refactor(treesitter): simplify injection retrieval #33104

Simplify the logic for retrieving the injection ranges for the language
tree. The trees are now also sorted by starting position, regardless of
whether they are part of a combined injection or not. This would be
helpful if ranges are ever to be stored in an interval tree or other
kind of sorted tree structure.
This commit is contained in:
Riley Bruins
2025-03-28 04:38:47 -07:00
committed by GitHub
parent 18fa61049a
commit 75cbd9a8ae
2 changed files with 44 additions and 70 deletions

View File

@@ -868,35 +868,42 @@ end
---@alias vim.treesitter.languagetree.Injection table<string,table<integer,vim.treesitter.languagetree.InjectionElem>> ---@alias vim.treesitter.languagetree.Injection table<string,table<integer,vim.treesitter.languagetree.InjectionElem>>
---@param t table<integer,vim.treesitter.languagetree.Injection> ---@param t vim.treesitter.languagetree.Injection
---@param tree_index integer
---@param pattern integer ---@param pattern integer
---@param lang string ---@param lang string
---@param combined boolean ---@param combined boolean
---@param ranges Range6[] ---@param ranges Range6[]
local function add_injection(t, tree_index, pattern, lang, combined, ranges) ---@param result table<string,Range6[][]>
local function add_injection(t, pattern, lang, combined, ranges, result)
if #ranges == 0 then if #ranges == 0 then
-- Make sure not to add an empty range set as this is interpreted to mean the whole buffer. -- Make sure not to add an empty range set as this is interpreted to mean the whole buffer.
return return
end end
-- Each tree index should be isolated from the other nodes. if not result[lang] then
if not t[tree_index] then result[lang] = {}
t[tree_index] = {}
end end
if not t[tree_index][lang] then if not combined then
t[tree_index][lang] = {} table.insert(result[lang], ranges)
return
end end
-- Key this by pattern. If combined is set to true all captures of this pattern if not t[lang] then
t[lang] = {}
end
-- Key this by pattern. For combined injections, all captures of this pattern
-- will be parsed by treesitter as the same "source". -- will be parsed by treesitter as the same "source".
-- If combined is false, each "region" will be parsed as a single source. if not t[lang][pattern] then
if not t[tree_index][lang][pattern] then local regions = {}
t[tree_index][lang][pattern] = { combined = combined, regions = {} } t[lang][pattern] = regions
table.insert(result[lang], regions)
end end
table.insert(t[tree_index][lang][pattern].regions, ranges) for _, range in ipairs(ranges) do
table.insert(t[lang][pattern], range)
end
end end
-- TODO(clason): replace by refactored `ts.has_parser` API (without side effects) -- TODO(clason): replace by refactored `ts.has_parser` API (without side effects)
@@ -964,19 +971,6 @@ function LanguageTree:_get_injection(match, metadata)
return lang, combined, ranges return lang, combined, ranges
end end
--- Can't use vim.tbl_flatten since a range is just a table.
---@param regions Range6[][]
---@return Range6[]
local function combine_regions(regions)
local result = {} ---@type Range6[]
for _, region in ipairs(regions) do
for _, range in ipairs(region) do
result[#result + 1] = range
end
end
return result
end
--- Gets language injection regions by language. --- Gets language injection regions by language.
--- ---
--- This is where most of the injection processing occurs. --- This is where most of the injection processing occurs.
@@ -993,13 +987,16 @@ function LanguageTree:_get_injections(range, thread_state)
return {} return {}
end end
---@type table<integer,vim.treesitter.languagetree.Injection>
local injections = {}
local start = vim.uv.hrtime() local start = vim.uv.hrtime()
---@type table<string,Range6[][]>
local result = {}
local full_scan = range == true or self._injection_query.has_combined_injections local full_scan = range == true or self._injection_query.has_combined_injections
for index, tree in pairs(self._trees) do for _, tree in pairs(self._trees) do
---@type vim.treesitter.languagetree.Injection
local injections = {}
local root_node = tree:root() local root_node = tree:root()
local start_line, end_line ---@type integer, integer local start_line, end_line ---@type integer, integer
if full_scan then if full_scan then
@@ -1013,7 +1010,7 @@ function LanguageTree:_get_injections(range, thread_state)
do do
local lang, combined, ranges = self:_get_injection(match, metadata) local lang, combined, ranges = self:_get_injection(match, metadata)
if lang then if lang then
add_injection(injections, index, pattern, lang, combined, ranges) add_injection(injections, pattern, lang, combined, ranges, result)
else else
self:_log('match from injection query failed for pattern', pattern) self:_log('match from injection query failed for pattern', pattern)
end end
@@ -1025,29 +1022,6 @@ function LanguageTree:_get_injections(range, thread_state)
end end
end end
---@type table<string,Range6[][]>
local result = {}
-- Generate a map by lang of node lists.
-- Each list is a set of ranges that should be parsed together.
for _, lang_map in pairs(injections) do
for lang, patterns in pairs(lang_map) do
if not result[lang] then
result[lang] = {}
end
for _, entry in pairs(patterns) do
if entry.combined then
table.insert(result[lang], combine_regions(entry.regions))
else
for _, ranges in pairs(entry.regions) do
table.insert(result[lang], ranges)
end
end
end
end
end
if full_scan then if full_scan then
self._processed_injection_range = entire_document_range self._processed_injection_range = entire_document_range
else else

View File

@@ -575,22 +575,22 @@ int x = INT_MAX;
eq(5, exec_lua('return #parser:children().c:trees()')) eq(5, exec_lua('return #parser:children().c:trees()'))
eq({ eq({
{ 0, 0, 7, 0 }, -- root tree { 0, 0, 7, 0 }, -- root tree
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 14, 3, 17 }, -- VALUE 123 { 3, 14, 3, 17 }, -- VALUE 123
{ 4, 15, 4, 18 }, -- VALUE1 123 { 4, 15, 4, 18 }, -- VALUE1 123
{ 5, 15, 5, 18 }, -- VALUE2 123 { 5, 15, 5, 18 }, -- VALUE2 123
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges()) }, get_ranges())
n.feed('ggo<esc>') n.feed('ggo<esc>')
eq(5, exec_lua('return #parser:children().c:trees()')) eq(5, exec_lua('return #parser:children().c:trees()'))
eq({ eq({
{ 0, 0, 8, 0 }, -- root tree { 0, 0, 8, 0 }, -- root tree
{ 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 4, 14, 4, 17 }, -- VALUE 123 { 4, 14, 4, 17 }, -- VALUE 123
{ 5, 15, 5, 18 }, -- VALUE1 123 { 5, 15, 5, 18 }, -- VALUE1 123
{ 6, 15, 6, 18 }, -- VALUE2 123 { 6, 15, 6, 18 }, -- VALUE2 123
{ 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges()) }, get_ranges())
end) end)
end) end)
@@ -613,11 +613,11 @@ int x = INT_MAX;
eq(2, exec_lua('return #parser:children().c:trees()')) eq(2, exec_lua('return #parser:children().c:trees()'))
eq({ eq({
{ 0, 0, 7, 0 }, -- root tree { 0, 0, 7, 0 }, -- root tree
{ 1, 26, 2, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 14, 5, 18 }, -- VALUE 123 { 3, 14, 5, 18 }, -- VALUE 123
-- VALUE1 123 -- VALUE1 123
-- VALUE2 123 -- VALUE2 123
{ 1, 26, 2, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges()) }, get_ranges())
n.feed('ggo<esc>') n.feed('ggo<esc>')
@@ -625,11 +625,11 @@ int x = INT_MAX;
eq(2, exec_lua('return #parser:children().c:trees()')) eq(2, exec_lua('return #parser:children().c:trees()'))
eq({ eq({
{ 0, 0, 8, 0 }, -- root tree { 0, 0, 8, 0 }, -- root tree
{ 4, 14, 6, 18 }, -- VALUE 123
-- VALUE1 123
-- VALUE2 123
{ 2, 26, 3, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) { 2, 26, 3, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
-- VALUE 123
{ 4, 14, 6, 18 }, -- VALUE1 123
-- VALUE2 123
}, get_ranges()) }, get_ranges())
n.feed('7ggI//<esc>') n.feed('7ggI//<esc>')
@@ -638,10 +638,10 @@ int x = INT_MAX;
eq(2, exec_lua('return #parser:children().c:trees()')) eq(2, exec_lua('return #parser:children().c:trees()'))
eq({ eq({
{ 0, 0, 8, 0 }, -- root tree { 0, 0, 8, 0 }, -- root tree
{ 4, 14, 5, 18 }, -- VALUE 123
-- VALUE1 123
{ 2, 26, 3, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y)) { 2, 26, 3, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y)) -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
-- VALUE 123
{ 4, 14, 5, 18 }, -- VALUE1 123
}, get_ranges()) }, get_ranges())
end) end)
@@ -794,22 +794,22 @@ int x = INT_MAX;
eq(5, exec_lua('return #parser:children().c:trees()')) eq(5, exec_lua('return #parser:children().c:trees()'))
eq({ eq({
{ 0, 0, 7, 0 }, -- root tree { 0, 0, 7, 0 }, -- root tree
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 14, 3, 17 }, -- VALUE 123 { 3, 14, 3, 17 }, -- VALUE 123
{ 4, 15, 4, 18 }, -- VALUE1 123 { 4, 15, 4, 18 }, -- VALUE1 123
{ 5, 15, 5, 18 }, -- VALUE2 123 { 5, 15, 5, 18 }, -- VALUE2 123
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges()) }, get_ranges())
n.feed('ggo<esc>') n.feed('ggo<esc>')
eq(5, exec_lua('return #parser:children().c:trees()')) eq(5, exec_lua('return #parser:children().c:trees()'))
eq({ eq({
{ 0, 0, 8, 0 }, -- root tree { 0, 0, 8, 0 }, -- root tree
{ 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 4, 14, 4, 17 }, -- VALUE 123 { 4, 14, 4, 17 }, -- VALUE 123
{ 5, 15, 5, 18 }, -- VALUE1 123 { 5, 15, 5, 18 }, -- VALUE1 123
{ 6, 15, 6, 18 }, -- VALUE2 123 { 6, 15, 6, 18 }, -- VALUE2 123
{ 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges()) }, get_ranges())
end) end)
end) end)
@@ -831,11 +831,11 @@ int x = INT_MAX;
eq('table', exec_lua('return type(parser:children().c)')) eq('table', exec_lua('return type(parser:children().c)'))
eq({ eq({
{ 0, 0, 7, 0 }, -- root tree { 0, 0, 7, 0 }, -- root tree
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 16, 3, 16 }, -- VALUE 123 { 3, 16, 3, 16 }, -- VALUE 123
{ 4, 17, 4, 17 }, -- VALUE1 123 { 4, 17, 4, 17 }, -- VALUE1 123
{ 5, 17, 5, 17 }, -- VALUE2 123 { 5, 17, 5, 17 }, -- VALUE2 123
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges()) }, get_ranges())
end) end)
it('should list all directives', function() it('should list all directives', function()