mirror of
				https://github.com/neovim/neovim.git
				synced 2025-11-04 01:34:25 +00:00 
			
		
		
		
	perf(treesitter): smarter languagetree invalidation
Problem:
  Treesitter injections are slow because all injected trees are invalidated on every change.
Solution:
    Implement smarter invalidation to avoid reparsing injected regions.
    - In on_bytes, try and update self._regions as best we can. This PR just offsets any regions after the change.
    - Add valid flags for each region in self._regions.
    - Call on_bytes recursively for all children.
       - We still need to run the query every time for the top level tree. I don't know how to avoid this. However, if the new injection ranges don't change, then we re-use the old trees and avoid reparsing children.
This should result in roughly a 2-3x reduction in tree parsing when the comment injections are enabled.
			
			
This commit is contained in:
		
							
								
								
									
										126
									
								
								runtime/lua/vim/treesitter/_range.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								runtime/lua/vim/treesitter/_range.lua
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,126 @@
 | 
			
		||||
local api = vim.api
 | 
			
		||||
 | 
			
		||||
local M = {}
 | 
			
		||||
 | 
			
		||||
---@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
 | 
			
		||||
---@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer}
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---@param a_row integer
 | 
			
		||||
---@param a_col integer
 | 
			
		||||
---@param b_row integer
 | 
			
		||||
---@param b_col integer
 | 
			
		||||
---@return integer
 | 
			
		||||
--- 1: a > b
 | 
			
		||||
--- 0: a == b
 | 
			
		||||
--- -1: a < b
 | 
			
		||||
local function cmp_pos(a_row, a_col, b_row, b_col)
 | 
			
		||||
  if a_row == b_row then
 | 
			
		||||
    if a_col > b_col then
 | 
			
		||||
      return 1
 | 
			
		||||
    elseif a_col < b_col then
 | 
			
		||||
      return -1
 | 
			
		||||
    else
 | 
			
		||||
      return 0
 | 
			
		||||
    end
 | 
			
		||||
  elseif a_row > b_row then
 | 
			
		||||
    return 1
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  return -1
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
M.cmp_pos = {
 | 
			
		||||
  lt = function(...)
 | 
			
		||||
    return cmp_pos(...) == -1
 | 
			
		||||
  end,
 | 
			
		||||
  le = function(...)
 | 
			
		||||
    return cmp_pos(...) ~= 1
 | 
			
		||||
  end,
 | 
			
		||||
  gt = function(...)
 | 
			
		||||
    return cmp_pos(...) == 1
 | 
			
		||||
  end,
 | 
			
		||||
  ge = function(...)
 | 
			
		||||
    return cmp_pos(...) ~= -1
 | 
			
		||||
  end,
 | 
			
		||||
  eq = function(...)
 | 
			
		||||
    return cmp_pos(...) == 0
 | 
			
		||||
  end,
 | 
			
		||||
  ne = function(...)
 | 
			
		||||
    return cmp_pos(...) ~= 0
 | 
			
		||||
  end,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
setmetatable(M.cmp_pos, { __call = cmp_pos })
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---@param r1 Range4|Range6
 | 
			
		||||
---@param r2 Range4|Range6
 | 
			
		||||
---@return boolean
 | 
			
		||||
function M.intercepts(r1, r2)
 | 
			
		||||
  local off_1 = #r1 == 6 and 1 or 0
 | 
			
		||||
  local off_2 = #r1 == 6 and 1 or 0
 | 
			
		||||
 | 
			
		||||
  local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1]
 | 
			
		||||
  local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]
 | 
			
		||||
 | 
			
		||||
  -- r1 is above r2
 | 
			
		||||
  if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
 | 
			
		||||
    return false
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  -- r1 is below r2
 | 
			
		||||
  if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
 | 
			
		||||
    return false
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  return true
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---@param r1 Range4|Range6
 | 
			
		||||
---@param r2 Range4|Range6
 | 
			
		||||
