mirror of
				https://github.com/neovim/neovim.git
				synced 2025-11-04 01:34:25 +00:00 
			
		
		
		
	This reverts the handling of base languages to the old way how nvim-treesitter handled them. When a language extends a base language it usually wants to be able to overwrite queries. Related: https://github.com/nvim-treesitter/nvim-treesitter/issues/633
		
			
				
	
	
		
			337 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Lua
		
	
	
	
	
	
			
		
		
	
	
			337 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Lua
		
	
	
	
	
	
local a = vim.api
 | 
						|
local language = require'vim.treesitter.language'
 | 
						|
 | 
						|
-- query: pattern matching on trees
 | 
						|
-- predicate matching is implemented in lua
 | 
						|
local Query = {}
 | 
						|
Query.__index = Query
 | 
						|
 | 
						|
local M = {}
 | 
						|
 | 
						|
-- Filter the runtime query files, the spec is like regular runtime files but in the new `queries`
 | 
						|
-- directory. They resemble ftplugins, that is that you can override queries by adding things in the
 | 
						|
-- `queries` directory, and extend using the `after/queries` directory.
 | 
						|
local function filter_files(file_list)
 | 
						|
  local main = nil
 | 
						|
  local after = {}
 | 
						|
 | 
						|
  for _, fname in ipairs(file_list) do
 | 
						|
    -- Only get the name of the directory containing the queries directory
 | 
						|
    if vim.fn.fnamemodify(fname, ":p:h:h:h:t") == "after" then
 | 
						|
      table.insert(after, fname)
 | 
						|
    -- The first one is the one with most priority
 | 
						|
    elseif not main then
 | 
						|
      main = fname
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  return main and { main, unpack(after) } or after
 | 
						|
end
 | 
						|
 | 
						|
local function runtime_query_path(lang, query_name)
 | 
						|
  return string.format('queries/%s/%s.scm', lang, query_name)
 | 
						|
end
 | 
						|
 | 
						|
local function filtered_runtime_queries(lang, query_name)
 | 
						|
  return filter_files(a.nvim_get_runtime_file(runtime_query_path(lang, query_name), true) or {})
 | 
						|
end
 | 
						|
 | 
						|
