mirror of
				https://github.com/neovim/neovim.git
				synced 2025-11-04 09:44:31 +00:00 
			
		
		
		
	refactor: split predicates and directives
This commit is contained in:
		@@ -299,6 +299,8 @@ local function on_line_impl(self, buf, line, is_spell_nav)
 | 
				
			|||||||
        state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1)
 | 
					        state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1)
 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    local captures = state.highlighter_query:query().captures
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    while line >= state.next_row do
 | 
					    while line >= state.next_row do
 | 
				
			||||||
      local capture, node, metadata, match = state.iter(line)
 | 
					      local capture, node, metadata, match = state.iter(line)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -311,7 +313,7 @@ local function on_line_impl(self, buf, line, is_spell_nav)
 | 
				
			|||||||
      if capture then
 | 
					      if capture then
 | 
				
			||||||
        local hl = state.highlighter_query:get_hl_from_capture(capture)
 | 
					        local hl = state.highlighter_query:get_hl_from_capture(capture)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        local capture_name = state.highlighter_query:query().captures[capture]
 | 
					        local capture_name = captures[capture]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        local spell, spell_pri_offset = get_spell(capture_name)
 | 
					        local spell, spell_pri_offset = get_spell(capture_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,6 +7,59 @@ local memoize = vim.func._memoize
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
local M = {}
 | 
					local M = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					local function is_directive(name)
 | 
				
			||||||
 | 
					  return string.sub(name, -1) == '!'
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@nodoc
 | 
				
			||||||
 | 
					---@class vim.treesitter.query.ProcessedPredicate
 | 
				
			||||||
 | 
					---@field [1] string predicate name
 | 
				
			||||||
 | 
					---@field [2] boolean should match
 | 
				
			||||||
 | 
					---@field [3] (integer|string)[] the original predicate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@alias vim.treesitter.query.ProcessedDirective (integer|string)[]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---@nodoc
 | 
				
			||||||
 | 
					---@class vim.treesitter.query.ProcessedPattern {
 | 
				
			||||||
 | 
					---@field predicates vim.treesitter.query.ProcessedPredicate[]
 | 
				
			||||||
 | 
					---@field directives vim.treesitter.query.ProcessedDirective[]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					--- Splits the query patterns into predicates and directives.
 | 
				
			||||||
 | 
					---@param patterns table<integer, (integer|string)[][]>
 | 
				
			||||||
 | 
					---@return table<integer, vim.treesitter.query.ProcessedPattern>
 | 
				
			||||||
 | 
					local function process_patterns(patterns)
 | 
				
			||||||
 | 
					  ---@type table<integer, vim.treesitter.query.ProcessedPattern>
 | 
				
			||||||
 | 
					  local processed_patterns = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for k, pattern_list in pairs(patterns) do
 | 
				
			||||||
 | 
					    ---@type vim.treesitter.query.ProcessedPredicate[]
 | 
				
			||||||
 | 
					    local predicates = {}
 | 
				
			||||||
 | 
					    ---@type vim.treesitter.query.ProcessedDirective[]
 | 
				
			||||||
 | 
					    local directives = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for _, pattern in ipairs(pattern_list) do
 | 
				
			||||||
 | 
					      -- Note: tree-sitter strips the leading # from predicates for us.
 | 
				
			||||||
 | 
					      local pred_name = pattern[1]
 | 
				
			||||||
 | 
					      ---@cast pred_name string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      if is_directive(pred_name) then
 | 
				
			||||||
 | 
					        table.insert(directives, pattern)
 | 
				
			||||||
 | 
					      else
 | 
				
			||||||
 | 
					        local should_match = true
 | 
				
			||||||
 | 
					        if pred_name:match('^not%-') then
 | 
				
			||||||
 | 
					          pred_name = pred_name:sub(5)
 | 
				
			||||||
 | 
					          should_match = false
 | 
				
			||||||
 | 
					        end
 | 
				
			||||||
 | 
					        table.insert(predicates, { pred_name, should_match, pattern })
 | 
				
			||||||
 | 
					      end
 | 
				
			||||||
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    processed_patterns[k] = { predicates = predicates, directives = directives }
 | 
				
			||||||
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return processed_patterns
 | 
				
			||||||
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---@nodoc
 | 
					---@nodoc
 | 
				
			||||||
---Parsed query, see |vim.treesitter.query.parse()|
 | 
					---Parsed query, see |vim.treesitter.query.parse()|
 | 
				
			||||||
---
 | 
					---
 | 
				
			||||||
@@ -15,6 +68,7 @@ local M = {}
 | 
				
			|||||||
---@field captures string[] list of (unique) capture names defined in query
 | 
					---@field captures string[] list of (unique) capture names defined in query
 | 
				
			||||||
---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives)
 | 
					---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives)
 | 
				
			||||||