---@return boolean whether r1 contains r2
 | 
			
		||||
function M.contains(r1, r2)
 | 
			
		||||
  local off_1 = #r1 == 6 and 1 or 0
 | 
			
		||||
  local off_2 = #r1 == 6 and 1 or 0
 | 
			
		||||
 | 
			
		||||
  local srow_1, scol_1, erow_1, ecol_1 = r1[1], r2[2], r1[3 + off_1], r1[4 + off_1]
 | 
			
		||||
  local srow_2, scol_2, erow_2, ecol_2 = r2[1], r2[2], r2[3 + off_2], r2[4 + off_2]
 | 
			
		||||
 | 
			
		||||
  -- start doesn't fit
 | 
			
		||||
  if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
 | 
			
		||||
    return false
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  -- end doesn't fit
 | 
			
		||||
  if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
 | 
			
		||||
    return false
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  return true
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---@param source integer|string
 | 
			
		||||
---@param range Range4
 | 
			
		||||
---@return Range6
 | 
			
		||||
function M.add_bytes(source, range)
 | 
			
		||||
  local start_row, start_col, end_row, end_col = range[1], range[2], range[3], range[4]
 | 
			
		||||
  local start_byte = 0
 | 
			
		||||
  local end_byte = 0
 | 
			
		||||
  -- TODO(vigoux): proper byte computation here, and account for EOL ?
 | 
			
		||||
  if type(source) == 'number' then
 | 
			
		||||
    -- Easy case, this is a buffer parser
 | 
			
		||||
    start_byte = api.nvim_buf_get_offset(source, start_row) + start_col
 | 
			
		||||
    end_byte = api.nvim_buf_get_offset(source, end_row) + end_col
 | 
			
		||||
  elseif type(source) == 'string' then
 | 
			
		||||
    -- string parser, single `\n` delimited string
 | 
			
		||||
    start_byte = vim.fn.byteidx(source, start_col)
 | 
			
		||||
    end_byte = vim.fn.byteidx(source, end_col)
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  return { start_row, start_col, start_byte, end_row, end_col, end_byte }
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
return M
 | 
			
		||||
@@ -1,9 +1,8 @@
 | 
			
		||||
local a = vim.api
 | 
			
		||||
local query = require('vim.treesitter.query')
 | 
			
		||||
local language = require('vim.treesitter.language')
 | 
			
		||||
local Range = require('vim.treesitter._range')
 | 
			
		||||
 | 
			
		||||
