feat(api, lua): support converting nested Funcref back to LuaRef (#17749)

This commit is contained in:
zeertzjq
2022-03-18 03:21:47 +08:00
committed by GitHub
parent 09a3b33d36
commit cac90d2de7
3 changed files with 95 additions and 36 deletions

View File

@@ -64,7 +64,14 @@ typedef struct {
#define TYPVAL_ENCODE_CONV_FUNC_START(tv, fun) \ #define TYPVAL_ENCODE_CONV_FUNC_START(tv, fun) \
do { \ do { \
ufunc_T *fp = find_func(fun); \
assert(fp != NULL); \
if (fp->uf_cb == nlua_CFunction_func_call) { \
LuaRef ref = api_new_luaref(((LuaCFunctionState *)fp->uf_cb_state)->lua_callable.func_ref); \
kvi_push(edata->stack, LUAREF_OBJ(ref)); \
} else { \
TYPVAL_ENCODE_CONV_NIL(tv); \ TYPVAL_ENCODE_CONV_NIL(tv); \
} \
goto typval_encode_stop_converting_one_item; \ goto typval_encode_stop_converting_one_item; \
} while (0) } while (0)
@@ -231,14 +238,6 @@ static inline void typval_encode_dict_end(EncodedData *const edata)
/// @return The converted value /// @return The converted value
Object vim_to_object(typval_T *obj) Object vim_to_object(typval_T *obj)
{ {
if (obj->v_type == VAR_FUNC) {
ufunc_T *fp = find_func(obj->vval.v_string);
assert(fp != NULL);
if (fp->uf_cb == nlua_CFunction_func_call) {
LuaRef ref = api_new_luaref(((LuaCFunctionState *)fp->uf_cb_state)->lua_callable.func_ref);
return LUAREF_OBJ(ref);
}
}
EncodedData edata; EncodedData edata;
kvi_init(edata.stack); kvi_init(edata.stack);
const int evo_ret = encode_vim_to_object(&edata, obj, const int evo_ret = encode_vim_to_object(&edata, obj,

View File

@@ -476,7 +476,13 @@ static bool typval_conv_special = false;
#define TYPVAL_ENCODE_CONV_FUNC_START(tv, fun) \ #define TYPVAL_ENCODE_CONV_FUNC_START(tv, fun) \
do { \ do { \
ufunc_T *fp = find_func(fun); \
assert(fp != NULL); \
if (fp->uf_cb == nlua_CFunction_func_call) { \
nlua_pushref(lstate, ((LuaCFunctionState *)fp->uf_cb_state)->lua_callable.func_ref); \
} else { \
TYPVAL_ENCODE_CONV_NIL(tv); \ TYPVAL_ENCODE_CONV_NIL(tv); \
} \
goto typval_encode_stop_converting_one_item; \ goto typval_encode_stop_converting_one_item; \
} while (0) } while (0)
@@ -615,14 +621,6 @@ bool nlua_push_typval(lua_State *lstate, typval_T *const tv, bool special)
semsg(_("E1502: Lua failed to grow stack to %i"), initial_size + 4); semsg(_("E1502: Lua failed to grow stack to %i"), initial_size + 4);
return false; return false;
} }
if (tv->v_type == VAR_FUNC) {
ufunc_T *fp = find_func(tv->vval.v_string);
assert(fp != NULL);
if (fp->uf_cb == nlua_CFunction_func_call) {
nlua_pushref(lstate, ((LuaCFunctionState *)fp->uf_cb_state)->lua_callable.func_ref);
return true;
}
}
if (encode_vim_to_lua(lstate, tv, "nlua_push_typval argument") == FAIL) { if (encode_vim_to_lua(lstate, tv, "nlua_push_typval argument") == FAIL) {
return false; return false;
} }

View File

@@ -14,7 +14,7 @@ local feed = helpers.feed
local pcall_err = helpers.pcall_err local pcall_err = helpers.pcall_err
local exec_lua = helpers.exec_lua local exec_lua = helpers.exec_lua
local matches = helpers.matches local matches = helpers.matches
local source = helpers.source local exec = helpers.exec
local NIL = helpers.NIL local NIL = helpers.NIL
local retry = helpers.retry local retry = helpers.retry
local next_msg = helpers.next_msg local next_msg = helpers.next_msg
@@ -743,7 +743,7 @@ describe('lua stdlib', function()
-- compat: nvim_call_function uses "special" value for vimL float -- compat: nvim_call_function uses "special" value for vimL float
eq(false, exec_lua([[return vim.api.nvim_call_function('sin', {0.0}) == 0.0 ]])) eq(false, exec_lua([[return vim.api.nvim_call_function('sin', {0.0}) == 0.0 ]]))
source([[ exec([[
func! FooFunc(test) func! FooFunc(test)
let g:test = a:test let g:test = a:test
return {} return {}
@@ -771,6 +771,12 @@ describe('lua stdlib', function()
-- error handling -- error handling
eq({false, 'Vim:E897: List or Blob required'}, exec_lua([[return {pcall(vim.fn.add, "aa", "bb")}]])) eq({false, 'Vim:E897: List or Blob required'}, exec_lua([[return {pcall(vim.fn.add, "aa", "bb")}]]))
-- conversion between LuaRef and Vim Funcref
eq(true, exec_lua([[
local x = vim.fn.VarArg(function() return 'foo' end, function() return 'bar' end)
return #x == 2 and x[1]() == 'foo' and x[2]() == 'bar'
]]))
end) end)
it('vim.fn should error when calling API function', function() it('vim.fn should error when calling API function', function()
@@ -993,8 +999,11 @@ describe('lua stdlib', function()
exec_lua [[ exec_lua [[
local counter = 0 local counter = 0
vim.g.AddCounter = function() counter = counter + 1 end local function add_counter() counter = counter + 1 end
vim.g.GetCounter = function() return counter end local function get_counter() return counter end
vim.g.AddCounter = add_counter
vim.g.GetCounter = get_counter
vim.g.funcs = {add = add_counter, get = get_counter}
]] ]]
eq(0, eval('g:GetCounter()')) eq(0, eval('g:GetCounter()'))
@@ -1006,11 +1015,18 @@ describe('lua stdlib', function()
eq(3, exec_lua([[return vim.g.GetCounter()]])) eq(3, exec_lua([[return vim.g.GetCounter()]]))
exec_lua([[vim.api.nvim_get_var('AddCounter')()]]) exec_lua([[vim.api.nvim_get_var('AddCounter')()]])
eq(4, exec_lua([[return vim.api.nvim_get_var('GetCounter')()]])) eq(4, exec_lua([[return vim.api.nvim_get_var('GetCounter')()]]))
exec_lua([[vim.g.funcs.add()]])
eq(5, exec_lua([[return vim.g.funcs.get()]]))
exec_lua([[vim.api.nvim_get_var('funcs').add()]])
eq(6, exec_lua([[return vim.api.nvim_get_var('funcs').get()]]))
exec_lua [[ exec_lua [[
local counter = 0 local counter = 0
vim.api.nvim_set_var('AddCounter', function() counter = counter + 1 end) local function add_counter() counter = counter + 1 end
vim.api.nvim_set_var('GetCounter', function() return counter end) local function get_counter() return counter end
vim.api.nvim_set_var('AddCounter', add_counter)
vim.api.nvim_set_var('GetCounter', get_counter)
vim.api.nvim_set_var('funcs', {add = add_counter, get = get_counter})
]] ]]
eq(0, eval('g:GetCounter()')) eq(0, eval('g:GetCounter()'))
@@ -1022,6 +1038,10 @@ describe('lua stdlib', function()
eq(3, exec_lua([[return vim.g.GetCounter()]])) eq(3, exec_lua([[return vim.g.GetCounter()]]))
exec_lua([[vim.api.nvim_get_var('AddCounter')()]]) exec_lua([[vim.api.nvim_get_var('AddCounter')()]])
eq(4, exec_lua([[return vim.api.nvim_get_var('GetCounter')()]])) eq(4, exec_lua([[return vim.api.nvim_get_var('GetCounter')()]]))
exec_lua([[vim.g.funcs.add()]])
eq(5, exec_lua([[return vim.g.funcs.get()]]))
exec_lua([[vim.api.nvim_get_var('funcs').add()]])
eq(6, exec_lua([[return vim.api.nvim_get_var('funcs').get()]]))
-- Check if autoload works properly -- Check if autoload works properly
local pathsep = helpers.get_pathsep() local pathsep = helpers.get_pathsep()
@@ -1072,8 +1092,11 @@ describe('lua stdlib', function()
exec_lua [[ exec_lua [[
local counter = 0 local counter = 0
vim.b.AddCounter = function() counter = counter + 1 end local function add_counter() counter = counter + 1 end
vim.b.GetCounter = function() return counter end local function get_counter() return counter end
vim.b.AddCounter = add_counter
vim.b.GetCounter = get_counter
vim.b.funcs = {add = add_counter, get = get_counter}
]] ]]
eq(0, eval('b:GetCounter()')) eq(0, eval('b:GetCounter()'))
@@ -1085,11 +1108,18 @@ describe('lua stdlib', function()
eq(3, exec_lua([[return vim.b.GetCounter()]])) eq(3, exec_lua([[return vim.b.GetCounter()]]))
exec_lua([[vim.api.nvim_buf_get_var(0, 'AddCounter')()]]) exec_lua([[vim.api.nvim_buf_get_var(0, 'AddCounter')()]])
eq(4, exec_lua([[return vim.api.nvim_buf_get_var(0, 'GetCounter')()]])) eq(4, exec_lua([[return vim.api.nvim_buf_get_var(0, 'GetCounter')()]]))
exec_lua([[vim.b.funcs.add()]])
eq(5, exec_lua([[return vim.b.funcs.get()]]))
exec_lua([[vim.api.nvim_buf_get_var(0, 'funcs').add()]])
eq(6, exec_lua([[return vim.api.nvim_buf_get_var(0, 'funcs').get()]]))
exec_lua [[ exec_lua [[
local counter = 0 local counter = 0
vim.api.nvim_buf_set_var(0, 'AddCounter', function() counter = counter + 1 end) local function add_counter() counter = counter + 1 end
vim.api.nvim_buf_set_var(0, 'GetCounter', function() return counter end) local function get_counter() return counter end
vim.api.nvim_buf_set_var(0, 'AddCounter', add_counter)
vim.api.nvim_buf_set_var(0, 'GetCounter', get_counter)
vim.api.nvim_buf_set_var(0, 'funcs', {add = add_counter, get = get_counter})
]] ]]
eq(0, eval('b:GetCounter()')) eq(0, eval('b:GetCounter()'))
@@ -1101,6 +1131,10 @@ describe('lua stdlib', function()
eq(3, exec_lua([[return vim.b.GetCounter()]])) eq(3, exec_lua([[return vim.b.GetCounter()]]))
exec_lua([[vim.api.nvim_buf_get_var(0, 'AddCounter')()]]) exec_lua([[vim.api.nvim_buf_get_var(0, 'AddCounter')()]])
eq(4, exec_lua([[return vim.api.nvim_buf_get_var(0, 'GetCounter')()]])) eq(4, exec_lua([[return vim.api.nvim_buf_get_var(0, 'GetCounter')()]]))
exec_lua([[vim.b.funcs.add()]])
eq(5, exec_lua([[return vim.b.funcs.get()]]))
exec_lua([[vim.api.nvim_buf_get_var(0, 'funcs').add()]])
eq(6, exec_lua([[return vim.api.nvim_buf_get_var(0, 'funcs').get()]]))
exec_lua [[ exec_lua [[
vim.cmd "vnew" vim.cmd "vnew"
@@ -1141,8 +1175,11 @@ describe('lua stdlib', function()
exec_lua [[ exec_lua [[
local counter = 0 local counter = 0
vim.w.AddCounter = function() counter = counter + 1 end local function add_counter() counter = counter + 1 end
vim.w.GetCounter = function() return counter end local function get_counter() return counter end
vim.w.AddCounter = add_counter
vim.w.GetCounter = get_counter
vim.w.funcs = {add = add_counter, get = get_counter}
]] ]]
eq(0, eval('w:GetCounter()')) eq(0, eval('w:GetCounter()'))
@@ -1154,11 +1191,18 @@ describe('lua stdlib', function()
eq(3, exec_lua([[return vim.w.GetCounter()]])) eq(3, exec_lua([[return vim.w.GetCounter()]]))
exec_lua([[vim.api.nvim_win_get_var(0, 'AddCounter')()]]) exec_lua([[vim.api.nvim_win_get_var(0, 'AddCounter')()]])
eq(4, exec_lua([[return vim.api.nvim_win_get_var(0, 'GetCounter')()]])) eq(4, exec_lua([[return vim.api.nvim_win_get_var(0, 'GetCounter')()]]))
exec_lua([[vim.w.funcs.add()]])
eq(5, exec_lua([[return vim.w.funcs.get()]]))
exec_lua([[vim.api.nvim_win_get_var(0, 'funcs').add()]])
eq(6, exec_lua([[return vim.api.nvim_win_get_var(0, 'funcs').get()]]))
exec_lua [[ exec_lua [[
local counter = 0 local counter = 0
vim.api.nvim_win_set_var(0, 'AddCounter', function() counter = counter + 1 end) local function add_counter() counter = counter + 1 end
vim.api.nvim_win_set_var(0, 'GetCounter', function() return counter end) local function get_counter() return counter end
vim.api.nvim_win_set_var(0, 'AddCounter', add_counter)
vim.api.nvim_win_set_var(0, 'GetCounter', get_counter)
vim.api.nvim_win_set_var(0, 'funcs', {add = add_counter, get = get_counter})
]] ]]
eq(0, eval('w:GetCounter()')) eq(0, eval('w:GetCounter()'))
@@ -1170,6 +1214,10 @@ describe('lua stdlib', function()
eq(3, exec_lua([[return vim.w.GetCounter()]])) eq(3, exec_lua([[return vim.w.GetCounter()]]))
exec_lua([[vim.api.nvim_win_get_var(0, 'AddCounter')()]]) exec_lua([[vim.api.nvim_win_get_var(0, 'AddCounter')()]])
eq(4, exec_lua([[return vim.api.nvim_win_get_var(0, 'GetCounter')()]])) eq(4, exec_lua([[return vim.api.nvim_win_get_var(0, 'GetCounter')()]]))
exec_lua([[vim.w.funcs.add()]])
eq(5, exec_lua([[return vim.w.funcs.get()]]))
exec_lua([[vim.api.nvim_win_get_var(0, 'funcs').add()]])
eq(6, exec_lua([[return vim.api.nvim_win_get_var(0, 'funcs').get()]]))
exec_lua [[ exec_lua [[
vim.cmd "vnew" vim.cmd "vnew"
@@ -1205,8 +1253,11 @@ describe('lua stdlib', function()
exec_lua [[ exec_lua [[
local counter = 0 local counter = 0
vim.t.AddCounter = function() counter = counter + 1 end local function add_counter() counter = counter + 1 end
vim.t.GetCounter = function() return counter end local function get_counter() return counter end
vim.t.AddCounter = add_counter
vim.t.GetCounter = get_counter
vim.t.funcs = {add = add_counter, get = get_counter}
]] ]]
eq(0, eval('t:GetCounter()')) eq(0, eval('t:GetCounter()'))
@@ -1218,11 +1269,18 @@ describe('lua stdlib', function()
eq(3, exec_lua([[return vim.t.GetCounter()]])) eq(3, exec_lua([[return vim.t.GetCounter()]]))
exec_lua([[vim.api.nvim_tabpage_get_var(0, 'AddCounter')()]]) exec_lua([[vim.api.nvim_tabpage_get_var(0, 'AddCounter')()]])
eq(4, exec_lua([[return vim.api.nvim_tabpage_get_var(0, 'GetCounter')()]])) eq(4, exec_lua([[return vim.api.nvim_tabpage_get_var(0, 'GetCounter')()]]))
exec_lua([[vim.t.funcs.add()]])
eq(5, exec_lua([[return vim.t.funcs.get()]]))
exec_lua([[vim.api.nvim_tabpage_get_var(0, 'funcs').add()]])
eq(6, exec_lua([[return vim.api.nvim_tabpage_get_var(0, 'funcs').get()]]))
exec_lua [[ exec_lua [[
local counter = 0 local counter = 0
vim.api.nvim_tabpage_set_var(0, 'AddCounter', function() counter = counter + 1 end) local function add_counter() counter = counter + 1 end
vim.api.nvim_tabpage_set_var(0, 'GetCounter', function() return counter end) local function get_counter() return counter end
vim.api.nvim_tabpage_set_var(0, 'AddCounter', add_counter)
vim.api.nvim_tabpage_set_var(0, 'GetCounter', get_counter)
vim.api.nvim_tabpage_set_var(0, 'funcs', {add = add_counter, get = get_counter})
]] ]]
eq(0, eval('t:GetCounter()')) eq(0, eval('t:GetCounter()'))
@@ -1234,6 +1292,10 @@ describe('lua stdlib', function()
eq(3, exec_lua([[return vim.t.GetCounter()]])) eq(3, exec_lua([[return vim.t.GetCounter()]]))
exec_lua([[vim.api.nvim_tabpage_get_var(0, 'AddCounter')()]]) exec_lua([[vim.api.nvim_tabpage_get_var(0, 'AddCounter')()]])
eq(4, exec_lua([[return vim.api.nvim_tabpage_get_var(0, 'GetCounter')()]])) eq(4, exec_lua([[return vim.api.nvim_tabpage_get_var(0, 'GetCounter')()]]))
exec_lua([[vim.t.funcs.add()]])
eq(5, exec_lua([[return vim.t.funcs.get()]]))
exec_lua([[vim.api.nvim_tabpage_get_var(0, 'funcs').add()]])
eq(6, exec_lua([[return vim.api.nvim_tabpage_get_var(0, 'funcs').get()]]))
exec_lua [[ exec_lua [[
vim.cmd "tabnew" vim.cmd "tabnew"