refactor(unit): add type annotations

This commit is contained in:
Lewis Russell
2023-04-03 12:01:23 +01:00
parent 5465adcbab
commit 3d29424fb9
5 changed files with 313 additions and 248 deletions

View File

@@ -20,12 +20,16 @@ local module = {
REMOVE_THIS = {}, REMOVE_THIS = {},
} }
--- @param p string
--- @return string
local function relpath(p) local function relpath(p)
p = vim.fs.normalize(p) p = vim.fs.normalize(p)
local cwd = luv.cwd() local cwd = luv.cwd()
return p:gsub("^" .. cwd) return p:gsub("^" .. cwd)
end end
--- @param path string
--- @return boolean
function module.isdir(path) function module.isdir(path)
if not path then if not path then
return false return false
@@ -37,6 +41,8 @@ function module.isdir(path)
return stat.type == 'directory' return stat.type == 'directory'
end end
--- @param path string
--- @return boolean
function module.isfile(path) function module.isfile(path)
if not path then if not path then
return false return false
@@ -48,6 +54,7 @@ function module.isfile(path)
return stat.type == 'file' return stat.type == 'file'
end end
--- @return string
function module.argss_to_cmd(...) function module.argss_to_cmd(...)
local cmd = '' local cmd = ''
for i = 1, select('#', ...) do for i = 1, select('#', ...) do
@@ -462,6 +469,7 @@ function module.check_cores(app, force) -- luacheck: ignore
end end
end end
--- @return string?
function module.repeated_read_cmd(...) function module.repeated_read_cmd(...)
for _ = 1, 10 do for _ = 1, 10 do
local stream = module.popen_r(...) local stream = module.popen_r(...)
@@ -561,6 +569,9 @@ function module.concat_tables(...)
return ret return ret
end end
--- @param str string
--- @param leave_indent? boolean
--- @return string
function module.dedent(str, leave_indent) function module.dedent(str, leave_indent)
-- find minimum common indent across lines -- find minimum common indent across lines
local indent = nil local indent = nil

View File