---@alias Range {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
 | 
			
		||||
--
 | 
			
		||||
---@alias TSCallbackName
 | 
			
		||||
---| 'changedtree'
 | 
			
		||||
---| 'bytes'
 | 
			
		||||
@@ -24,11 +23,13 @@ local language = require('vim.treesitter.language')
 | 
			
		||||
---@field private _injection_query Query Queries defining injected languages
 | 
			
		||||
---@field private _opts table Options
 | 
			
		||||
---@field private _parser TSParser Parser for language
 | 
			
		||||
---@field private _regions Range[][] List of regions this tree should manage and parse
 | 
			
		||||
---@field private _regions Range6[][] List of regions this tree should manage and parse
 | 
			
		||||
---@field private _lang string Language name
 | 
			
		||||
---@field private _source (integer|string) Buffer or string to parse
 | 
			
		||||
---@field private _trees TSTree[] Reference to parsed tree (one for each language)
 | 
			
		||||
---@field private _valid boolean If the parsed tree is valid
 | 
			
		||||
---@field private _valid boolean|table<integer,true> If the parsed tree is valid
 | 
			
		||||
--- TODO(lewis6991): combine _regions, _valid and _trees
 | 
			
		||||
---@field private _is_child boolean
 | 
			
		||||
local LanguageTree = {}
 | 
			
		||||
 | 
			
		||||
---@class LanguageTreeOpts
 | 
			
		||||
@@ -114,6 +115,9 @@ end
 | 
			
		||||
--- If the tree is invalid, call `parse()`.
 | 
			
		||||
--- This will return the updated tree.
 | 
			
		||||
function LanguageTree:is_valid()
 | 
			
		||||
  if type(self._valid) == 'table' then
 | 
			
		||||
    return #self._valid == #self._regions
 | 
			
		||||
  end
 | 
			
		||||
  return self._valid
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
@@ -127,6 +131,16 @@ function LanguageTree:source()
 | 
			
		||||
  return self._source
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---This is only exposed so it can be wrapped for profiling
 | 
			
		||||
---@param old_tree TSTree
 | 
			
		||||
---@return TSTree, integer[]
 | 
			
		||||
function LanguageTree:_parse_tree(old_tree)
 | 
			
		||||
  local tree, tree_changes = self._parser:parse(old_tree, self._source)
 | 
			
		||||
  self:_do_callback('changedtree', tree_changes, tree)
 | 
			
		||||
  return tree, tree_changes
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- Parses all defined regions using a treesitter parser
 | 
			
		||||
--- for the language this tree represents.
 | 
			
		||||
--- This will run the injection query for this language to
 | 
			
		||||
@@ -135,35 +149,27 @@ end
 | 
			
		||||
---@return TSTree[]
 | 
			
		||||
---@return table|nil Change list
 | 
			
		||||
function LanguageTree:parse()
 | 
			
		||||
  if self._valid then
 | 
			
		||||
  if self:is_valid() then
 | 
			
		||||
    return self._trees
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  local parser = self._parser
 | 
			
		||||
  local changes = {}
 | 
			
		||||
 | 
			
		||||
  local old_trees = self._trees
 | 
			
		||||
  self._trees = {}
 | 
			
		||||
 | 
			
		||||
  -- If there are no ranges, set to an empty list
 | 
			
		||||
  -- so the included ranges in the parser are cleared.
 | 
			
		||||
  if self._regions and #self._regions > 0 then
 | 
			
		||||
  if #self._regions > 0 then
 | 
			
		||||
    for i, ranges in ipairs(self._regions) do
 | 
			
		||||
      local old_tree = old_trees[i]
 | 
			
		||||
      parser:set_included_ranges(ranges)
 | 
			
		||||
 | 
			
		||||
      local tree, tree_changes = parser:parse(old_tree, self._source)
 | 
			
		||||
      self:_do_callback('changedtree', tree_changes, tree)
 | 
			
		||||
 | 
			
		||||
      table.insert(self._trees, tree)
 | 
			
		||||
      vim.list_extend(changes, tree_changes)
 | 
			
		||||
      if not self._valid or not self._valid[i] then
 | 
			
		||||
        self._parser:set_included_ranges(ranges)
 | 
			
		||||
        local tree, tree_changes = self:_parse_tree(self._trees[i])
 | 
			
		||||
        self._trees[i] = tree
 | 
			
		||||
        vim.list_extend(changes, tree_changes)
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
  else
 | 
			
		||||
    local tree, tree_changes = parser:parse(old_trees[1], self._source)
 | 
			
		||||
    self:_do_callback('changedtree', tree_changes, tree)
 | 
			
		||||
 | 
			
		||||
    table.insert(self._trees, tree)
 | 
			
		||||
    vim.list_extend(changes, tree_changes)
 | 
			
		||||
    local tree, tree_changes = self:_parse_tree(self._trees[1])
 | 
			
		||||
    self._trees = { tree }
 | 
			
		||||
    changes = tree_changes
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  local injections_by_lang = self:_get_injections()
 | 
			
		||||
@@ -249,6 +255,7 @@ function LanguageTree:add_child(lang)
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  self._children[lang] = LanguageTree.new(self._source, lang, self._opts)
 | 
			
		||||
  self._children[lang]._is_child = true
 | 
			
		||||
 | 
			
		||||
  self:invalidate()
 | 
			
		||||
  self:_do_callback('child_added', self._children[lang])
 | 
			
		||||
@@ -298,43 +305,35 @@ end
 | 
			
		||||
--- This allows for embedded languages to be parsed together across different
 | 
			
		||||
--- nodes, which is useful for templating languages like ERB and EJS.
 | 
			
		||||
---
 | 
			
		||||
--- Note: This call invalidates the tree and requires it to be parsed again.
 | 
			
		||||
---
 | 
			
		||||
---@private
 | 
			
		||||
---@param regions integer[][][] List of regions this tree should manage and parse.
 | 
			
		||||
---@param regions Range4[][] List of regions this tree should manage and parse.
 | 
			
		||||
function LanguageTree:set_included_regions(regions)
 | 
			
		||||
  -- Transform the tables from 4 element long to 6 element long (with byte offset)
 | 
			
		||||
  for _, region in ipairs(regions) do
 | 
			
		||||
    for i, range in ipairs(region) do
 | 
			
		||||
      if type(range) == 'table' and #range == 4 then
 | 
			
		||||
        ---@diagnostic disable-next-line:no-unknown
 | 
			
		||||
        local start_row, start_col, end_row, end_col = unpack(range)
 | 
			
		||||
        local start_byte = 0
 | 
			
		||||
        local end_byte = 0
 | 
			
		||||
        local source = self._source
 | 
			
		||||
        -- TODO(vigoux): proper byte computation here, and account for EOL ?
 | 
			
		||||
        if type(source) == 'number' then
 | 
			
		||||
          -- Easy case, this is a buffer parser
 | 
			
		||||
          start_byte = a.nvim_buf_get_offset(source, start_row) + start_col
 | 
			
		||||
          end_byte = a.nvim_buf_get_offset(source, end_row) + end_col
 | 
			
		||||
        elseif type(self._source) == 'string' then
 | 
			
		||||
          -- string parser, single `\n` delimited string
 | 
			
		||||
          start_byte = vim.fn.byteidx(self._source, start_col)
 | 
			
		||||
          end_byte = vim.fn.byteidx(self._source, end_col)
 | 
			
		||||
        end
 | 
			
		||||
        region[i] = Range.add_bytes(self._source, range)
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
        region[i] = { start_row, start_col, start_byte, end_row, end_col, end_byte }
 | 
			
		||||
  if #self._regions ~= #regions then
 | 
			
		||||
    self._trees = {}
 | 
			
		||||
    self:invalidate()
 | 
			
		||||
  elseif self._valid ~= false then
 | 
			
		||||
    if self._valid == true then
 | 
			
		||||
      self._valid = {}
 | 
			
		||||
    end
 | 
			
		||||
    for i = 1, #regions do
 | 
			
		||||
      self._valid[i] = true
 | 
			
		||||
      if not vim.deep_equal(self._regions[i], regions[i]) then
 | 
			
		||||
        self._valid[i] = nil
 | 
			
		||||
        self._trees[i] = nil
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  self._regions = regions
 | 
			
		||||
  -- Trees are no longer valid now that we have changed regions.
 | 
			
		||||
  -- TODO(vigoux,steelsojka): Look into doing this smarter so we can use some of the
 | 
			
		||||
  --                          old trees for incremental parsing. Currently, this only
 | 
			
		||||
  --                          affects injected languages.
 | 
			
		||||
  self._trees = {}
 | 
			
		||||
  self:invalidate()
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- Gets the set of included regions
 | 
			
		||||
@@ -346,10 +345,10 @@ end
 | 
			
		||||
---@param node TSNode
 | 
			
		||||
---@param id integer
 | 
			
		||||
---@param metadata TSMetadata
 | 
			
		||||
---@return Range
 | 
			
		||||
---@return Range4
 | 
			
		||||
local function get_range_from_metadata(node, id, metadata)
 | 
			
		||||
  if metadata[id] and metadata[id].range then
 | 
			
		||||
    return metadata[id].range --[[@as Range]]
 | 
			
		||||
    return metadata[id].range --[[@as Range4]]
 | 
			
		||||
  end
 | 
			
		||||
  return { node:range() }
 | 
			
		||||
end
 | 
			
		||||
@@ -378,7 +377,7 @@ function LanguageTree:_get_injections()
 | 
			
		||||
      self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1)
 | 
			
		||||
    do
 | 
			
		||||
      local lang = nil ---@type string
 | 
			
		||||
      local ranges = {} ---@type Range[]
 | 
			
		||||
      local ranges = {} ---@type Range4[]
 | 
			
		||||
      local combined = metadata.combined ---@type boolean
 | 
			
		||||
 | 
			
		||||
      -- Directives can configure how injections are captured as well as actual node captures.
 | 
			
		||||
