test/helpers: improve pattern with module functions (#10421)

Benefits:

- less lines, especially less results when grepping
- makes it clearer what is exported
This commit is contained in:
Daniel Hahler
2019-07-22 01:13:11 +02:00
committed by GitHub
parent 8d66b6091b
commit 66149ecffe

View File

@@ -15,7 +15,11 @@ local function shell_quote(str)
end
end
local function argss_to_cmd(...)
local module = {
REMOVE_THIS = {},
}
function module.argss_to_cmd(...)
local cmd = ''
for i = 1, select('#', ...) do
local arg = select(i, ...)
@@ -30,16 +34,16 @@ local function argss_to_cmd(...)
return cmd
end
local function popen_r(...)
return io.popen(argss_to_cmd(...), 'r')
function module.popen_r(...)
return io.popen(module.argss_to_cmd(...), 'r')
end
local function popen_w(...)
return io.popen(argss_to_cmd(...), 'w')
function module.popen_w(...)
return io.popen(module.argss_to_cmd(...), 'w')
end
-- sleeps the test runner (_not_ the nvim instance)
local function sleep(ms)
function module.sleep(ms)
luv.sleep(ms)
end
@@ -49,26 +53,26 @@ local check_logs_useless_lines = {
['See README_MISSING_SYSCALL_OR_IOCTL for guidance']=3,
}
local function eq(expected, actual, context)
function module.eq(expected, actual, context)
return assert.are.same(expected, actual, context)
end
local function neq(expected, actual, context)
function module.neq(expected, actual, context)
return assert.are_not.same(expected, actual, context)
end
local function ok(res)
function module.ok(res)
return assert.is_true(res)
end
local function near(actual, expected, tolerance)
function module.near(actual, expected, tolerance)
return assert.is.near(actual, expected, tolerance)
end
local function matches(pat, actual)
function module.matches(pat, actual)
if nil ~= string.match(actual, pat) then
return true
end
error(string.format('Pattern does not match.\nPattern:\n%s\nActual:\n%s', pat, actual))
end
-- Expect an error matching pattern `pat`.
local function expect_err(pat, ...)
function module.expect_err(pat, ...)
local fn = select(1, ...)
local fn_args = {...}
table.remove(fn_args, 1)
@@ -78,7 +82,7 @@ end
-- initial_path: directory to recurse into
-- re: include pattern (string)
-- exc_re: exclude pattern(s) (string or table)
local function glob(initial_path, re, exc_re)
function module.glob(initial_path, re, exc_re)
exc_re = type(exc_re) == 'table' and exc_re or { exc_re }
local paths_to_check = {initial_path}
local ret = {}
@@ -118,7 +122,7 @@ local function glob(initial_path, re, exc_re)
return ret
end
local function check_logs()
function module.check_logs()
local log_dir = os.getenv('LOG_DIR')
local runtime_errors = 0
if log_dir and lfs.attributes(log_dir, 'mode') == 'directory' then
@@ -153,7 +157,7 @@ local function check_logs()
end
-- Tries to get platform name from $SYSTEM_NAME, uname; fallback is "Windows".
local uname = (function()
module.uname = (function()
local platform = nil
return (function()
if platform then
@@ -165,7 +169,7 @@ local uname = (function()
return platform
end
local status, f = pcall(popen_r, 'uname', '-s')
local status, f = pcall(module.popen_r, 'uname', '-s')
if status then
platform = f:read("*l")
f:close()
@@ -185,7 +189,7 @@ local function tmpdir_is_local(dir)
return not not (dir and string.find(dir, 'Xtest'))
end
local tmpname = (function()
module.tmpname = (function()
local seq = 0
local tmpdir = tmpdir_get()
return (function()
@@ -197,11 +201,11 @@ local tmpname = (function()
return fname
else
local fname = os.tmpname()
if uname() == 'Windows' and fname:sub(1, 2) == '\\s' then
if module.uname() == 'Windows' and fname:sub(1, 2) == '\\s' then
-- In Windows tmpname() returns a filename starting with
-- special sequence \s, prepend $TEMP path
return tmpdir..fname
elseif fname:match('^/tmp') and uname() == 'Darwin' then
elseif fname:match('^/tmp') and module.uname() == 'Darwin' then
-- In OS X /tmp links to /private/tmp
return '/private'..fname
else
@@ -211,7 +215,7 @@ local tmpname = (function()
end)
end)()
local function map(func, tab)
function module.map(func, tab)
local rettab = {}
for k, v in pairs(tab) do
rettab[k] = func(v)
@@ -219,7 +223,7 @@ local function map(func, tab)
return rettab
end
local function filter(filter_func, tab)
function module.filter(filter_func, tab)
local rettab = {}
for _, entry in pairs(tab) do
if filter_func(entry) then
@@ -229,7 +233,7 @@ local function filter(filter_func, tab)
return rettab
end
local function hasenv(name)
function module.hasenv(name)
local env = os.getenv(name)
if env and env ~= '' then
return env
@@ -244,7 +248,7 @@ end
local tests_skipped = 0
local function check_cores(app, force)
function module.check_cores(app, force)
app = app or 'build/bin/nvim'
local initial_path, re, exc_re
local gdb_db_cmd = 'gdb -n -batch -ex "thread apply all bt full" "$_NVIM_TEST_APP" -c "$_NVIM_TEST_CORE"'
@@ -256,7 +260,7 @@ local function check_cores(app, force)
and relpath(tmpdir_get()):gsub('^[ ./]+',''):gsub('%/+$',''):gsub('([^%w])', '%%%1')
or nil)
local db_cmd
if hasenv('NVIM_TEST_CORE_GLOB_DIRECTORY') then
if module.hasenv('NVIM_TEST_CORE_GLOB_DIRECTORY') then
initial_path = os.getenv('NVIM_TEST_CORE_GLOB_DIRECTORY')
re = os.getenv('NVIM_TEST_CORE_GLOB_RE')
exc_re = { os.getenv('NVIM_TEST_CORE_EXC_RE'), local_tmpdir }
@@ -279,7 +283,7 @@ local function check_cores(app, force)
tests_skipped = tests_skipped + 1
return
end
local cores = glob(initial_path, re, exc_re)
local cores = module.glob(initial_path, re, exc_re)
local found_cores = 0
local out = io.stdout
for _, core in ipairs(cores) do
@@ -301,8 +305,8 @@ local function check_cores(app, force)
end
end
local function which(exe)
local pipe = popen_r('which', exe)
function module.which(exe)
local pipe = module.popen_r('which', exe)
local ret = pipe:read('*a')
pipe:close()
if ret == '' then
@@ -312,20 +316,20 @@ local function which(exe)
end
end
local function repeated_read_cmd(...)
function module.repeated_read_cmd(...)
for _ = 1, 10 do
local stream = popen_r(...)
local stream = module.popen_r(...)
local ret = stream:read('*a')
stream:close()
if ret then
return ret
end
end
print('ERROR: Failed to execute ' .. argss_to_cmd(...) .. ': nil return after 10 attempts')
print('ERROR: Failed to execute ' .. module.argss_to_cmd(...) .. ': nil return after 10 attempts')
return nil
end
local function shallowcopy(orig)
function module.shallowcopy(orig)
if type(orig) ~= 'table' then
return orig
end
@@ -336,15 +340,13 @@ local function shallowcopy(orig)
return copy
end
local REMOVE_THIS = {}
local function mergedicts_copy(d1, d2)
local ret = shallowcopy(d1)
function module.mergedicts_copy(d1, d2)
local ret = module.shallowcopy(d1)
for k, v in pairs(d2) do
if d2[k] == REMOVE_THIS then
if d2[k] == module.REMOVE_THIS then
ret[k] = nil
elseif type(d1[k]) == 'table' and type(v) == 'table' then
ret[k] = mergedicts_copy(d1[k], v)
ret[k] = module.mergedicts_copy(d1[k], v)
else
ret[k] = v
end
@@ -355,16 +357,16 @@ end
-- dictdiff: find a diff so that mergedicts_copy(d1, diff) is equal to d2
--
-- Note: does not do copies of d2 values used.
local function dictdiff(d1, d2)
function module.dictdiff(d1, d2)
local ret = {}
local hasdiff = false
for k, v in pairs(d1) do
if d2[k] == nil then
hasdiff = true
ret[k] = REMOVE_THIS
ret[k] = module.REMOVE_THIS
elseif type(v) == type(d2[k]) then
if type(v) == 'table' then
local subdiff = dictdiff(v, d2[k])
local subdiff = module.dictdiff(v, d2[k])
if subdiff ~= nil then
hasdiff = true
ret[k] = subdiff
@@ -378,6 +380,7 @@ local function dictdiff(d1, d2)
hasdiff = true
end
end
local shallowcopy = module.shallowcopy
for k, v in pairs(d2) do
if d1[k] == nil then
ret[k] = shallowcopy(v)
@@ -391,7 +394,7 @@ local function dictdiff(d1, d2)
end
end
local function updated(d, d2)
function module.updated(d, d2)
for k, v in pairs(d2) do
d[k] = v
end
@@ -399,7 +402,7 @@ local function updated(d, d2)
end
-- Concat list-like tables.
local function concat_tables(...)
function module.concat_tables(...)
local ret = {}
for i = 1, select('#', ...) do
local tbl = select(i, ...)
@@ -412,7 +415,7 @@ local function concat_tables(...)
return ret
end
local function dedent(str, leave_indent)
function module.dedent(str, leave_indent)
-- find minimum common indent across lines
local indent = nil
for line in str:gmatch('[^\n]+') do
@@ -452,9 +455,7 @@ local SUBTBL = {
'\\030', '\\031',
}
local format_luav
format_luav = function(v, indent, opts)
function module.format_luav(v, indent, opts)
opts = opts or {}
local linesep = '\n'
local next_indent_arg = nil
@@ -484,12 +485,13 @@ format_luav = function(v, indent, opts)
end) .. quote
end
elseif type(v) == 'table' then
if v == REMOVE_THIS then
if v == module.REMOVE_THIS then
ret = 'REMOVE_THIS'
else
local processed_keys = {}
ret = '{' .. linesep
local non_empty = false
local format_luav = module.format_luav
for i, subv in ipairs(v) do
ret = ('%s%s%s,%s'):format(ret, next_indent,
format_luav(subv, next_indent_arg, opts), nl)
@@ -531,7 +533,7 @@ format_luav = function(v, indent, opts)
return ret
end
local function format_string(fmt, ...)
function module.format_string(fmt, ...)
local i = 0
local args = {...}
local function getarg()
@@ -552,7 +554,7 @@ local function format_string(fmt, ...)
-- Builtin %q is replaced here as it gives invalid and inconsistent with
-- luajit results for e.g. "\e" on lua: luajit transforms that into `\27`,
-- lua leaves as-is.
arg = format_luav(arg, nil, {dquote_strings = (subfmt:sub(-1) == 'q')})
arg = module.format_luav(arg, nil, {dquote_strings = (subfmt:sub(-1) == 'q')})
subfmt = subfmt:sub(1, -2) .. 's'
end
if subfmt == '%e' then
@@ -564,7 +566,7 @@ local function format_string(fmt, ...)
return ret
end
local function intchar2lua(ch)
function module.intchar2lua(ch)
ch = tonumber(ch)
return (20 <= ch and ch < 127) and ('%c'):format(ch) or ch
end
@@ -575,20 +577,21 @@ local fixtbl_metatable = {
end,
}
local function fixtbl(tbl)
function module.fixtbl(tbl)
return setmetatable(tbl, fixtbl_metatable)
end
local function fixtbl_rec(tbl)
function module.fixtbl_rec(tbl)
local fixtbl_rec = module.fixtbl_rec
for _, v in pairs(tbl) do
if type(v) == 'table' then
fixtbl_rec(v)
end
end
return fixtbl(tbl)
return module.fixtbl(tbl)
end
local function hexdump(str)
function module.hexdump(str)
local len = string.len(str)
local dump = ""
local hex = ""
@@ -617,7 +620,7 @@ end
--
-- filename: path to file
-- start: start line (1-indexed), negative means "lines before end" (tail)
local function read_file_list(filename, start)
function module.read_file_list(filename, start)
local lnum = (start ~= nil and type(start) == 'number') and start or 1
local tail = (lnum < 0)
local maxlines = tail and math.abs(lnum) or nil
@@ -643,7 +646,7 @@ end
-- Reads the entire contents of `filename` into a string.
--
-- filename: path to file
local function read_file(filename)
function module.read_file(filename)
local file = io.open(filename, 'r')
if not file then
return nil
@@ -654,7 +657,7 @@ local function read_file(filename)
end
-- Dedent the given text and write it to the file name.
local function write_file(name, text, no_dedent, append)
function module.write_file(name, text, no_dedent, append)
local file = io.open(name, (append and 'a' or 'w'))
if type(text) == 'table' then
-- Byte blob
@@ -664,14 +667,14 @@ local function write_file(name, text, no_dedent, append)
text = ('%s%c'):format(text, char)
end
elseif not no_dedent then
text = dedent(text)
text = module.dedent(text)
end
file:write(text)
file:flush()
file:close()
end
local function isCI()
function module.isCI()
local is_travis = nil ~= os.getenv('TRAVIS')
local is_appveyor = nil ~= os.getenv('APPVEYOR')
local is_quickbuild = nil ~= lfs.attributes('/usr/home/quickbuild')
@@ -680,10 +683,11 @@ end
-- Gets the contents of $NVIM_LOG_FILE for printing to the build log.
-- Also removes the file, if the current environment looks like CI.
local function read_nvim_log()
function module.read_nvim_log()
local logfile = os.getenv('NVIM_LOG_FILE') or '.nvimlog'
local keep = isCI() and 999 or 10
local lines = read_file_list(logfile, -keep) or {}
local is_ci = module.isCI()
local keep = is_ci and 999 or 10
local lines = module.read_file_list(logfile, -keep) or {}
local log = (('-'):rep(78)..'\n'
..string.format('$NVIM_LOG_FILE: %s\n', logfile)
..(#lines > 0 and '(last '..tostring(keep)..' lines)\n' or '(empty)\n'))
@@ -691,52 +695,12 @@ local function read_nvim_log()
log = log..line..'\n'
end
log = log..('-'):rep(78)..'\n'
if isCI() then
if is_ci then
os.remove(logfile)
end
return log
end
local module = {
REMOVE_THIS = REMOVE_THIS,
argss_to_cmd = argss_to_cmd,
check_cores = check_cores,
check_logs = check_logs,
concat_tables = concat_tables,
dedent = dedent,
dictdiff = dictdiff,
eq = eq,
expect_err = expect_err,
filter = filter,
fixtbl = fixtbl,
fixtbl_rec = fixtbl_rec,
format_luav = format_luav,
format_string = format_string,
glob = glob,
hasenv = hasenv,
hexdump = hexdump,
intchar2lua = intchar2lua,
isCI = isCI,
map = map,
matches = matches,
mergedicts_copy = mergedicts_copy,
near = near,
neq = neq,
ok = ok,
popen_r = popen_r,
popen_w = popen_w,
read_file = read_file,
read_file_list = read_file_list,
read_nvim_log = read_nvim_log,
repeated_read_cmd = repeated_read_cmd,
shallowcopy = shallowcopy,
sleep = sleep,
tmpname = tmpname,
uname = uname,
updated = updated,
which = which,
write_file = write_file,
}
module = shared.tbl_extend('error', module, Paths, shared)
return module