---@field query TSQuery userdata query object
 | 
					---@field query TSQuery userdata query object
 | 
				
			||||||
 | 
					---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern>
 | 
				
			||||||
local Query = {}
 | 
					local Query = {}
 | 
				
			||||||
Query.__index = Query
 | 
					Query.__index = Query
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -33,6 +87,7 @@ function Query.new(lang, ts_query)
 | 
				
			|||||||
    patterns = query_info.patterns,
 | 
					    patterns = query_info.patterns,
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  self.captures = self.info.captures
 | 
					  self.captures = self.info.captures
 | 
				
			||||||
 | 
					  self._processed_patterns = process_patterns(self.info.patterns)
 | 
				
			||||||
  return self
 | 
					  return self
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -751,67 +806,50 @@ function M.list_predicates()
 | 
				
			|||||||
  return vim.tbl_keys(predicate_handlers)
 | 
					  return vim.tbl_keys(predicate_handlers)
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
local function xor(x, y)
 | 
					 | 
				
			||||||
  return (x or y) and not (x and y)
 | 
					 | 
				
			||||||
end
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
local function is_directive(name)
 | 
					 | 
				
			||||||
  return string.sub(name, -1) == '!'
 | 
					 | 
				
			||||||
end
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
---@private
 | 
					---@private
 | 
				
			||||||
---@param match TSQueryMatch
 | 
					---@param pattern_i integer
 | 
				
			||||||
 | 
					---@param predicates vim.treesitter.query.ProcessedPredicate[]
 | 
				
			||||||
 | 
					---@param captures table<integer, TSNode[]>
 | 
				
			||||||
---@param source integer|string
 | 
					---@param source integer|string
 | 
				
			||||||