@@ -408,6 +407,7 @@ function LanguageTree:_get_injections()
 | 
			
		||||
 | 
			
		||||
        -- Lang should override any other language tag
 | 
			
		||||
        if name == 'language' and not lang then
 | 
			
		||||
          ---@diagnostic disable-next-line
 | 
			
		||||
          lang = query.get_node_text(node, self._source, { metadata = metadata[id] })
 | 
			
		||||
        elseif name == 'combined' then
 | 
			
		||||
          combined = true
 | 
			
		||||
@@ -426,6 +426,8 @@ function LanguageTree:_get_injections()
 | 
			
		||||
        end
 | 
			
		||||
      end
 | 
			
		||||
 | 
			
		||||
      assert(type(lang) == 'string')
 | 
			
		||||
 | 
			
		||||
      -- Each tree index should be isolated from the other nodes.
 | 
			
		||||
      if not injections[tree_index] then
 | 
			
		||||
        injections[tree_index] = {}
 | 
			
		||||
@@ -446,7 +448,7 @@ function LanguageTree:_get_injections()
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  ---@type table<string,Range[][]>
 | 
			
		||||
  ---@type table<string,Range4[][]>
 | 
			
		||||
  local result = {}
 | 
			
		||||
 | 
			
		||||
  -- Generate a map by lang of node lists.
 | 
			
		||||
