diff --git a/runtime/lua/vim/treesitter/dev.lua b/runtime/lua/vim/treesitter/dev.lua index ee4d981c16..14e21901df 100644 --- a/runtime/lua/vim/treesitter/dev.lua +++ b/runtime/lua/vim/treesitter/dev.lua @@ -45,7 +45,7 @@ local TSTreeView = {} ---@param depth integer Current recursion depth ---@param field string|nil The field of the current node ---@param lang string Language of the tree currently being traversed ----@param injections table Mapping of node ids to root nodes +---@param injections table Mapping of node ids to root nodes --- of injected language trees (see explanation above) ---@param tree vim.treesitter.dev.Node[] Output table containing a list of tables each representing a node in the tree local function traverse(node, depth, field, lang, injections, tree) @@ -56,8 +56,7 @@ local function traverse(node, depth, field, lang, injections, tree) field = field, }) - local injection = injections[node:id()] - if injection then + for _, injection in ipairs(injections[node:id()] or {}) do traverse(injection.root, depth + 1, nil, injection.lang, injections, tree) end @@ -94,7 +93,7 @@ function TSTreeView:new(bufnr, lang) -- the primary tree that contains that root. Add a mapping from the node in the primary tree to -- the root in the child tree to the {injections} table. local root = parser:parse(true)[1]:root() - local injections = {} ---@type table + local injections = {} ---@type table> parser:for_each_tree(function(parent_tree, parent_ltree) local parent = parent_tree:root() @@ -106,18 +105,32 @@ function TSTreeView:new(bufnr, lang) if Range.contains(parent_range, r_range) then local node = assert(parent:named_descendant_for_range(r:range())) local id = node:id() - if not injections[id] or r:byte_length() > injections[id].root:byte_length() then - injections[id] = { - lang = child:lang(), - root = r, - } + local ilang = child:lang() + injections[id] = injections[id] or {} + local injection = injections[id][ilang] + if not injection or r:byte_length() > injection:byte_length() then + injections[id][ilang] = r end end end end end) - local nodes = traverse(root, 0, nil, parser:lang(), injections, {}) + local sorted_injections = {} ---@type table + for id, lang_injections in pairs(injections) do + local langs = vim.tbl_keys(lang_injections) + ---@param a string + ---@param b string + table.sort(langs, function(a, b) + return lang_injections[a]:byte_length() > lang_injections[b]:byte_length() + end) + ---@param ilang string + sorted_injections[id] = vim.tbl_map(function(ilang) + return { lang = ilang, root = lang_injections[ilang] } + end, langs) + end + + local nodes = traverse(root, 0, nil, parser:lang(), sorted_injections, {}) local named = {} ---@type vim.treesitter.dev.Node[] for _, v in ipairs(nodes) do diff --git a/test/functional/treesitter/inspect_tree_spec.lua b/test/functional/treesitter/inspect_tree_spec.lua index fb87618fa9..88c855df28 100644 --- a/test/functional/treesitter/inspect_tree_spec.lua +++ b/test/functional/treesitter/inspect_tree_spec.lua @@ -120,6 +120,45 @@ describe('vim.treesitter.inspect_tree', function() ]] end) + it('works with multiple injection on the same node', function() + insert([[--* #include]]) + exec_lua(function() + vim.treesitter.query.set( + 'lua', + 'injections', + [[ + (comment + content: (_) @injection.content + (#set! injection.language "markdown")) + (comment + content: (_) @injection.content + (#set! injection.language "c") + (#offset! @injection.content 0 1 0 0)) + ]] + ) + vim.treesitter.start(0, 'lua') + vim.treesitter.get_parser():parse(true) + vim.treesitter.inspect_tree() + end) + feed('I') + expect_tree [[ + (chunk ; [0, 0] - [1, 0] lua + (comment ; [0, 0] - [0, 21] lua + content: (comment_content ; [0, 2] - [0, 21] lua + (document ; [0, 2] - [0, 21] markdown + (section ; [0, 2] - [0, 21] markdown + (list ; [0, 2] - [0, 21] markdown + (list_item ; [0, 2] - [0, 21] markdown + (list_marker_star) ; [0, 2] - [0, 4] markdown + (paragraph ; [0, 4] - [0, 21] markdown + (inline ; [0, 4] - [0, 21] markdown + (inline))))))) ; [0, 4] - [0, 21] markdown_inline + (translation_unit ; [0, 4] - [0, 21] c + (preproc_include ; [0, 4] - [0, 21] c + path: (system_lib_string)))))) ; [0, 12] - [0, 21] c + ]] + end) + it('can toggle to show languages', function() insert([[ ```lua