@@ -154,6 +154,8 @@ local C_keywords = set { -- luacheck: ignore
-- --
-- The first one will have a lot of false positives (the line '{' for -- The first one will have a lot of false positives (the line '{' for
-- example), the second one is more unique. -- example), the second one is more unique.
--- @param string
--- @return string
local function formatc(str) local function formatc(str)
local toks = TokeniseC(str) local toks = TokeniseC(str)
local result = {} local result = {}

View File

@@ -14,20 +14,15 @@ local map = global_helpers.tbl_map
local eq = global_helpers.eq local eq = global_helpers.eq
local trim = global_helpers.trim local trim = global_helpers.trim
-- C constants.
local NULL = ffi.cast('void*', 0)
local OK = 1
local FAIL = 0
local cimport
-- add some standard header locations -- add some standard header locations
for _, p in ipairs(Paths.include_paths) do for _, p in ipairs(Paths.include_paths) do
Preprocess.add_to_include_path(p) Preprocess.add_to_include_path(p)
end end
local child_pid = nil local child_pid = nil --- @type integer
--- @generic F: function
--- @param func F
--- @return F
local function only_separate(func) local function only_separate(func)
return function(...) return function(...)
if child_pid ~= 0 then if child_pid ~= 0 then
@@ -36,9 +31,20 @@ local function only_separate(func)
return func(...) return func(...)
end end
end end
local child_calls_init = {}
local child_calls_mod = nil --- @class ChildCall
local child_calls_mod_once = nil --- @field func function
--- @field args any[]
--- @class ChildCallLog
--- @field func string
--- @field args any[]
--- @field ret any?
local child_calls_init = {} --- @type ChildCall[]
local child_calls_mod = nil --- @type ChildCall[]
local child_calls_mod_once = nil --- @type ChildCall[]?
local function child_call(func, ret) local function child_call(func, ret)
return function(...) return function(...)
local child_calls = child_calls_mod or child_calls_init local child_calls = child_calls_mod or child_calls_init
@@ -53,16 +59,16 @@ end
-- Run some code at the start of the child process, before running the test -- Run some code at the start of the child process, before running the test
-- itself. Is supposed to be run in `before_each`. -- itself. Is supposed to be run in `before_each`.
--- @param func function
local function child_call_once(func, ...) local function child_call_once(func, ...)
if child_pid ~= 0 then if child_pid ~= 0 then
child_calls_mod_once[#child_calls_mod_once + 1] = { child_calls_mod_once[#child_calls_mod_once + 1] = { func = func, args = {...} }
func=func, args={...}}
else else
func(...) func(...)
end end
end end
local child_cleanups_mod_once = nil local child_cleanups_mod_once = nil --- @type ChildCall[]?
-- Run some code at the end of the child process, before exiting. Is supposed to -- Run some code at the end of the child process, before exiting. Is supposed to
-- be run in `before_each` because `after_each` is run after child has exited. -- be run in `before_each` because `after_each` is run after child has exited.
@@ -125,8 +131,9 @@ local pragma_pack_id = 1
-- some things are just too complex for the LuaJIT C parser to digest. We -- some things are just too complex for the LuaJIT C parser to digest. We
-- usually don't need them anyway. -- usually don't need them anyway.
--- @param body string
local function filter_complex_blocks(body) local function filter_complex_blocks(body)
local result = {} local result = {} --- @type string[]
for line in body:gmatch("[^\r\n]+") do for line in body:gmatch("[^\r\n]+") do
if not (string.find(line, "(^)", 1, true) ~= nil if not (string.find(line, "(^)", 1, true) ~= nil
@@ -153,18 +160,20 @@ typedef struct { char bytes[16]; } __attribute__((aligned(16))) __uint128_t;
typedef struct { char bytes[16]; } __attribute__((aligned(16))) __float128; typedef struct { char bytes[16]; } __attribute__((aligned(16))) __float128;
]] ]]
local preprocess_cache_init = {} local preprocess_cache_init = {} --- @type table<string,string>
local previous_defines_mod = '' local previous_defines_mod = ''
local preprocess_cache_mod = nil local preprocess_cache_mod = nil --- @type table<string,string>
local function is_child_cdefs() local function is_child_cdefs()
return (os.getenv('NVIM_TEST_MAIN_CDEFS') ~= '1') return os.getenv('NVIM_TEST_MAIN_CDEFS') ~= '1'
end end
-- use this helper to import C files, you can pass multiple paths at once, -- use this helper to import C files, you can pass multiple paths at once,
-- this helper will return the C namespace of the nvim library. -- this helper will return the C namespace of the nvim library.
cimport = function(...) local function cimport(...)
local previous_defines, preprocess_cache, cdefs local previous_defines --- @type string
local preprocess_cache --- @type table<string,string>
local cdefs
if is_child_cdefs() and preprocess_cache_mod then if is_child_cdefs() and preprocess_cache_mod then
preprocess_cache = preprocess_cache_mod preprocess_cache = preprocess_cache_mod
previous_defines = previous_defines_mod previous_defines = previous_defines_mod
@@ -180,7 +189,7 @@ cimport = function(...)
path = './' .. path path = './' .. path
end end
if not preprocess_cache[path] then if not preprocess_cache[path] then
local body local body --- @type string
body, previous_defines = Preprocess.preprocess(previous_defines, path) body, previous_defines = Preprocess.preprocess(previous_defines, path)
-- format it (so that the lines are "unique" statements), also filter out -- format it (so that the lines are "unique" statements), also filter out
-- Objective-C blocks -- Objective-C blocks
@@ -202,6 +211,7 @@ cimport = function(...)
-- (they are needed in the right order with the struct definitions, -- (they are needed in the right order with the struct definitions,
-- otherwise luajit has wrong memory layouts for the sturcts) -- otherwise luajit has wrong memory layouts for the sturcts)
if line:match("#pragma%s+pack") then if line:match("#pragma%s+pack") then
--- @type string
line = line .. " // " .. pragma_pack_id line = line .. " // " .. pragma_pack_id
pragma_pack_id = pragma_pack_id + 1 pragma_pack_id = pragma_pack_id + 1
end end
@@ -229,20 +239,21 @@ cimport = function(...)
return lib return lib
end end
local cimport_immediate = function(...) local function cimport_immediate(...)
local saved_pid = child_pid local saved_pid = child_pid
child_pid = 0 child_pid = 0
local err, emsg = pcall(cimport, ...) local err, emsg = pcall(cimport, ...)
child_pid = saved_pid child_pid = saved_pid
if not err then if not err then
emsg = tostring(emsg) io.stderr:write(tostring(emsg) .. '\n')
io.stderr:write(emsg .. '\n')
assert(false) assert(false)
else else
return lib return lib
end end
end end
--- @param preprocess_cache table<string,string[]>
--- @param path string
local function _cimportstr(preprocess_cache, path) local function _cimportstr(preprocess_cache, path)
if imported:contains(path) then if imported:contains(path) then
return lib return lib
@@ -265,12 +276,14 @@ end
local function alloc_log_new() local function alloc_log_new()
local log = { local log = {
log={}, log={}, --- @type ChildCallLog[]
lib=cimport('./src/nvim/memory.h'), lib=cimport('./src/nvim/memory.h'), --- @type table<string,function>
original_functions={}, original_functions={}, --- @type table<string,function>
null={['\0:is_null']=true}, null={['\0:is_null']=true},
} }
local allocator_functions = {'malloc', 'free', 'calloc', 'realloc'} local allocator_functions = {'malloc', 'free', 'calloc', 'realloc'}
function log:save_original_functions() function log:save_original_functions()
for _, funcname in ipairs(allocator_functions) do for _, funcname in ipairs(allocator_functions) do
if not self.original_functions[funcname] then if not self.original_functions[funcname] then
@@ -278,13 +291,16 @@ local function alloc_log_new()
end end
end end
end end
log.save_original_functions = child_call(log.save_original_functions) log.save_original_functions = child_call(log.save_original_functions)
function log:set_mocks() function log:set_mocks()
for _, k in ipairs(allocator_functions) do for _, k in ipairs(allocator_functions) do
do do
local kk = k local kk = k
self.lib['mem_' .. k] = function(...) self.lib['mem_' .. k] = function(...)
local log_entry = {func=kk, args={...}} --- @type ChildCallLog
local log_entry = { func = kk, args = {...} }
self.log[#self.log + 1] = log_entry self.log[#self.log + 1] = log_entry
if kk == 'free' then if kk == 'free' then
self.original_functions[kk](...) self.original_functions[kk](...)
@@ -305,17 +321,21 @@ local function alloc_log_new()
end end
end end
end end
log.set_mocks = child_call(log.set_mocks) log.set_mocks = child_call(log.set_mocks)
function log:clear() function log:clear()
self.log = {} self.log = {}
end end
function log:check(exp) function log:check(exp)
eq(exp, self.log) eq(exp, self.log)
self:clear() self:clear()
end end
function log:clear_tmp_allocs(clear_null_frees) function log:clear_tmp_allocs(clear_null_frees)
local toremove = {} local toremove = {} --- @type integer[]
local allocs = {} local allocs = {} --- @type table<string,integer>
for i, v in ipairs(self.log) do for i, v in ipairs(self.log) do
if v.func == 'malloc' or v.func == 'calloc' then if v.func == 'malloc' or v.func == 'calloc' then
allocs[tostring(v.ret)] = i allocs[tostring(v.ret)] = i
@@ -338,26 +358,20 @@ local function alloc_log_new()
table.remove(self.log, toremove[i]) table.remove(self.log, toremove[i])
end end
end end
function log:restore_original_functions()
-- Do nothing: set mocks live in a separate process
return
--[[
[ for k, v in pairs(self.original_functions) do
[ self.lib['mem_' .. k] = v
[ end
]]
end
function log:setup() function log:setup()
log:save_original_functions() log:save_original_functions()
log:set_mocks() log:set_mocks()
end end
function log:before_each() function log:before_each()
return
end end
function log:after_each() function log:after_each()
log:restore_original_functions()
end end
log:setup() log:setup()
return log return log
end end
@@ -374,11 +388,14 @@ local function to_cstr(string)
end end
cimport_immediate('./test/unit/fixtures/posix.h') cimport_immediate('./test/unit/fixtures/posix.h')
local sc = {
fork = function() local sc = {}
function sc.fork()
return tonumber(ffi.C.fork()) return tonumber(ffi.C.fork())
end, end
pipe = function()
function sc.pipe()
local ret = ffi.new('int[2]', {-1, -1}) local ret = ffi.new('int[2]', {-1, -1})
ffi.errno(0) ffi.errno(0)
local res = ffi.C.pipe(ret) local res = ffi.C.pipe(ret)
@@ -389,8 +406,10 @@ local sc = {
end end
assert(ret[0] ~= -1 and ret[1] ~= -1) assert(ret[0] ~= -1 and ret[1] ~= -1)
return ret[0], ret[1] return ret[0], ret[1]
end, end
read = function(rd, len)
--- @return string
function sc.read(rd, len)
local ret = ffi.new('char[?]', len, {0}) local ret = ffi.new('char[?]', len, {0})
local total_bytes_read = 0 local total_bytes_read = 0
ffi.errno(0) ffi.errno(0)
@@ -412,8 +431,9 @@ local sc = {
end end
end end
return ffi.string(ret, total_bytes_read) return ffi.string(ret, total_bytes_read)
end, end
write = function(wr, s)
function sc.write(wr, s)
local wbuf = to_cstr(s) local wbuf = to_cstr(s)
local total_bytes_written = 0 local total_bytes_written = 0
ffi.errno(0) ffi.errno(0)
@@ -435,9 +455,13 @@ local sc = {
end end
end end
return total_bytes_written return total_bytes_written
end, end
close = ffi.C.close,
wait = function(pid) sc.close = ffi.C.close
--- @param pid integer
--- @return integer
function sc.wait(pid)
ffi.errno(0) ffi.errno(0)
local stat_loc = ffi.new('int[1]', {0}) local stat_loc = ffi.new('int[1]', {0})
while true do while true do
@@ -455,17 +479,18 @@ local sc = {
end end
end end
return stat_loc[0] return stat_loc[0]
end, end
exit = ffi.C._exit,
}
sc.exit = ffi.C._exit
--- @param lst string[]
--- @return string
local function format_list(lst) local function format_list(lst)
local ret = '' local ret = {} --- @type string[]
for _, v in ipairs(lst) do for _, v in ipairs(lst) do
if ret ~= '' then ret = ret .. ', ' end ret[#ret+1] = assert:format({v, n=1})[1]
ret = ret .. assert:format({v, n=1})[1]
end end
return ret return table.concat(ret, ', ')
end end
if os.getenv('NVIM_TEST_PRINT_SYSCALLS') == '1' then if os.getenv('NVIM_TEST_PRINT_SYSCALLS') == '1' then
@@ -513,19 +538,26 @@ local tracehelp = dedent([[
]]) ]])
local function child_sethook(wr) local function child_sethook(wr)
local trace_level = os.getenv('NVIM_TEST_TRACE_LEVEL') local trace_level_str = os.getenv('NVIM_TEST_TRACE_LEVEL')
if not trace_level or trace_level == '' then local trace_level = 0
trace_level = 0 if trace_level_str and trace_level_str ~= '' then
else --- @type number
trace_level = tonumber(trace_level) trace_level = assert(tonumber(trace_level_str))
end end
if trace_level <= 0 then if trace_level <= 0 then
return return
end end
local trace_only_c = trace_level <= 1 local trace_only_c = trace_level <= 1
--- @type debuginfo?, string?, integer
local prev_info, prev_reason, prev_lnum local prev_info, prev_reason, prev_lnum
--- @param reason string
--- @param lnum integer
--- @param use_prev boolean
local function hook(reason, lnum, use_prev) local function hook(reason, lnum, use_prev)
local info = nil local info = nil --- @type debuginfo?
if use_prev then if use_prev then
info = prev_info info = prev_info
elseif reason ~= 'tail return' then -- tail return elseif reason ~= 'tail return' then -- tail return
@@ -533,6 +565,7 @@ local function child_sethook(wr)
end end
if trace_only_c and (not info or info.what ~= 'C') and not use_prev then if trace_only_c and (not info or info.what ~= 'C') and not use_prev then
--- @cast info -nil
if info.source:sub(-9) == '_spec.lua' then if info.source:sub(-9) == '_spec.lua' then
prev_info = info prev_info = info
prev_reason = 'saved' prev_reason = 'saved'
@@ -573,12 +606,8 @@ local function child_sethook(wr)
end end
-- assert(-1 <= lnum and lnum <= 99999) -- assert(-1 <= lnum and lnum <= 99999)
local lnum_s local lnum_s = lnum == -1 and 'nknwn' or ('%u'):format(lnum)
if lnum == -1 then --- @type string
lnum_s = 'nknwn'
else
lnum_s = ('%u'):format(lnum)
end
local msg = ( -- lua does not support %* local msg = ( -- lua does not support %*
'' ''
.. msgchar .. msgchar
@@ -600,6 +629,7 @@ end
local trace_end_msg = ('E%s\n'):format((' '):rep(hook_msglen - 2)) local trace_end_msg = ('E%s\n'):format((' '):rep(hook_msglen - 2))
--- @type function
local _debug_log local _debug_log
local debug_log = only_separate(function(...) local debug_log = only_separate(function(...)
@@ -607,6 +637,7 @@ local debug_log = only_separate(function(...)
end) end)
local function itp_child(wr, func) local function itp_child(wr, func)
--- @param s string
_debug_log = function(s) _debug_log = function(s)
s = s:sub(1, hook_msglen - 2) s = s:sub(1, hook_msglen - 2)
sc.write(wr, '>' .. s .. (' '):rep(hook_msglen - 2 - #s) .. '\n') sc.write(wr, '>' .. s .. (' '):rep(hook_msglen - 2 - #s) .. '\n')
@@ -638,7 +669,7 @@ local function itp_child(wr, func)
end end
local function check_child_err(rd) local function check_child_err(rd)
local trace = {} local trace = {} --- @type string[]
local did_traceline = false local did_traceline = false
local maxtrace = tonumber(os.getenv('NVIM_TEST_MAXTRACE')) or 1024 local maxtrace = tonumber(os.getenv('NVIM_TEST_MAXTRACE')) or 1024
while true do while true do
@@ -668,11 +699,14 @@ local function check_child_err(rd)
local len = tonumber(len_s) local len = tonumber(len_s)
neq(0, len) neq(0, len)
if os.getenv('NVIM_TEST_TRACE_ON_ERROR') == '1' and #trace ~= 0 then if os.getenv('NVIM_TEST_TRACE_ON_ERROR') == '1' and #trace ~= 0 then
--- @type string
err = '\nTest failed, trace:\n' .. tracehelp err = '\nTest failed, trace:\n' .. tracehelp
for _, traceline in ipairs(trace) do for _, traceline in ipairs(trace) do
--- @type string
err = err .. traceline err = err .. traceline
end end
end end
--- @type string
err = err .. sc.read(rd, len + 1) err = err .. sc.read(rd, len + 1)
end end
local eres = sc.read(rd, 2) local eres = sc.read(rd, 2)
@@ -686,10 +720,12 @@ local function check_child_err(rd)
end end
end end
if not did_traceline then if not did_traceline then
--- @type string
err = err .. '\nNo end of trace occurred' err = err .. '\nNo end of trace occurred'
end end
local cc_err, cc_emsg = pcall(check_cores, Paths.test_luajit_prg, true) local cc_err, cc_emsg = pcall(check_cores, Paths.test_luajit_prg, true)
if not cc_err then if not cc_err then
--- @type string
err = err .. '\ncheck_cores failed: ' .. cc_emsg err = err .. '\ncheck_cores failed: ' .. cc_emsg
end end
end end
@@ -822,9 +858,9 @@ local module = {
lib = lib, lib = lib,
cstr = cstr, cstr = cstr,
to_cstr = to_cstr, to_cstr = to_cstr,
NULL = NULL, NULL = ffi.cast('void*', 0),
OK = OK, OK = 1,
FAIL = FAIL, FAIL = 0,
alloc_log_new = alloc_log_new, alloc_log_new = alloc_log_new,
gen_itp = gen_itp, gen_itp = gen_itp,
only_separate = only_separate, only_separate = only_separate,

View File

@@ -7,6 +7,9 @@ local global_helpers = require('test.helpers')
local argss_to_cmd = global_helpers.argss_to_cmd local argss_to_cmd = global_helpers.argss_to_cmd
local repeated_read_cmd = global_helpers.repeated_read_cmd local repeated_read_cmd = global_helpers.repeated_read_cmd
--- @alias Compiler {path: string[], type: string}
--- @type Compiler[]
local ccs = {} local ccs = {}
local env_cc = os.getenv("CC") local env_cc = os.getenv("CC")
@@ -27,6 +30,8 @@ table.insert(ccs, {path = {"/usr/bin/env", "clang"}, type = "clang"})
table.insert(ccs, {path = {"/usr/bin/env", "icc"}, type = "gcc"}) table.insert(ccs, {path = {"/usr/bin/env", "icc"}, type = "gcc"})
-- parse Makefile format dependencies into a Lua table -- parse Makefile format dependencies into a Lua table
--- @param deps string
--- @return string[]
local function parse_make_deps(deps) local function parse_make_deps(deps)
-- remove line breaks and line concatenators -- remove line breaks and line concatenators
deps = deps:gsub("\n", ""):gsub("\\", "") deps = deps:gsub("\n", ""):gsub("\\", "")
@@ -36,7 +41,7 @@ local function parse_make_deps(deps)
deps = deps:gsub(" +", " ") deps = deps:gsub(" +", " ")
-- split according to token (space in this case) -- split according to token (space in this case)
local headers = {} local headers = {} --- @type string[]
for token in deps:gmatch("[^%s]+") do for token in deps:gmatch("[^%s]+") do
-- headers[token] = true -- headers[token] = true
headers[#headers + 1] = token headers[#headers + 1] = token
@@ -53,57 +58,58 @@ local function parse_make_deps(deps)
return headers return headers
end end
-- will produce a string that represents a meta C header file that includes --- will produce a string that represents a meta C header file that includes
-- all the passed in headers. I.e.: --- all the passed in headers. I.e.:
-- ---
-- headerize({"stdio.h", "math.h"}, true) --- headerize({"stdio.h", "math.h"}, true)
-- produces: --- produces:
-- #include <stdio.h> --- #include <stdio.h>
-- #include <math.h> --- #include <math.h>
-- ---
-- headerize({"vim.h", "memory.h"}, false) --- headerize({"vim.h", "memory.h"}, false)
-- produces: --- produces:
-- #include "vim.h" --- #include "vim.h"
-- #include "memory.h" --- #include "memory.h"
--- @param headers string[]
--- @param global? boolean
--- @return string
local function headerize(headers, global) local function headerize(headers, global)
local pre = '"' local fmt = global and '#include <%s>' or '#include "%s"'
local post = pre local formatted = {} --- @type string[]
if global then
pre = "<"
post = ">"
end
local formatted = {}
for _, hdr in ipairs(headers) do for _, hdr in ipairs(headers) do
formatted[#formatted + 1] = "#include " .. formatted[#formatted + 1] = string.format(fmt, hdr)
tostring(pre) ..
tostring(hdr) ..
tostring(post)
end end
return table.concat(formatted, "\n") return table.concat(formatted, "\n")
end end
--- @class Gcc
--- @field path string
--- @field preprocessor_extra_flags string[]
--- @field get_defines_extra_flags string[]
--- @field get_declarations_extra_flags string[]
local Gcc = { local Gcc = {
preprocessor_extra_flags = {}, preprocessor_extra_flags = {},
get_defines_extra_flags = {'-std=c99', '-dM', '-E'}, get_defines_extra_flags = {'-std=c99', '-dM', '-E'},
get_declarations_extra_flags = {'-std=c99', '-P', '-E'}, get_declarations_extra_flags = {'-std=c99', '-P', '-E'},
} }
--- @param name string
--- @param args string[]?
--- @param val string?
function Gcc:define(name, args, val) function Gcc:define(name, args, val)
local define = '-D' .. name local define = string.format('-D%s', name)
if args ~= nil then if args then
define = define .. '(' .. table.concat(args, ',') .. ')' define = string.format('%s(%s)', define, table.concat(args, ','))
end end
if val ~= nil then if val then
define = define .. '=' .. val define = string.format('%s=%s', define, val)
end end
self.preprocessor_extra_flags[#self.preprocessor_extra_flags + 1] = define self.preprocessor_extra_flags[#self.preprocessor_extra_flags + 1] = define
end end
function Gcc:undefine(name) function Gcc:undefine(name)
self.preprocessor_extra_flags[#self.preprocessor_extra_flags + 1] = ( self.preprocessor_extra_flags[#self.preprocessor_extra_flags + 1] = '-U' .. name
'-U' .. name)
end end
function Gcc:init_defines() function Gcc:init_defines()
@@ -128,6 +134,8 @@ function Gcc:init_defines()
self:undefine('__BLOCKS__') self:undefine('__BLOCKS__')
end end
--- @param obj? Compiler
--- @return Gcc
function Gcc:new(obj) function Gcc:new(obj)
obj = obj or {} obj = obj or {}
setmetatable(obj, self) setmetatable(obj, self)
@@ -136,6 +144,7 @@ function Gcc:new(obj)
return obj return obj
end end
--- @param ... string
function Gcc:add_to_include_path(...) function Gcc:add_to_include_path(...)
for i = 1, select('#', ...) do for i = 1, select('#', ...) do
local path = select(i, ...) local path = select(i, ...)
@@ -145,110 +154,115 @@ function Gcc:add_to_include_path(...)
end end
-- returns a list of the headers files upon which this file relies -- returns a list of the headers files upon which this file relies
--- @param hdr string
--- @return string[]?
function Gcc:dependencies(hdr) function Gcc:dependencies(hdr)
--- @type string
local cmd = argss_to_cmd(self.path, {'-M', hdr}) .. ' 2>&1' local cmd = argss_to_cmd(self.path, {'-M', hdr}) .. ' 2>&1'
local out = io.popen(cmd) local out = assert(io.popen(cmd))
local deps = out:read("*a") local deps = out:read("*a")
out:close() out:close()
if deps then if deps then
return parse_make_deps(deps) return parse_make_deps(deps)
else
return nil
end end
end end
--- @param defines string
--- @return string
function Gcc:filter_standard_defines(defines) function Gcc:filter_standard_defines(defines)
if not self.standard_defines then if not self.standard_defines then
local pseudoheader_fname = 'tmp_empty_pseudoheader.h' local pseudoheader_fname = 'tmp_empty_pseudoheader.h'
local pseudoheader_file = io.open(pseudoheader_fname, 'w') local pseudoheader_file = assert(io.open(pseudoheader_fname, 'w'))
pseudoheader_file:close() pseudoheader_file:close()
local standard_defines = repeated_read_cmd(self.path, local standard_defines = assert(repeated_read_cmd(self.path,
self.preprocessor_extra_flags, self.preprocessor_extra_flags,
self.get_defines_extra_flags, self.get_defines_extra_flags,
{pseudoheader_fname}) {pseudoheader_fname}))
os.remove(pseudoheader_fname) os.remove(pseudoheader_fname)
self.standard_defines = {} self.standard_defines = {} --- @type table<string,true>
for line in standard_defines:gmatch('[^\n]+') do for line in standard_defines:gmatch('[^\n]+') do
self.standard_defines[line] = true self.standard_defines[line] = true
end end
end end
local ret = {}
local ret = {} --- @type string[]
for line in defines:gmatch('[^\n]+') do for line in defines:gmatch('[^\n]+') do
if not self.standard_defines[line] then if not self.standard_defines[line] then
ret[#ret + 1] = line ret[#ret + 1] = line
end end
end end
return table.concat(ret, "\n") return table.concat(ret, "\n")
end end
-- returns a stream representing a preprocessed form of the passed-in headers. --- returns a stream representing a preprocessed form of the passed-in headers.
-- Don't forget to close the stream by calling the close() method on it. --- Don't forget to close the stream by calling the close() method on it.
--- @param previous_defines string
--- @param ... string
--- @return string, string
function Gcc:preprocess(previous_defines, ...) function Gcc:preprocess(previous_defines, ...)
-- create pseudo-header -- create pseudo-header
local pseudoheader = headerize({...}, false) local pseudoheader = headerize({...}, false)
local pseudoheader_fname = 'tmp_pseudoheader.h' local pseudoheader_fname = 'tmp_pseudoheader.h'
local pseudoheader_file = io.open(pseudoheader_fname, 'w') local pseudoheader_file = assert(io.open(pseudoheader_fname, 'w'))
pseudoheader_file:write(previous_defines) pseudoheader_file:write(previous_defines)
pseudoheader_file:write("\n") pseudoheader_file:write("\n")
pseudoheader_file:write(pseudoheader) pseudoheader_file:write(pseudoheader)
pseudoheader_file:flush() pseudoheader_file:flush()
pseudoheader_file:close() pseudoheader_file:close()
local defines = repeated_read_cmd(self.path, self.preprocessor_extra_flags, local defines = assert(repeated_read_cmd(self.path, self.preprocessor_extra_flags,
self.get_defines_extra_flags, self.get_defines_extra_flags,
{pseudoheader_fname}) {pseudoheader_fname}))
defines = self:filter_standard_defines(defines) defines = self:filter_standard_defines(defines)
local declarations = repeated_read_cmd(self.path, local declarations = assert(repeated_read_cmd(self.path,
self.preprocessor_extra_flags, self.preprocessor_extra_flags,
self.get_declarations_extra_flags, self.get_declarations_extra_flags,
{pseudoheader_fname}) {pseudoheader_fname}))
os.remove(pseudoheader_fname) os.remove(pseudoheader_fname)
assert(declarations and defines)
return declarations, defines return declarations, defines
end end
local Clang = Gcc:new()
local Msvc = Gcc:new()
local type_to_class = {
["gcc"] = Gcc,
["clang"] = Clang,
["msvc"] = Msvc
}
-- find the best cc. If os.exec causes problems on windows (like popping up -- find the best cc. If os.exec causes problems on windows (like popping up
-- a console window) we might consider using something like this: -- a console window) we might consider using something like this:
-- http://scite-ru.googlecode.com/svn/trunk/pack/tools/LuaLib/shell.html#exec -- http://scite-ru.googlecode.com/svn/trunk/pack/tools/LuaLib/shell.html#exec
--- @param compilers Compiler[]
--- @return Gcc?
local function find_best_cc(compilers) local function find_best_cc(compilers)
for _, meta in pairs(compilers) do for _, meta in pairs(compilers) do
local version = io.popen(tostring(meta.path) .. " -v 2>&1") local version = assert(io.popen(tostring(meta.path) .. " -v 2>&1"))
version:close() version:close()
if version then if version then
return type_to_class[meta.type]:new({path = meta.path}) return Gcc:new({path = meta.path})
end end
end end
return nil
end end
-- find the best cc. If os.exec causes problems on windows (like popping up -- find the best cc. If os.exec causes problems on windows (like popping up
-- a console window) we might consider using something like this: -- a console window) we might consider using something like this:
-- http://scite-ru.googlecode.com/svn/trunk/pack/tools/LuaLib/shell.html#exec -- http://scite-ru.googlecode.com/svn/trunk/pack/tools/LuaLib/shell.html#exec
local cc = nil local cc = assert(find_best_cc(ccs))
if cc == nil then
cc = find_best_cc(ccs) local M = {}
--- @param hdr string
--- @return string[]?
function M.includes(hdr)
return cc:dependencies(hdr)
end end
return { --- @param ... string
includes = function(hdr) --- @return string, string
return cc:dependencies(hdr) function M.preprocess(...)
end,
preprocess = function(...)
return cc:preprocess(...) return cc:preprocess(...)
end, end
add_to_include_path = function(...)
--- @param ... string
function M.add_to_include_path(...)
return cc:add_to_include_path(...) return cc:add_to_include_path(...)
end end
}
return M

View File

@@ -4,10 +4,15 @@
-- other: -- other:
-- 1) index => item -- 1) index => item
-- 2) item => index -- 2) item => index
--- @class Set
--- @field nelem integer
--- @field items string[]
--- @field tbl table
local Set = {} local Set = {}
--- @param items? string[]
function Set:new(items) function Set:new(items)
local obj = {} local obj = {} --- @ type Set
setmetatable(obj, self) setmetatable(obj, self)
self.__index = self self.__index = self
@@ -26,8 +31,9 @@ function Set:new(items)
return obj return obj
end end
--- @return Set
function Set:copy() function Set:copy()
local obj = {} local obj = {} --- @ type Set
obj.nelem = self.nelem obj.nelem = self.nelem
obj.tbl = {} obj.tbl = {}
obj.items = {} obj.items = {}
@@ -43,6 +49,7 @@ function Set:copy()
end end
-- adds the argument Set to this Set -- adds the argument Set to this Set
--- @param other Set
function Set:union(other) function Set:union(other)
for e in other:iterator() do for e in other:iterator() do
self:add(e) self:add(e)
@@ -57,6 +64,7 @@ function Set:union_table(t)
end end
-- subtracts the argument Set from this Set -- subtracts the argument Set from this Set
--- @param other Set
function Set:diff(other) function Set:diff(other)
if other:size() > self:size() then if other:size() > self:size() then
-- this set is smaller than the other set -- this set is smaller than the other set
@@ -75,6 +83,7 @@ function Set:diff(other)
end end
end end
--- @param it string
function Set:add(it) function Set:add(it)
if not self:contains(it) then if not self:contains(it) then
local idx = #self.tbl + 1 local idx = #self.tbl + 1
@@ -84,6 +93,7 @@ function Set:add(it)
end end
end end
--- @param it string
function Set:remove(it) function Set:remove(it)
if self:contains(it) then if self:contains(it) then
local idx = self.items[it] local idx = self.items[it]
@@ -93,10 +103,13 @@ function Set:remove(it)
end end
end end
--- @param it string
--- @return boolean
function Set:contains(it) function Set:contains(it)
return self.items[it] or false return self.items[it] or false
end end
--- @return integer
function Set:size() function Set:size()
return self.nelem return self.nelem
end end
@@ -113,29 +126,18 @@ function Set:iterator()
return pairs(self.items) return pairs(self.items)
end end
--- @return string[]
function Set:to_table() function Set:to_table()
-- there might be gaps in @tbl, so we have to be careful and sort first -- there might be gaps in @tbl, so we have to be careful and sort first
local keys local keys = {} --- @type string[]
do
local _accum_0 = { }
local _len_0 = 1
for idx, _ in pairs(self.tbl) do for idx, _ in pairs(self.tbl) do
_accum_0[_len_0] = idx keys[#keys+1] = idx
_len_0 = _len_0 + 1
end
keys = _accum_0
end end
table.sort(keys) table.sort(keys)
local copy local copy = {} --- @type string[]
do for _, idx in ipairs(keys) do
local _accum_0 = { } copy[#copy+1] = self.tbl[idx]
local _len_0 = 1
for _index_0 = 1, #keys do
local idx = keys[_index_0]
_accum_0[_len_0] = self.tbl[idx]
_len_0 = _len_0 + 1
end
copy = _accum_0
end end
return copy return copy
end end