@@ -485,6 +487,45 @@ function LanguageTree:_do_callback(cb_name, ...)
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---@param regions Range6[][]
 | 
			
		||||
---@param old_range Range6
 | 
			
		||||
---@param new_range Range6
 | 
			
		||||
---@return table<integer, true> region indices to invalidate
 | 
			
		||||
local function update_regions(regions, old_range, new_range)
 | 
			
		||||
  ---@type table<integer,true>
 | 
			
		||||
  local valid = {}
 | 
			
		||||
 | 
			
		||||
  for i, ranges in ipairs(regions or {}) do
 | 
			
		||||
    valid[i] = true
 | 
			
		||||
    for j, r in ipairs(ranges) do
 | 
			
		||||
      if Range.intercepts(r, old_range) then
 | 
			
		||||
        valid[i] = nil
 | 
			
		||||
        break
 | 
			
		||||
      end
 | 
			
		||||
 | 
			
		||||
      -- Range after change. Adjust
 | 
			
		||||
      if Range.cmp_pos.gt(r[1], r[2], old_range[4], old_range[5]) then
 | 
			
		||||
        local byte_offset = new_range[6] - old_range[6]
 | 
			
		||||
        local row_offset = new_range[4] - old_range[4]
 | 
			
		||||
 | 
			
		||||
        -- Update the range to avoid invalidation in set_included_regions()
 | 
			
		||||
        -- which will compare the regions against the parsed injection regions
 | 
			
		||||
        ranges[j] = {
 | 
			
		||||
          r[1] + row_offset,
 | 
			
		||||
          r[2],
 | 
			
		||||
          r[3] + byte_offset,
 | 
			
		||||
          r[4] + row_offset,
 | 
			
		||||
          r[5],
 | 
			
		||||
          r[6] + byte_offset,
 | 
			
		||||
        }
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  return valid
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---@param bufnr integer
 | 
			
		||||
---@param changed_tick integer
 | 
			
		||||