function Query:match_preds(preds, pattern, captures, source)
 | 
					---@return boolean whether the predicates match
 | 
				
			||||||
  for _, pred in pairs(preds) do
 | 
					function Query:_match_predicates(predicates, pattern_i, captures, source)
 | 
				
			||||||
    -- Here we only want to return if a predicate DOES NOT match, and
 | 
					  for _, predicate in ipairs(predicates) do
 | 
				
			||||||
    -- continue on the other case. This way unknown predicates will not be considered,
 | 
					    local processed_name = predicate[1]
 | 
				
			||||||
    -- which allows some testing and easier user extensibility (#12173).
 | 
					    local should_match = predicate[2]
 | 
				
			||||||
    -- Also, tree-sitter strips the leading # from predicates for us.
 | 
					    local orig_predicate = predicate[3]
 | 
				
			||||||
    local is_not = false
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    -- Skip over directives... they will get processed after all the predicates.
 | 
					    local handler = predicate_handlers[processed_name]
 | 
				
			||||||
    if not is_directive(pred[1]) then
 | 
					    if not handler then
 | 
				
			||||||
      local pred_name = pred[1]
 | 
					      error(string.format('No handler for %s', orig_predicate[1]))
 | 
				
			||||||
      if pred_name:match('^not%-') then
 | 
					      return false
 | 
				
			||||||
        pred_name = pred_name:sub(5)
 | 
					    end
 | 
				
			||||||
        is_not = true
 | 
					 | 
				
			||||||
      end
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
      local handler = predicate_handlers[pred_name]
 | 
					    local does_match = handler(captures, pattern_i, source, orig_predicate)
 | 
				
			||||||
 | 
					    if does_match ~= should_match then
 | 
				
			||||||
      if not handler then
 | 
					      return false
 | 
				
			||||||
        error(string.format('No handler for %s', pred[1]))
 | 
					 | 
				
			||||||
        return false
 | 
					 | 
				
			||||||
      end
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      local pred_matches = handler(captures, pattern, source, pred)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      if not xor(is_not, pred_matches) then
 | 
					 | 
				
			||||||
        return false
 | 
					 | 
				
			||||||
      end
 | 
					 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
  return true
 | 
					  return true
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---@private
 | 
					---@private
 | 
				
			||||||
---@param match TSQueryMatch
 | 
					---@param pattern_i integer
 | 
				
			||||||
 | 
					---@param directives vim.treesitter.query.ProcessedDirective[]
 | 
				
			||||||
 | 
					---@param source integer|string
 | 
				
			||||||
 | 
					---@param captures table<integer, TSNode[]>
 | 
				
			||||||
---@return vim.treesitter.query.TSMetadata metadata
 | 
					---@return vim.treesitter.query.TSMetadata metadata
 | 
				
			||||||
function Query:apply_directives(preds, pattern, captures, source)
 | 
					function Query:_apply_directives(directives, pattern_i, captures, source)
 | 
				
			||||||
  ---@type vim.treesitter.query.TSMetadata
 | 
					  ---@type vim.treesitter.query.TSMetadata
 | 
				
			||||||
  local metadata = {}
 | 
					  local metadata = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  for _, pred in pairs(preds) do
 | 
					  for _, directive in pairs(directives) do
 | 
				
			||||||
    if is_directive(pred[1]) then
 | 
					    local handler = directive_handlers[directive[1]]
 | 
				
			||||||
      local handler = directive_handlers[pred[1]]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
      if not handler then
 | 
					    if not handler then
 | 
				
			||||||
        error(string.format('No handler for %s', pred[1]))
 | 
					      error(string.format('No handler for %s', directive[1]))
 | 
				
			||||||
      end
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      handler(captures, pattern, source, pred, metadata)
 | 
					 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    handler(captures, pattern_i, source, directive, metadata)
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return metadata
 | 
					  return metadata
 | 
				
			||||||
@@ -835,12 +873,6 @@ local function value_or_node_range(start, stop, node)
 | 
				
			|||||||
  return start, stop
 | 
					  return start, stop
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
--- @param match TSQueryMatch
 | 
					 | 
				
			||||||
--- @return integer
 | 
					 | 
				
			||||||
local function match_id_hash(_, match)
 | 
					 | 
				
			||||||
  return (match:info())
 | 
					 | 
				
			||||||
end
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
--- Iterates over all captures from all matches in {node}.
 | 
					--- Iterates over all captures from all matches in {node}.
 | 
				
			||||||
---
 | 
					---
 | 
				
			||||||
--- {source} is required if the query contains predicates; then the caller
 | 
					--- {source} is required if the query contains predicates; then the caller
 | 
				
			||||||
@@ -897,7 +929,7 @@ function Query:iter_captures(node, source, start, stop)
 | 
				
			|||||||
      return
 | 
					      return
 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    local match_id, pattern = match:info()
 | 
					    local match_id, pattern_i = match:info()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    --- @type vim.treesitter.query.TSMetadata
 | 
					    --- @type vim.treesitter.query.TSMetadata
 | 
				
			||||||
    local metadata
 | 
					    local metadata
 | 
				
			||||||
@@ -906,11 +938,14 @@ function Query:iter_captures(node, source, start, stop)
 | 
				
			|||||||
    end
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not metadata then
 | 
					    if not metadata then
 | 
				
			||||||
      local preds = self.info.patterns[pattern]
 | 
					      metadata = {}
 | 
				
			||||||
      if preds then
 | 
					
 | 
				
			||||||
 | 
					      local processed_pattern = self._processed_patterns[pattern_i]
 | 
				
			||||||
 | 
					      if processed_pattern then
 | 
				
			||||||
        local captures = match:captures()
 | 
					        local captures = match:captures()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not self:match_preds(preds, pattern, captures, source) then
 | 
					        local predicates = processed_pattern.predicates
 | 
				
			||||||
 | 
					        if not self:_match_predicates(predicates, pattern_i, captures, source) then
 | 
				
			||||||
          cursor:remove_match(match_id)
 | 
					          cursor:remove_match(match_id)
 | 
				
			||||||
          if end_line and captured_node:range() > end_line then
 | 
					          if end_line and captured_node:range() > end_line then
 | 
				
			||||||
            return nil, captured_node, nil, nil
 | 
					            return nil, captured_node, nil, nil
 | 
				
			||||||
@@ -918,9 +953,8 @@ function Query:iter_captures(node, source, start, stop)
 | 
				
			|||||||
          return iter(end_line) -- tail call: try next match
 | 
					          return iter(end_line) -- tail call: try next match
 | 
				
			||||||
        end
 | 
					        end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        metadata = self:apply_directives(preds, pattern, captures, source)
 | 
					        local directives = processed_pattern.directives
 | 
				
			||||||
      else
 | 
					        metadata = self:_apply_directives(directives, pattern_i, captures, source)
 | 
				
			||||||
        metadata = {}
 | 
					 | 
				
			||||||
      end
 | 
					      end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      highest_cached_match_id = math.max(highest_cached_match_id, match_id)
 | 
					      highest_cached_match_id = math.max(highest_cached_match_id, match_id)
 | 
				
			||||||
@@ -988,20 +1022,20 @@ function Query:iter_matches(node, source, start, stop, opts)
 | 
				
			|||||||
      return
 | 
					      return
 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    local match_id, pattern = match:info()
 | 
					    local match_id, pattern_i = match:info()
 | 
				
			||||||
    local preds = self.info.patterns[pattern]
 | 
					    local processed_pattern = self._processed_patterns[pattern_i]
 | 
				
			||||||
    local captures = match:captures()
 | 
					    local captures = match:captures()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    --- @type vim.treesitter.query.TSMetadata
 | 
					    --- @type vim.treesitter.query.TSMetadata
 | 
				
			||||||
    local metadata
 | 
					    local metadata = {}
 | 
				
			||||||
    if preds then
 | 
					    if processed_pattern then
 | 
				
			||||||
      if not self:match_preds(preds, pattern, captures, source) then
 | 
					      local predicates = processed_pattern.predicates
 | 
				
			||||||
 | 
					      if not self:_match_predicates(predicates, pattern_i, captures, source) then
 | 
				
			||||||
        cursor:remove_match(match_id)
 | 
					        cursor:remove_match(match_id)
 | 
				
			||||||
        return iter() -- tail call: try next match
 | 
					        return iter() -- tail call: try next match
 | 
				
			||||||
      end
 | 
					      end
 | 
				
			||||||
      metadata = self:apply_directives(preds, pattern, captures, source)
 | 
					      local directives = processed_pattern.directives
 | 
				
			||||||
    else
 | 
					      metadata = self:_apply_directives(directives, pattern_i, captures, source)
 | 
				
			||||||
      metadata = {}
 | 
					 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if opts.all == false then
 | 
					    if opts.all == false then
 | 
				
			||||||
@@ -1012,11 +1046,11 @@ function Query:iter_matches(node, source, start, stop, opts)
 | 
				
			|||||||
      for k, v in pairs(captures or {}) do
 | 
					      for k, v in pairs(captures or {}) do
 | 
				
			||||||
        old_match[k] = v[#v]
 | 
					        old_match[k] = v[#v]
 | 
				
			||||||
      end
 | 
					      end
 | 
				
			||||||
      return pattern, old_match, metadata
 | 
					      return pattern_i, old_match, metadata
 | 
				
			||||||
    end
 | 
					    end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    -- TODO(lewis6991): create a new function that returns {match, metadata}
 | 
					    -- TODO(lewis6991): create a new function that returns {match, metadata}
 | 
				
			||||||
    return pattern, captures, metadata
 | 
					    return pattern_i, captures, metadata
 | 
				
			||||||
  end
 | 
					  end
 | 
				
			||||||
  return iter
 | 
					  return iter
 | 
				
			||||||
end
 | 
					end
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -99,7 +99,6 @@ describe('decor perf', function()
 | 
				
			|||||||
    print('\nTotal ' .. fmt(total) .. '\nDecoration provider: ' .. fmt(provider))
 | 
					    print('\nTotal ' .. fmt(total) .. '\nDecoration provider: ' .. fmt(provider))
 | 
				
			||||||
  end)
 | 
					  end)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
  it('can handle full screen of highlighting', function()
 | 
					  it('can handle full screen of highlighting', function()
 | 
				
			||||||
    Screen.new(100, 51)
 | 
					    Screen.new(100, 51)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -835,9 +835,9 @@ void ui_refresh(void)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
      local result = exec_lua(function()
 | 
					      local result = exec_lua(function()
 | 
				
			||||||
        local query0 = vim.treesitter.query.parse('c', query)
 | 
					        local query0 = vim.treesitter.query.parse('c', query)
 | 
				
			||||||
        local match_preds = query0.match_preds
 | 
					        local match_preds = query0._match_predicates
 | 
				
			||||||
        local called = 0
 | 
					        local called = 0
 | 
				
			||||||
        function query0:match_preds(...)
 | 
					        function query0:_match_predicates(...)
 | 
				
			||||||
          called = called + 1
 | 
					          called = called + 1
 | 
				
			||||||
          return match_preds(self, ...)
 | 
					          return match_preds(self, ...)
 | 
				
			||||||
        end
 | 
					        end
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user