mirror of
				https://github.com/neovim/neovim.git
				synced 2025-11-04 09:44:31 +00:00 
			
		
		
		
	perf(treesitter): use child_containing_descendant() in has-ancestor? (#28512)
Problem: `has-ancestor?` is O(n²) for the depth of the tree since it iterates over each of the node's ancestors (bottom-up), and each ancestor takes O(n) time. This happens because tree-sitter's nodes don't store their parent nodes, and the tree is searched (top-down) each time a new parent is requested. Solution: Make use of new `ts_node_child_containing_descendant()` in tree-sitter v0.22.6 (which is now the minimum required version) to rewrite the `has-ancestor?` predicate in C to become O(n). For a sample file, decreases the time taken by `has-ancestor?` from 360ms to 6ms.
This commit is contained in:
		@@ -78,6 +78,8 @@ An instance `TSNode` of a treesitter node supports the following methods.
 | 
			
		||||
 | 
			
		||||
TSNode:parent()                                         *TSNode:parent()*
 | 
			
		||||
    Get the node's immediate parent.
 | 
			
		||||
    Prefer |TSNode:child_containing_descendant()|
 | 
			
		||||
    for iterating over the node's ancestors.
 | 
			
		||||
 | 
			
		||||
TSNode:next_sibling()                                   *TSNode:next_sibling()*
 | 
			
		||||
    Get the node's next sibling.
 | 
			
		||||
@@ -114,6 +116,9 @@ TSNode:named_child({index})                              *TSNode:named_child()*
 | 
			
		||||
    Get the node's named child at the given {index}, where zero represents the
 | 
			
		||||
    first named child.
 | 
			
		||||
 | 
			
		||||
TSNode:child_containing_descendant({descendant})  *TSNode:child_containing_descendant()*
 | 
			
		||||
    Get the node's child that contains {descendant}.
 | 
			
		||||
 | 
			
		||||
TSNode:start()                                          *TSNode:start()*
 | 
			
		||||
    Get the node's start position. Return three values: the row, column and
 | 
			
		||||
    total byte count (all zero-based).
 | 
			
		||||
 
 | 
			
		||||
@@ -20,6 +20,7 @@ error('Cannot require a meta file')
 | 
			
		||||
---@field descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode?
 | 
			
		||||
---@field named_descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode?
 | 
			
		||||
---@field parent fun(self: TSNode): TSNode?
 | 
			
		||||
---@field child_containing_descendant fun(self: TSNode, descendant: TSNode): TSNode?
 | 
			
		||||
---@field next_sibling fun(self: TSNode): TSNode?
 | 
			
		||||
---@field prev_sibling fun(self: TSNode): TSNode?
 | 
			
		||||
---@field next_named_sibling fun(self: TSNode): TSNode?
 | 
			
		||||
 
 | 
			
		||||
@@ -457,17 +457,8 @@ local predicate_handlers = {
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    for _, node in ipairs(nodes) do
 | 
			
		||||
      local ancestor_types = {} --- @type table<string, boolean>
 | 
			
		||||
      for _, type in ipairs({ unpack(predicate, 3) }) do
 | 
			
		||||
        ancestor_types[type] = true
 | 
			
		||||
      end
 | 
			
		||||
 | 
			
		||||
      local cur = node:parent()
 | 
			
		||||
      while cur do
 | 
			
		||||
        if ancestor_types[cur:type()] then
 | 
			
		||||
          return true
 | 
			
		||||
        end
 | 
			
		||||
        cur = cur:parent()
 | 
			
		||||
      if node:__has_ancestor(predicate) then
 | 
			
		||||
        return true
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
    return false
 | 
			
		||||
 
 | 
			
		||||
@@ -33,7 +33,7 @@ find_package(Libuv 1.28.0 REQUIRED)
 | 
			
		||||
find_package(Libvterm 0.3.3 REQUIRED)
 | 
			
		||||
find_package(Lpeg REQUIRED)
 | 
			
		||||
find_package(Msgpack 1.0.0 REQUIRED)
 | 
			
		||||
find_package(Treesitter 0.20.9 REQUIRED)
 | 
			
		||||
find_package(Treesitter 0.22.6 REQUIRED)
 | 
			
		||||
find_package(Unibilium 2.0 REQUIRED)
 | 
			
		||||
 | 
			
		||||
target_link_libraries(main_lib INTERFACE
 | 
			
		||||
 
 | 
			
		||||
@@ -725,6 +725,8 @@ static struct luaL_Reg node_meta[] = {
 | 
			
		||||
  { "descendant_for_range", node_descendant_for_range },
 | 
			
		||||
  { "named_descendant_for_range", node_named_descendant_for_range },
 | 
			
		||||
  { "parent", node_parent },
 | 
			
		||||
  { "__has_ancestor", __has_ancestor },
 | 
			
		||||
  { "child_containing_descendant", node_child_containing_descendant },
 | 
			
		||||
  { "iter_children", node_iter_children },
 | 
			
		||||
  { "next_sibling", node_next_sibling },
 | 
			
		||||
  { "prev_sibling", node_prev_sibling },
 | 
			
		||||
@@ -1052,6 +1054,49 @@ static int node_parent(lua_State *L)
 | 
			
		||||
  return 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static int __has_ancestor(lua_State *L)
 | 
			
		||||
{
 | 
			
		||||
  TSNode descendant = node_check(L, 1);
 | 
			
		||||
  if (lua_type(L, 2) != LUA_TTABLE) {
 | 
			
		||||
    lua_pushboolean(L, false);
 | 
			
		||||
    return 1;
 | 
			
		||||
  }
 | 
			
		||||
  int const pred_len = (int)lua_objlen(L, 2);
 | 
			
		||||
 | 
			
		||||
  TSNode node = ts_tree_root_node(descendant.tree);
 | 
			
		||||
  while (!ts_node_is_null(node)) {
 | 
			
		||||
    char const *node_type = ts_node_type(node);
 | 
			
		||||
    size_t node_type_len = strlen(node_type);
 | 
			
		||||
 | 
			
		||||
    for (int i = 3; i <= pred_len; i++) {
 | 
			
		||||
      lua_rawgeti(L, 2, i);
 | 
			
		||||
      if (lua_type(L, -1) == LUA_TSTRING) {
 | 
			
		||||
        size_t check_len;
 | 
			
		||||
        char const *check_str = lua_tolstring(L, -1, &check_len);
 | 
			
		||||
        if (node_type_len == check_len && memcmp(node_type, check_str, check_len) == 0) {
 | 
			
		||||
          lua_pushboolean(L, true);
 | 
			
		||||
          return 1;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      lua_pop(L, 1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    node = ts_node_child_containing_descendant(node, descendant);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  lua_pushboolean(L, false);
 | 
			
		||||
  return 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static int node_child_containing_descendant(lua_State *L)
 | 
			
		||||
{
 | 
			
		||||
  TSNode node = node_check(L, 1);
 | 
			
		||||
  TSNode descendant = node_check(L, 2);
 | 
			
		||||
  TSNode child = ts_node_child_containing_descendant(node, descendant);
 | 
			
		||||
  push_node(L, child, 1);
 | 
			
		||||
  return 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static int node_next_sibling(lua_State *L)
 | 
			
		||||
{
 | 
			
		||||
  TSNode node = node_check(L, 1);
 | 
			
		||||
 
 | 
			
		||||
@@ -143,4 +143,30 @@ describe('treesitter node API', function()
 | 
			
		||||
    eq(28, lua_eval('root:byte_length()'))
 | 
			
		||||
    eq(3, lua_eval('child:byte_length()'))
 | 
			
		||||
  end)
 | 
			
		||||
 | 
			
		||||
  it('child_containing_descendant() works', function()
 | 
			
		||||
    insert([[
 | 
			
		||||
      int main() {
 | 
			
		||||
        int x = 3;
 | 
			
		||||
      }]])
 | 
			
		||||
 | 
			
		||||
    exec_lua([[
 | 
			
		||||
      tree = vim.treesitter.get_parser(0, "c"):parse()[1]
 | 
			
		||||
      root = tree:root()
 | 
			
		||||
      main = root:child(0)
 | 
			
		||||
      body = main:child(2)
 | 
			
		||||
      statement = body:child(1)
 | 
			
		||||
      declarator = statement:child(1)
 | 
			
		||||
      value = declarator:child(1)
 | 
			
		||||
    ]])
 | 
			
		||||
 | 
			
		||||
    eq(lua_eval('main:type()'), lua_eval('root:child_containing_descendant(value):type()'))
 | 
			
		||||
    eq(lua_eval('body:type()'), lua_eval('main:child_containing_descendant(value):type()'))
 | 
			
		||||
    eq(lua_eval('statement:type()'), lua_eval('body:child_containing_descendant(value):type()'))
 | 
			
		||||
    eq(
 | 
			
		||||
      lua_eval('declarator:type()'),
 | 
			
		||||
      lua_eval('statement:child_containing_descendant(value):type()')
 | 
			
		||||
    )
 | 
			
		||||
    eq(vim.NIL, lua_eval('declarator:child_containing_descendant(value)'))
 | 
			
		||||
  end)
 | 
			
		||||
end)
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,22 @@ local pcall_err = t.pcall_err
 | 
			
		||||
local api = n.api
 | 
			
		||||
local fn = n.fn
 | 
			
		||||
 | 
			
		||||
local get_query_result_code = [[
 | 
			
		||||
  function get_query_result(query_text)
 | 
			
		||||
    cquery = vim.treesitter.query.parse("c", query_text)
 | 
			
		||||
    parser = vim.treesitter.get_parser(0, "c")
 | 
			
		||||
    tree = parser:parse()[1]
 | 
			
		||||
    res = {}
 | 
			
		||||
    for cid, node in cquery:iter_captures(tree:root(), 0) do
 | 
			
		||||
      -- can't transmit node over RPC. just check the name, range, and text
 | 
			
		||||
      local text = vim.treesitter.get_node_text(node, 0)
 | 
			
		||||
      local range = {node:range()}
 | 
			
		||||
      table.insert(res, { cquery.captures[cid], node:type(), range, text })
 | 
			
		||||
    end
 | 
			
		||||
    return res
 | 
			
		||||
  end
 | 
			
		||||
]]
 | 
			
		||||
 | 
			
		||||
describe('treesitter query API', function()
 | 
			
		||||
  before_each(function()
 | 
			
		||||
    clear()
 | 
			
		||||
@@ -291,21 +307,7 @@ void ui_refresh(void)
 | 
			
		||||
        return 0;
 | 
			
		||||
      }
 | 
			
		||||
    ]])
 | 
			
		||||
    exec_lua([[
 | 
			
		||||
      function get_query_result(query_text)
 | 
			
		||||
        cquery = vim.treesitter.query.parse("c", query_text)
 | 
			
		||||
        parser = vim.treesitter.get_parser(0, "c")
 | 
			
		||||
        tree = parser:parse()[1]
 | 
			
		||||
        res = {}
 | 
			
		||||
        for cid, node in cquery:iter_captures(tree:root(), 0) do
 | 
			
		||||
          -- can't transmit node over RPC. just check the name, range, and text
 | 
			
		||||
          local text = vim.treesitter.get_node_text(node, 0)
 | 
			
		||||
          local range = {node:range()}
 | 
			
		||||
          table.insert(res, { cquery.captures[cid], node:type(), range, text })
 | 
			
		||||
        end
 | 
			
		||||
        return res
 | 
			
		||||
      end
 | 
			
		||||
    ]])
 | 
			
		||||
    exec_lua(get_query_result_code)
 | 
			
		||||
 | 
			
		||||
    local res0 = exec_lua(
 | 
			
		||||
      [[return get_query_result(...)]],
 | 
			
		||||
@@ -333,6 +335,38 @@ void ui_refresh(void)
 | 
			
		||||
    }, res1)
 | 
			
		||||
  end)
 | 
			
		||||
 | 
			
		||||
  it('supports builtin predicate has-ancestor?', function()
 | 
			
		||||
    insert([[
 | 
			
		||||
      int x = 123;
 | 
			
		||||
      enum C { y = 124 };
 | 
			
		||||
      int main() { int z = 125; }]])
 | 
			
		||||
    exec_lua(get_query_result_code)
 | 
			
		||||
 | 
			
		||||
    local result = exec_lua(
 | 
			
		||||
      [[return get_query_result(...)]],
 | 
			
		||||
      [[((number_literal) @literal (#has-ancestor? @literal "function_definition"))]]
 | 
			
		||||
    )
 | 
			
		||||
    eq({ { 'literal', 'number_literal', { 2, 21, 2, 24 }, '125' } }, result)
 | 
			
		||||
 | 
			
		||||
    result = exec_lua(
 | 
			
		||||
      [[return get_query_result(...)]],
 | 
			
		||||
      [[((number_literal) @literal (#has-ancestor? @literal "function_definition" "enum_specifier"))]]
 | 
			
		||||
    )
 | 
			
		||||
    eq({
 | 
			
		||||
      { 'literal', 'number_literal', { 1, 13, 1, 16 }, '124' },
 | 
			
		||||
      { 'literal', 'number_literal', { 2, 21, 2, 24 }, '125' },
 | 
			
		||||
    }, result)
 | 
			
		||||
 | 
			
		||||
    result = exec_lua(
 | 
			
		||||
      [[return get_query_result(...)]],
 | 
			
		||||
      [[((number_literal) @literal (#not-has-ancestor? @literal "enum_specifier"))]]
 | 
			
		||||
    )
 | 
			
		||||
    eq({
 | 
			
		||||
      { 'literal', 'number_literal', { 0, 8, 0, 11 }, '123' },
 | 
			
		||||
      { 'literal', 'number_literal', { 2, 21, 2, 24 }, '125' },
 | 
			
		||||
    }, result)
 | 
			
		||||
  end)
 | 
			
		||||
 | 
			
		||||
  it('allows loading query with escaped quotes and capture them `#{lua,vim}-match`?', function()
 | 
			
		||||
    insert('char* astring = "Hello World!";')
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user