@@ -510,14 +551,53 @@ function LanguageTree:_on_bytes(
 | 
			
		||||
  new_col,
 | 
			
		||||
  new_byte
 | 
			
		||||
)
 | 
			
		||||
  self:invalidate()
 | 
			
		||||
 | 
			
		||||
  local old_end_col = old_col + ((old_row == 0) and start_col or 0)
 | 
			
		||||
  local new_end_col = new_col + ((new_row == 0) and start_col or 0)
 | 
			
		||||
 | 
			
		||||
  -- Edit all trees recursively, together BEFORE emitting a bytes callback.
 | 
			
		||||
  -- In most cases this callback should only be called from the root tree.
 | 
			
		||||
  self:for_each_tree(function(tree)
 | 
			
		||||
  local old_range = {
 | 
			
		||||
    start_row,
 | 
			
		||||
    start_col,
 | 
			
		||||
    start_byte,
 | 
			
		||||
    start_row + old_row,
 | 
			
		||||
    old_end_col,
 | 
			
		||||
    start_byte + old_byte,
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  local new_range = {
 | 
			
		||||
    start_row,
 | 
			
		||||
    start_col,
 | 
			
		||||
    start_byte,
 | 
			
		||||
    start_row + new_row,
 | 
			
		||||
    new_end_col,
 | 
			
		||||
    start_byte + new_byte,
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  local valid_regions = update_regions(self._regions, old_range, new_range)
 | 
			
		||||
 | 
			
		||||
  if #self._regions == 0 or #valid_regions == 0 then
 | 
			
		||||
    self._valid = false
 | 
			
		||||
  else
 | 
			
		||||
    self._valid = valid_regions
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  for _, child in pairs(self._children) do
 | 
			
		||||
    child:_on_bytes(
 | 
			
		||||
      bufnr,
 | 
			
		||||
      changed_tick,
 | 
			
		||||
      start_row,
 | 
			
		||||
      start_col,
 | 
			
		||||
      start_byte,
 | 
			
		||||
      old_row,
 | 
			
		||||
      old_col,
 | 
			
		||||
      old_byte,
 | 
			
		||||
      new_row,
 | 
			
		||||
      new_col,
 | 
			
		||||
      new_byte
 | 
			
		||||
    )
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  -- Edit trees together BEFORE emitting a bytes callback.
 | 
			
		||||
  for _, tree in ipairs(self._trees) do
 | 
			
		||||
    tree:edit(
 | 
			
		||||
      start_byte,
 | 
			
		||||
      start_byte + old_byte,
 | 
			
		||||
@@ -529,22 +609,24 @@ function LanguageTree:_on_bytes(
 | 
			
		||||
      start_row + new_row,
 | 
			
		||||
      new_end_col
 | 
			
		||||
    )
 | 
			
		||||
  end)
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  self:_do_callback(
 | 
			
		||||
    'bytes',
 | 
			
		||||
    bufnr,
 | 
			
		||||
    changed_tick,
 | 
			
		||||
    start_row,
 | 
			
		||||
    start_col,
 | 
			
		||||
    start_byte,
 | 
			
		||||
    old_row,
 | 
			
		||||
    old_col,
 | 
			
		||||
    old_byte,
 | 
			
		||||
    new_row,
 | 
			
		||||
    new_col,
 | 
			
		||||
    new_byte
 | 
			
		||||
  )
 | 
			
		||||
  if not self._is_child then
 | 
			
		||||
    self:_do_callback(
 | 
			
		||||
      'bytes',
 | 
			
		||||
      bufnr,
 | 
			
		||||
      changed_tick,
 | 
			
		||||
      start_row,
 | 
			
		||||
      start_col,
 | 
			
		||||
      start_byte,
 | 
			
		||||
      old_row,
 | 
			
		||||
      old_col,
 | 
			
		||||
      old_byte,
 | 
			
		||||
      new_row,
 | 
			
		||||
      new_col,
 | 
			
		||||
      new_byte
 | 
			
		||||
    )
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
@@ -595,19 +677,15 @@ end
 | 
			
		||||
 | 
			
		||||
---@private
 | 
			
		||||
---@param tree TSTree
 | 
			
		||||
---@param range Range
 | 
			
		||||
---@param range Range4
 | 
			
		||||
---@return boolean
 | 
			
		||||
local function tree_contains(tree, range)
 | 
			
		||||
  local start_row, start_col, end_row, end_col = tree:root():range()
 | 
			
		||||
  local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2])
 | 
			
		||||
  local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4])
 | 
			
		||||
 | 
			
		||||
  return start_fits and end_fits
 | 
			
		||||
  return Range.contains({ tree:root():range() }, range)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- Determines whether {range} is contained in the |LanguageTree|.
 | 
			
		||||