local function get_query_files(lang, query_name, is_included)
 | 
						|
  local lang_files = filtered_runtime_queries(lang, query_name)
 | 
						|
 | 
						|
  if #lang_files == 0 then return {} end
 | 
						|
 | 
						|
  local base_langs = {}
 | 
						|
 | 
						|
  -- Now get the base languages by looking at the first line of every file
 | 
						|
  -- The syntax is the folowing :
 | 
						|
  -- ;+ inherits: ({language},)*{language}
 | 
						|
  --
 | 
						|
  -- {language} ::= {lang} | ({lang})
 | 
						|
  local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$"
 | 
						|
 | 
						|
  for _, file in ipairs(lang_files) do
 | 
						|
    local modeline = vim.fn.readfile(file, "", 1)
 | 
						|
 | 
						|
    if #modeline == 1 then
 | 
						|
      local langlist = modeline[1]:match(MODELINE_FORMAT)
 | 
						|
 | 
						|
      if langlist then
 | 
						|
        for _, incllang in ipairs(vim.split(langlist, ',', true)) do
 | 
						|
          local is_optional = incllang:match("%(.*%)")
 | 
						|
 | 
						|
          if is_optional then
 | 
						|
            if not is_included then
 | 
						|
              table.insert(base_langs, incllang:sub(2, #incllang - 1))
 | 
						|
            end
 | 
						|
          else
 | 
						|
            table.insert(base_langs, incllang)
 | 
						|
          end
 | 
						|
        end
 | 
						|
      end
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  local query_files = {}
 | 
						|
  for _, base_lang in ipairs(base_langs) do
 | 
						|
    local base_files = get_query_files(base_lang, query_name, true)
 | 
						|
    vim.list_extend(query_files, base_files)
 | 
						|
  end
 | 
						|
  vim.list_extend(query_files, lang_files)
 | 
						|
 | 
						|
  return query_files
 | 
						|
end
 | 
						|
 | 
						|
local function read_query_files(filenames)
 | 
						|
  local contents = {}
 | 
						|
 | 
						|
  for _,filename in ipairs(filenames) do
 | 
						|
    vim.list_extend(contents, vim.fn.readfile(filename))
 | 
						|
  end
 | 
						|
 | 
						|
  return table.concat(contents, '\n')
 | 
						|
end
 | 
						|
 | 
						|
--- Returns the runtime query {query_name} for {lang}.
 | 
						|
--
 | 
						|
-- @param lang The language to use for the query
 | 
						|
-- @param query_name The name of the query (i.e. "highlights")
 | 
						|
--
 | 
						|
-- @return The corresponding query, parsed.
 | 
						|
function M.get_query(lang, query_name)
 | 
						|
  local query_files = get_query_files(lang, query_name)
 | 
						|
  local query_string = read_query_files(query_files)
 | 
						|
 | 
						|
  if #query_string > 0 then
 | 
						|
    return M.parse_query(lang, query_string)
 | 
						|
  end
 | 
						|
end
 | 
						|
 | 
						|
--- Parses a query.
 | 
						|
--
 | 
						|
-- @param language The language
 | 
						|
-- @param query A string containing the query (s-expr syntax)
 | 
						|
--
 | 
						|
-- @returns The query
 | 
						|
function M.parse_query(lang, query)
 | 
						|
  language.require_language(lang)
 | 
						|
  local self = setmetatable({}, Query)
 | 
						|
  self.query = vim._ts_parse_query(lang, query)
 | 
						|
  self.info = self.query:inspect()
 | 
						|
  self.captures = self.info.captures
 | 
						|
  return self
 | 
						|
end
 | 
						|
 | 
						|
-- TODO(vigoux): support multiline nodes too
 | 
						|
 | 
						|
--- Gets the text corresponding to a given node
 | 
						|
-- @param node the node
 | 
						|
-- @param bufnr the buffer from which the node in extracted.
 | 
						|
function M.get_node_text(node, source)
 | 
						|
  local start_row, start_col, start_byte = node:start()
 | 
						|
  local end_row, end_col, end_byte = node:end_()
 | 
						|
 | 
						|
  if type(source) == "number" then
 | 
						|
    if start_row ~= end_row then
 | 
						|
      return nil
 | 
						|
    end
 | 
						|
    local line = a.nvim_buf_get_lines(source, start_row, start_row+1, true)[1]
 | 
						|
    return string.sub(line, start_col+1, end_col)
 | 
						|
  elseif type(source) == "string" then
 | 
						|
    return source:sub(start_byte+1, end_byte)
 | 
						|
  end
 | 
						|
end
 | 
						|
 | 
						|
-- Predicate handler receive the following arguments
 | 
						|
-- (match, pattern, bufnr, predicate)
 | 
						|
local predicate_handlers = {
 | 
						|
  ["eq?"] = function(match, _, source, predicate)
 | 
						|
      local node = match[predicate[2]]
 | 
						|
      local node_text = M.get_node_text(node, source)
 | 
						|
 | 
						|
      local str
 | 
						|
      if type(predicate[3]) == "string" then
 | 
						|
        -- (#eq? @aa "foo")
 | 
						|
        str = predicate[3]
 | 
						|
      else
 | 
						|
        -- (#eq? @aa @bb)
 | 
						|
        str = M.get_node_text(match[predicate[3]], source)
 | 
						|
      end
 | 
						|
 | 
						|
      if node_text ~= str or str == nil then
 | 
						|
        return false
 | 
						|
      end
 | 
						|
 | 
						|
      return true
 | 
						|
  end,
 | 
						|
 | 
						|
  ["lua-match?"] = function(match, _, source, predicate)
 | 
						|
      local node = match[predicate[2]]
 | 
						|
      local regex = predicate[3]
 | 
						|
      local start_row, _, end_row, _ = node:range()
 | 
						|
      if start_row ~= end_row then
 | 
						|
        return false
 | 
						|
      end
 | 
						|
 | 
						|
      return string.find(M.get_node_text(node, source), regex)
 | 
						|
  end,
 | 
						|
 | 
						|
  ["match?"] = (function()
 | 
						|
    local magic_prefixes = {['\\v']=true, ['\\m']=true, ['\\M']=true, ['\\V']=true}
 | 
						|
    local function check_magic(str)
 | 
						|
      if string.len(str) < 2 or magic_prefixes[string.sub(str,1,2)] then
 | 
						|
        return str
 | 
						|
      end
 | 
						|
      return '\\v'..str
 | 
						|
    end
 | 
						|
 | 
						|
    local compiled_vim_regexes = setmetatable({}, {
 | 
						|
      __index = function(t, pattern)
 | 
						|
        local res = vim.regex(check_magic(vim.fn.escape(pattern, '\\')))
 | 
						|
        rawset(t, pattern, res)
 | 
						|
        return res
 | 
						|
      end
 | 
						|
    })
 | 
						|
 | 
						|
    return function(match, _, source, pred)
 | 
						|
      local node = match[pred[2]]
 | 
						|
      local start_row, start_col, end_row, end_col = node:range()
 | 
						|
      if start_row ~= end_row then
 | 
						|
        return false
 | 
						|
      end
 | 
						|
 | 
						|
      local regex = compiled_vim_regexes[pred[3]]
 | 
						|
      return regex:match_line(source, start_row, start_col, end_col)
 | 
						|
    end
 | 
						|
  end)(),
 | 
						|
 | 
						|
  ["contains?"] = function(match, _, source, predicate)
 | 
						|
    local node = match[predicate[2]]
 | 
						|
    local node_text = M.get_node_text(node, source)
 | 
						|
 | 
						|
    for i=3,#predicate do
 | 
						|
      if string.find(node_text, predicate[i], 1, true) then
 | 
						|
        return true
 | 
						|
      end
 | 
						|
    end
 | 
						|
 | 
						|
    return false
 | 
						|
  end
 | 
						|
}
 | 
						|
 | 
						|
-- As we provide lua-match? also expose vim-match?
 | 
						|
predicate_handlers["vim-match?"] = predicate_handlers["match?"]
 | 
						|
 | 
						|
--- Adds a new predicates to be used in queries
 | 
						|
--
 | 
						|
-- @param name the name of the predicate, without leading #
 | 
						|
-- @param handler the handler function to be used
 | 
						|
--    signature will be (match, pattern, bufnr, predicate)
 | 
						|
function M.add_predicate(name, handler, force)
 | 
						|
  if predicate_handlers[name] and not force then
 | 
						|
    a.nvim_err_writeln(string.format("Overriding %s", name))
 | 
						|
  end
 | 
						|
 | 
						|
  predicate_handlers[name] = handler
 | 
						|
end
 | 
						|
 | 
						|
--- Returns the list of currently supported predicates
 | 
						|
function M.list_predicates()
 | 
						|
  return vim.tbl_keys(predicate_handlers)
 | 
						|
end
 | 
						|
 | 
						|
local function xor(x, y)
 | 
						|
  return (x or y) and not (x and y)
 | 
						|
end
 | 
						|
 | 
						|
function Query:match_preds(match, pattern, source)
 | 
						|
  local preds = self.info.patterns[pattern]
 | 
						|
 | 
						|
  for _, pred in pairs(preds or {}) do
 | 
						|
    -- Here we only want to return if a predicate DOES NOT match, and
 | 
						|
    -- continue on the other case. This way unknown predicates will not be considered,
 | 
						|
    -- which allows some testing and easier user extensibility (#12173).
 | 
						|
    -- Also, tree-sitter strips the leading # from predicates for us.
 | 
						|
    local pred_name
 | 
						|
    local is_not
 | 
						|
    if string.sub(pred[1], 1, 4) == "not-" then
 | 
						|
      pred_name = string.sub(pred[1], 5)
 | 
						|
      is_not = true
 | 
						|
    else
 | 
						|
      pred_name = pred[1]
 | 
						|
      is_not = false
 | 
						|
    end
 | 
						|
 | 
						|
    local handler = predicate_handlers[pred_name]
 | 
						|
 | 
						|
    if not handler then
 | 
						|
      a.nvim_err_writeln(string.format("No handler for %s", pred[1]))
 | 
						|
      return false
 | 
						|
    end
 | 
						|
 | 
						|
    local pred_matches = handler(match, pattern, source, pred)
 | 
						|
 | 
						|
    if not xor(is_not, pred_matches) then
 | 
						|
      return false
 | 
						|
    end
 | 
						|
  end
 | 
						|
  return true
 | 
						|
end
 | 
						|
 | 
						|
--- Iterates of the captures of self on a given range.
 | 
						|
--
 | 
						|
-- @param node The node under witch the search will occur
 | 
						|
-- @param buffer The source buffer to search
 | 
						|
-- @param start The starting line of the search
 | 
						|
-- @param stop The stoping line of the search (end-exclusive)
 | 
						|
--
 | 
						|
-- @returns The matching capture id
 | 
						|
-- @returns The captured node
 | 
						|
function Query:iter_captures(node, source, start, stop)
 | 
						|
  if type(source) == "number" and source == 0 then
 | 
						|
    source = vim.api.nvim_get_current_buf()
 | 
						|
  end
 | 
						|
  local raw_iter = node:_rawquery(self.query, true, start, stop)
 | 
						|
  local function iter()
 | 
						|
    local capture, captured_node, match = raw_iter()
 | 
						|
    if match ~= nil then
 | 
						|
      local active = self:match_preds(match, match.pattern, source)
 | 
						|
      match.active = active
 | 
						|
      if not active then
 | 
						|
        return iter() -- tail call: try next match
 | 
						|
      end
 | 
						|
    end
 | 
						|
    return capture, captured_node
 | 
						|
  end
 | 
						|
  return iter
 | 
						|
end
 | 
						|
 | 
						|
--- Iterates of the matches of self on a given range.
 | 
						|
--
 | 
						|
-- @param node The node under witch the search will occur
 | 
						|
-- @param buffer The source buffer to search
 | 
						|
-- @param start The starting line of the search
 | 
						|
-- @param stop The stoping line of the search (end-exclusive)
 | 
						|
--
 | 
						|
-- @returns The matching pattern id
 | 
						|
-- @returns The matching match
 | 
						|
function Query:iter_matches(node, source, start, stop)
 | 
						|
  if type(source) == "number" and source == 0 then
 | 
						|
    source = vim.api.nvim_get_current_buf()
 | 
						|
  end
 | 
						|
  local raw_iter = node:_rawquery(self.query, false, start, stop)
 | 
						|
  local function iter()
 | 
						|
    local pattern, match = raw_iter()
 | 
						|
    if match ~= nil then
 | 
						|
      local active = self:match_preds(match, pattern, source)
 | 
						|
      if not active then
 | 
						|
        return iter() -- tail call: try next match
 | 
						|
      end
 | 
						|
    end
 | 
						|
    return pattern, match
 | 
						|
  end
 | 
						|
  return iter
 | 
						|
end
 | 
						|
 | 
						|
return M
 |