refactor(treesitter): redesign query iterating

Problem:

  `TSNode:_rawquery()` is complicated, has known issues and the Lua and
  C code is awkwardly coupled (see logic with `active`).

Solution:

  - Add `TSQueryCursor` and `TSQueryMatch` bindings.
  - Replace `TSNode:_rawquery()` with `TSQueryCursor:next_capture()` and `TSQueryCursor:next_match()`
  - Do more stuff in Lua
  - API for `Query:iter_captures()` and `Query:iter_matches()` remains the same.
  - `treesitter.c` no longer contains any logic related to predicates.
  - Add `match_limit` option to `iter_matches()`. Default is still 256.
This commit is contained in:
Lewis Russell
2024-03-18 23:19:01 +00:00
committed by Lewis Russell
parent 16a416cb3c
commit aca2048bcd
6 changed files with 250 additions and 207 deletions

View File

@@ -258,7 +258,7 @@ end)
--- handling the "any" vs "all" semantics. They are called from the
--- predicate_handlers table with the appropriate arguments for each predicate.
local impl = {
--- @param match vim.treesitter.query.TSMatch
--- @param match table<integer,TSNode[]>
--- @param source integer|string
--- @param predicate any[]
--- @param any boolean
@@ -293,7 +293,7 @@ local impl = {
return not any
end,
--- @param match vim.treesitter.query.TSMatch
--- @param match table<integer,TSNode[]>
--- @param source integer|string
--- @param predicate any[]
--- @param any boolean
@@ -333,7 +333,7 @@ local impl = {
end,
})
--- @param match vim.treesitter.query.TSMatch
--- @param match table<integer,TSNode[]>
--- @param source integer|string
--- @param predicate any[]
--- @param any boolean
@@ -356,7 +356,7 @@ local impl = {
end
end)(),
--- @param match vim.treesitter.query.TSMatch
--- @param match table<integer,TSNode[]>
--- @param source integer|string
--- @param predicate any[]
--- @param any boolean
@@ -383,13 +383,7 @@ local impl = {
end,
}
---@nodoc
---@class vim.treesitter.query.TSMatch
---@field pattern? integer
---@field active? boolean
---@field [integer] TSNode[]
---@alias TSPredicate fun(match: vim.treesitter.query.TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean
---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean
-- Predicate handler receive the following arguments
-- (match, pattern, bufnr, predicate)
@@ -504,7 +498,7 @@ predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?']
---@field [integer] vim.treesitter.query.TSMetadata
---@field [string] integer|string
---@alias TSDirective fun(match: vim.treesitter.query.TSMatch, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata)
---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata)
-- Predicate handler receive the following arguments
-- (match, pattern, bufnr, predicate)
@@ -726,13 +720,19 @@ local function is_directive(name)
end
---@private
---@param match vim.treesitter.query.TSMatch
---@param pattern integer
---@param match TSQueryMatch
---@param source integer|string
function Query:match_preds(match, pattern, source)
function Query:match_preds(match, source)
local _, pattern = match:info()
local preds = self.info.patterns[pattern]
for _, pred in pairs(preds or {}) do
if not preds then
return true
end
local captures = match:captures()
for _, pred in pairs(preds) do
-- Here we only want to return if a predicate DOES NOT match, and
-- continue on the other case. This way unknown predicates will not be considered,
-- which allows some testing and easier user extensibility (#12173).
@@ -754,7 +754,7 @@ function Query:match_preds(match, pattern, source)
return false
end
local pred_matches = handler(match, pattern, source, pred)
local pred_matches = handler(captures, pattern, source, pred)
if not xor(is_not, pred_matches) then
return false
@@ -765,23 +765,33 @@ function Query:match_preds(match, pattern, source)
end
---@private
---@param match vim.treesitter.query.TSMatch
---@param metadata vim.treesitter.query.TSMetadata
function Query:apply_directives(match, pattern, source, metadata)
---@param match TSQueryMatch
---@return vim.treesitter.query.TSMetadata metadata
function Query:apply_directives(match, source)
---@type vim.treesitter.query.TSMetadata
local metadata = {}
local _, pattern = match:info()
local preds = self.info.patterns[pattern]
for _, pred in pairs(preds or {}) do
if not preds then
return metadata
end
local captures = match:captures()
for _, pred in pairs(preds) do
if is_directive(pred[1]) then
local handler = directive_handlers[pred[1]]
if not handler then
error(string.format('No handler for %s', pred[1]))
return
end
handler(match, pattern, source, pred, metadata)
handler(captures, pattern, source, pred, metadata)
end
end
return metadata
end
--- Returns the start and stop value if set else the node's range.
@@ -831,8 +841,10 @@ end
---@param start? integer Starting line for the search. Defaults to `node:start()`.
---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
---
---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer, TSNode>):
---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer,TSNode[]>?):
--- capture id, capture node, metadata, match
---
---@note Captures are only returned if the query pattern of a specific capture contained predicates.
function Query:iter_captures(node, source, start, stop)
if type(source) == 'number' and source == 0 then
source = api.nvim_get_current_buf()
@@ -840,24 +852,38 @@ function Query:iter_captures(node, source, start, stop)
start, stop = value_or_node_range(start, stop, node)
local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, vim.treesitter.query.TSMatch
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
local max_match_id = -1
local function iter(end_line)
local capture, captured_node, match = raw_iter()
local capture, captured_node, match = cursor:next_capture()
if not capture then
return
end
local captures --- @type table<integer,TSNode[]>?
local match_id, pattern_index = match:info()
local metadata = {}
if match ~= nil then
local active = self:match_preds(match, match.pattern, source)
match.active = active
if not active then
local preds = self.info.patterns[pattern_index] or {}
if #preds > 0 and match_id > max_match_id then
captures = match:captures()
max_match_id = match_id
if not self:match_preds(match, source) then
cursor:remove_match(match_id)
if end_line and captured_node:range() > end_line then
return nil, captured_node, nil
end
return iter(end_line) -- tail call: try next match
end
self:apply_directives(match, match.pattern, source, metadata)
metadata = self:apply_directives(match, source)
end
return capture, captured_node, metadata, match
return capture, captured_node, metadata, captures
end
return iter
end
@@ -899,45 +925,54 @@ end
---@param opts? table Optional keyword arguments:
--- - max_start_depth (integer) if non-zero, sets the maximum start depth
--- for each match. This is used to prevent traversing too deep into a tree.
--- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256).
--- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes.
--- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is
--- incorrect behavior. This option will eventually become the default and removed.
---
---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata
function Query:iter_matches(node, source, start, stop, opts)
local all = opts and opts.all
opts = opts or {}
opts.match_limit = opts.match_limit or 256
if type(source) == 'number' and source == 0 then
source = api.nvim_get_current_buf()
end
start, stop = value_or_node_range(start, stop, node)
local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, vim.treesitter.query.TSMatch
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts)
local function iter()
local pattern, match = raw_iter()
local metadata = {}
local match = cursor:next_match()
if match ~= nil then
local active = self:match_preds(match, pattern, source)
if not active then
return iter() -- tail call: try next match
end
self:apply_directives(match, pattern, source, metadata)
if not match then
return
end
if not all then
local match_id, pattern = match:info()
if not self:match_preds(match, source) then
cursor:remove_match(match_id)
return iter() -- tail call: try next match
end
local metadata = self:apply_directives(match, source)
local captures = match:captures()
if not opts.all then
-- Convert the match table into the old buggy version for backward
-- compatibility. This is slow. Plugin authors, if you're reading this, set the "all"
-- option!
local old_match = {} ---@type table<integer, TSNode>
for k, v in pairs(match or {}) do
for k, v in pairs(captures or {}) do
old_match[k] = v[#v]
end
return pattern, old_match, metadata
end
return pattern, match, metadata
return pattern, captures, metadata
end
return iter
end