---
 | 
			
		||||
---@param range Range `{ start_line, start_col, end_line, end_col }`
 | 
			
		||||
---@param range Range4 `{ start_line, start_col, end_line, end_col }`
 | 
			
		||||
---@return boolean
 | 
			
		||||
function LanguageTree:contains(range)
 | 
			
		||||
  for _, tree in pairs(self._trees) do
 | 
			
		||||
@@ -621,7 +699,7 @@ end
 | 
			
		||||
 | 
			
		||||
--- Gets the tree that contains {range}.
 | 
			
		||||
---
 | 
			
		||||
---@param range Range `{ start_line, start_col, end_line, end_col }`
 | 
			
		||||
---@param range Range4 `{ start_line, start_col, end_line, end_col }`
 | 
			
		||||
---@param opts table|nil Optional keyword arguments:
 | 
			
		||||
---             - ignore_injections boolean Ignore injected languages (default true)
 | 
			
		||||
---@return TSTree|nil
 | 
			
		||||
@@ -631,10 +709,9 @@ function LanguageTree:tree_for_range(range, opts)
 | 
			
		||||
 | 
			
		||||
  if not ignore then
 | 
			
		||||
    for _, child in pairs(self._children) do
 | 
			
		||||
      for _, tree in pairs(child:trees()) do
 | 
			
		||||
        if tree_contains(tree, range) then
 | 
			
		||||
          return tree
 | 
			
		||||
        end
 | 
			
		||||
      local tree = child:tree_for_range(range, opts)
 | 
			
		||||
      if tree then
 | 
			
		||||
        return tree
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
@@ -650,7 +727,7 @@ end
 | 
			
		||||
 | 
			
		||||
--- Gets the smallest named node that contains {range}.
 | 
			
		||||
---
 | 
			
		||||
---@param range Range `{ start_line, start_col, end_line, end_col }`
 | 
			
		||||
---@param range Range4 `{ start_line, start_col, end_line, end_col }`
 | 
			
		||||
---@param opts table|nil Optional keyword arguments:
 | 
			
		||||
---             - ignore_injections boolean Ignore injected languages (default true)
 | 
			
		||||
---@return TSNode | nil Found node
 | 
			
		||||
@@ -663,7 +740,7 @@ end
 | 
			
		||||
 | 
			
		||||
--- Gets the appropriate language that contains {range}.
 | 
			
		||||
---
 | 
			
		||||
---@param range Range `{ start_line, start_col, end_line, end_col }`
 | 
			
		||||
---@param range Range4 `{ start_line, start_col, end_line, end_col }`
 | 
			
		||||
---@return LanguageTree Managing {range}
 | 
			
		||||
function LanguageTree:language_for_range(range)
 | 
			
		||||
  for _, child in pairs(self._children) do
 | 
			
		||||
 
 | 
			
		||||
@@ -406,7 +406,7 @@ predicate_handlers['vim-match?'] = predicate_handlers['match?']
 | 
			
		||||
---@class TSMetadata
 | 
			
		||||
---@field [integer] TSMetadata
 | 
			
		||||
---@field [string] integer|string
 | 
			
		||||
---@field range Range
 | 
			
		||||
