fix(treesitter): injected lang ranges may cross capture boundaries #32549

Problem:
treesitter injected language ranges sometimes cross over the capture
boundaries when `@combined`.

Solution:
Clip child regions to not spill out of parent regions within
languagetree.lua, and only apply highlights within those regions in
highlighter.lua.


Co-authored-by: Cormac Relf <web@cormacrelf.net>
This commit is contained in:
Riley Bruins
2025-04-13 14:22:17 -07:00
committed by GitHub
parent ee3f9a1e03
commit 0977f70f4d
4 changed files with 213 additions and 42 deletions

View File

@@ -114,6 +114,19 @@ function M.intercepts(r1, r2)
return true
end
---@private
---@param r1 Range6
---@param r2 Range6
---@return Range6?
function M.intersection(r1, r2)
if not M.intercepts(r1, r2) then
return nil
end
local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1
local re = M.cmp_pos.ge(r1[4], r1[5], r2[4], r2[5]) and r2 or r1
return { rs[1], rs[2], rs[3], re[4], re[5], re[6] }
end
---@private
---@param r Range
---@return integer, integer, integer, integer

View File

@@ -322,6 +322,8 @@ local function on_line_impl(self, buf, line, on_spell, on_conceal)
return
end
local tree_region = state.tstree:included_ranges(true)
if state.iter == nil or state.next_row < line then
-- Mainly used to skip over folds
@@ -336,56 +338,63 @@ local function on_line_impl(self, buf, line, on_spell, on_conceal)
while line >= state.next_row do
local capture, node, metadata, match = state.iter(line)
local range = { root_end_row + 1, 0, root_end_row + 1, 0 }
local outer_range = { root_end_row + 1, 0, root_end_row + 1, 0 }
if node then
range = vim.treesitter.get_range(node, buf, metadata and metadata[capture])
outer_range = vim.treesitter.get_range(node, buf, metadata and metadata[capture])
end
local start_row, start_col, end_row, end_col = Range.unpack4(range)
local outer_range_start_row = outer_range[1]
if capture then
local hl = state.highlighter_query:get_hl_from_capture(capture)
for _, range in ipairs(tree_region) do
local intersection = Range.intersection(range, outer_range)
if intersection then
local start_row, start_col, end_row, end_col = Range.unpack4(intersection)
local capture_name = captures[capture]
if capture then
local hl = state.highlighter_query:get_hl_from_capture(capture)
local spell, spell_pri_offset = get_spell(capture_name)
local capture_name = captures[capture]
-- The "priority" attribute can be set at the pattern level or on a particular capture
local priority = (
tonumber(metadata.priority or metadata[capture] and metadata[capture].priority)
or vim.hl.priorities.treesitter
) + spell_pri_offset
local spell, spell_pri_offset = get_spell(capture_name)
-- The "conceal" attribute can be set at the pattern level or on a particular capture
local conceal = metadata.conceal or metadata[capture] and metadata[capture].conceal
-- The "priority" attribute can be set at the pattern level or on a particular capture
local priority = (
tonumber(metadata.priority or metadata[capture] and metadata[capture].priority)
or vim.hl.priorities.treesitter
) + spell_pri_offset
local url = get_url(match, buf, capture, metadata)
-- The "conceal" attribute can be set at the pattern level or on a particular capture
local conceal = metadata.conceal or metadata[capture] and metadata[capture].conceal
if hl and end_row >= line and not on_conceal and (not on_spell or spell ~= nil) then
api.nvim_buf_set_extmark(buf, ns, start_row, start_col, {
end_line = end_row,
end_col = end_col,
hl_group = hl,
ephemeral = true,
priority = priority,
conceal = conceal,
spell = spell,
url = url,
})
end
local url = get_url(match, buf, capture, metadata)
if
(metadata.conceal_lines or metadata[capture] and metadata[capture].conceal_lines)
and #api.nvim_buf_get_extmarks(buf, ns, { start_row, 0 }, { start_row, 0 }, {}) == 0
then
api.nvim_buf_set_extmark(buf, ns, start_row, 0, {
end_line = end_row,
conceal_lines = '',
})
if hl and end_row >= line and not on_conceal and (not on_spell or spell ~= nil) then
api.nvim_buf_set_extmark(buf, ns, start_row, start_col, {
end_line = end_row,
end_col = end_col,
hl_group = hl,
ephemeral = true,
priority = priority,
conceal = conceal,
spell = spell,
url = url,
})
end
if
(metadata.conceal_lines or metadata[capture] and metadata[capture].conceal_lines)
and #api.nvim_buf_get_extmarks(buf, ns, { start_row, 0 }, { start_row, 0 }, {}) == 0
then
api.nvim_buf_set_extmark(buf, ns, start_row, 0, {
end_line = end_row,
conceal_lines = '',
})
end
end
end
end
if start_row > line then
state.next_row = start_row
if outer_range_start_row > line then
state.next_row = outer_range_start_row
end
end
end)