---@field range Range4
 | 
			
		||||
 | 
			
		||||
---@alias TSDirective fun(match: TSMatch, _, _, predicate: any[], metadata: TSMetadata)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -291,7 +291,7 @@ local types = { 'integer', 'number', 'string', 'table', 'list', 'boolean', 'func
 | 
			
		||||
local tagged_types = { 'TSNode', 'LanguageTree' }
 | 
			
		||||
 | 
			
		||||
-- Document these as 'table'
 | 
			
		||||
local alias_types = { 'Range' }
 | 
			
		||||
local alias_types = { 'Range4', 'Range6' }
 | 
			
		||||
 | 
			
		||||
--! \brief run the filter
 | 
			
		||||
function TLua2DoX_filter.readfile(this, AppStamp, Filename)
 | 
			
		||||
 
 | 
			
		||||
@@ -639,6 +639,17 @@ int x = INT_MAX;
 | 
			
		||||
          {1, 26, 1, 65}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
          {2, 29, 2, 68}  -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
        }, get_ranges())
 | 
			
		||||
 | 
			
		||||
        helpers.feed('ggo<esc>')
 | 
			
		||||
        eq(5, exec_lua("return #parser:children().c:trees()"))
 | 
			
		||||
        eq({
 | 
			
		||||
          {0, 0, 8, 0},   -- root tree
 | 
			
		||||
          {4, 14, 4, 17}, -- VALUE 123
 | 
			
		||||
          {5, 15, 5, 18}, -- VALUE1 123
 | 
			
		||||
          {6, 15, 6, 18}, -- VALUE2 123
 | 
			
		||||
          {2, 26, 2, 65}, -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
          {3, 29, 3, 68}  -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
        }, get_ranges())
 | 
			
		||||
      end)
 | 
			
		||||
    end)
 | 
			
		||||
 | 
			
		||||
@@ -660,6 +671,18 @@ int x = INT_MAX;
 | 
			
		||||
          {1, 26, 2, 68}  -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
                          -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
        }, get_ranges())
 | 
			
		||||
 | 
			
		||||
        helpers.feed('ggo<esc>')
 | 
			
		||||
        eq("table", exec_lua("return type(parser:children().c)"))
 | 
			
		||||
        eq(2, exec_lua("return #parser:children().c:trees()"))
 | 
			
		||||
        eq({
 | 
			
		||||
          {0, 0, 8, 0},   -- root tree
 | 
			
		||||
          {4, 14, 6, 18}, -- VALUE 123
 | 
			
		||||
                          -- VALUE1 123
 | 
			
		||||
                          -- VALUE2 123
 | 
			
		||||
          {2, 26, 3, 68}  -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
                          -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
        }, get_ranges())
 | 
			
		||||
      end)
 | 
			
		||||
    end)
 | 
			
		||||
 | 
			
		||||
@@ -688,6 +711,18 @@ int x = INT_MAX;
 | 
			
		||||
          {1, 26, 2, 68}  -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
                          -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
        }, get_ranges())
 | 
			
		||||
 | 
			
		||||
        helpers.feed('ggo<esc>')
 | 
			
		||||
        eq("table", exec_lua("return type(parser:children().c)"))
 | 
			
		||||
        eq(2, exec_lua("return #parser:children().c:trees()"))
 | 
			
		||||
        eq({
 | 
			
		||||
          {0, 0, 8, 0},   -- root tree
 | 
			
		||||
          {4, 14, 6, 18}, -- VALUE 123
 | 
			
		||||
                          -- VALUE1 123
 | 
			
		||||
                          -- VALUE2 123
 | 
			
		||||
          {2, 26, 3, 68}  -- READ_STRING(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
                          -- READ_STRING_OK(x, y) (char_u *)read_string((x), (size_t)(y))
 | 
			
		||||
        }, get_ranges())
 | 
			
		||||
      end)
 | 
			
		||||
 | 
			
		||||
      it("should not inject bad languages", function()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user