View File

@@ -874,6 +874,39 @@ local function get_node_ranges(node, source, metadata, include_children)
return ranges
end
---Finds the intersection between two regions, assuming they are sorted in ascending order by
---starting point.
---@param region1 Range6[]
---@param region2 Range6[]?
---@return Range6[]
local function clip_regions(region1, region2)
if not region2 then
return region1
end
local result = {}
local i, j = 1, 1
while i <= #region1 and j <= #region2 do
local r1 = region1[i]
local r2 = region2[j]
local intersection = Range.intersection(r1, r2)
if intersection then
table.insert(result, intersection)
end
-- Advance the range that ends earlier
if Range.cmp_pos.le(r1[3], r1[4], r2[3], r2[4]) then
i = i + 1
else
j = j + 1
end
end
return result
end
---@nodoc
---@class vim.treesitter.languagetree.InjectionElem
---@field combined boolean
@@ -886,8 +919,9 @@ end
---@param lang string
---@param combined boolean
---@param ranges Range6[]
---@param parent_ranges Range6[]?
---@param result table<string,Range6[][]>
local function add_injection(t, pattern, lang, combined, ranges, result)
local function add_injection(t, pattern, lang, combined, ranges, parent_ranges, result)
if #ranges == 0 then
-- Make sure not to add an empty range set as this is interpreted to mean the whole buffer.
return
@@ -898,7 +932,7 @@ local function add_injection(t, pattern, lang, combined, ranges, result)
end
if not combined then
table.insert(result[lang], ranges)
table.insert(result[lang], clip_regions(ranges, parent_ranges))
return
end
@@ -914,7 +948,7 @@ local function add_injection(t, pattern, lang, combined, ranges, result)
table.insert(result[lang], regions)
end
for _, range in ipairs(ranges) do
for _, range in ipairs(clip_regions(ranges, parent_ranges)) do
table.insert(t[lang][pattern], range)
end
end
@@ -1007,10 +1041,11 @@ function LanguageTree:_get_injections(range, thread_state)
local full_scan = range == true or self._injection_query.has_combined_injections
for _, tree in pairs(self._trees) do
for tree_index, tree in pairs(self._trees) do
---@type vim.treesitter.languagetree.Injection
local injections = {}
local root_node = tree:root()
local parent_ranges = self._regions and self._regions[tree_index] or nil
local start_line, end_line ---@type integer, integer
if full_scan then
start_line, _, end_line = root_node:range()
@@ -1023,7 +1058,7 @@ function LanguageTree:_get_injections(range, thread_state)
do
local lang, combined, ranges = self:_get_injection(match, metadata)
if lang then
add_injection(injections, pattern, lang, combined, ranges, result)
add_injection(injections, pattern, lang, combined, ranges, parent_ranges, result)
else
self:_log('match from injection query failed for pattern', pattern)
end