mirror of
				https://github.com/neovim/neovim.git
				synced 2025-10-26 12:27:24 +00:00 
			
		
		
		
	feat(lua): add vim.func._memoize
Memoizes a function, using a custom function to hash the arguments. Private for now until: - There are other places in the codebase that could benefit from this (e.g. LSP), but might require other changes to accommodate. - Invalidation of the cache needs to be controllable. Using weak tables is an acceptable invalidation policy, but it shouldn't be the only one. - I don't think the story around `hash_fn` is completely thought out. We may be able to have a good default hash_fn by hashing each argument, so basically a better 'concat'.
This commit is contained in:
		
				
					committed by
					
						
						Lewis Russell
					
				
			
			
				
	
			
			
			
						parent
						
							11865dbe39
						
					
				
				
					commit
					877d04d0fb
				
			@@ -29,6 +29,7 @@ for k, v in pairs({
 | 
			
		||||
  treesitter = true,
 | 
			
		||||
  filetype = true,
 | 
			
		||||
  loader = true,
 | 
			
		||||
  func = true,
 | 
			
		||||
  F = true,
 | 
			
		||||
  lsp = true,
 | 
			
		||||
  highlight = true,
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,7 @@ vim._watch = require('vim._watch')
 | 
			
		||||
vim.diagnostic = require('vim.diagnostic')
 | 
			
		||||
vim.filetype = require('vim.filetype')
 | 
			
		||||
vim.fs = require('vim.fs')
 | 
			
		||||
vim.func = require('vim.func')
 | 
			
		||||
vim.health = require('vim.health')
 | 
			
		||||
vim.highlight = require('vim.highlight')
 | 
			
		||||
vim.iter = require('vim.iter')
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										41
									
								
								runtime/lua/vim/func.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								runtime/lua/vim/func.lua
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
			
		||||
local M = {}
 | 
			
		||||
 | 
			
		||||
-- TODO(lewis6991): Private for now until:
 | 
			
		||||
-- - There are other places in the codebase that could benefit from this
 | 
			
		||||
--   (e.g. LSP), but might require other changes to accommodate.
 | 
			
		||||
-- - Invalidation of the cache needs to be controllable. Using weak tables
 | 
			
		||||
--   is an acceptable invalidation policy, but it shouldn't be the only
 | 
			
		||||
--   one.
 | 
			
		||||
-- - I don't think the story around `hash` is completely thought out. We
 | 
			
		||||
--   may be able to have a good default hash by hashing each argument,
 | 
			
		||||
--   so basically a better 'concat'.
 | 
			
		||||
-- - Need to support multi level caches. Can be done by allow `hash` to
 | 
			
		||||
--   return multiple values.
 | 
			
		||||
--
 | 
			
		||||
--- Memoizes a function {fn} using {hash} to hash the arguments.
 | 
			
		||||
---
 | 
			
		||||
--- Internally uses a |lua-weaktable| to cache the results of {fn} meaning the
 | 
			
		||||
--- cache will be invalidated whenever Lua does garbage collection.
 | 
			
		||||
---
 | 
			
		||||
--- The memoized function returns shared references so be wary about
 | 
			
		||||
--- mutating return values.
 | 
			
		||||
---
 | 
			
		||||
--- @generic F: function
 | 
			
		||||
--- @param hash integer|string|function Hash function to create a hash to use as a key to
 | 
			
		||||
---     store results. Possible values:
 | 
			
		||||
---     - When integer, refers to the index of an argument of {fn} to hash.
 | 
			
		||||
---     This argument can have any type.
 | 
			
		||||
---     - When function, is evaluated using the same arguments passed to {fn}.
 | 
			
		||||
---     - When `concat`, the hash is determined by string concatenating all the
 | 
			
		||||
---     arguments passed to {fn}.
 | 
			
		||||
---     - When `concat-n`, the hash is determined by string concatenating the
 | 
			
		||||
---     first n arguments passed to {fn}.
 | 
			
		||||
---
 | 
			
		||||
--- @param fn F Function to memoize.
 | 
			
		||||
--- @return F # Memoized version of {fn}
 | 
			
		||||
--- @nodoc
 | 
			
		||||
function M._memoize(hash, fn)
 | 
			
		||||
  return require('vim.func._memoize')(hash, fn)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
return M
 | 
			
		||||
							
								
								
									
										59
									
								
								runtime/lua/vim/func/_memoize.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								runtime/lua/vim/func/_memoize.lua
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,59 @@
 | 
			
		||||
--- Module for private utility functions
 | 
			
		||||
 | 
			
		||||
--- @param argc integer?
 | 
			
		||||
--- @return fun(...): any
 | 
			
		||||
local function concat_hash(argc)
 | 
			
		||||
  return function(...)
 | 
			
		||||
    return table.concat({ ... }, '%%', 1, argc)
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- @param idx integer
 | 
			
		||||
--- @return fun(...): any
 | 
			
		||||
local function idx_hash(idx)
 | 
			
		||||
  return function(...)
 | 
			
		||||
    return select(idx, ...)
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- @param hash integer|string|fun(...): any
 | 
			
		||||
--- @return fun(...): any
 | 
			
		||||
local function resolve_hash(hash)
 | 
			
		||||
  if type(hash) == 'number' then
 | 
			
		||||
    hash = idx_hash(hash)
 | 
			
		||||
  elseif type(hash) == 'string' then
 | 
			
		||||
    local c = hash == 'concat' or hash:match('^concat%-(%d+)')
 | 
			
		||||
    if c then
 | 
			
		||||
      hash = concat_hash(tonumber(c))
 | 
			
		||||
    else
 | 
			
		||||
      error('invalid value for hash: ' .. hash)
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
  --- @cast hash -integer
 | 
			
		||||
  return hash
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- @generic F: function
 | 
			
		||||
--- @param hash integer|string|fun(...): any
 | 
			
		||||
--- @param fn F
 | 
			
		||||
--- @return F
 | 
			
		||||
return function(hash, fn)
 | 
			
		||||
  vim.validate({
 | 
			
		||||
    hash = { hash, { 'number', 'string', 'function' } },
 | 
			
		||||
    fn = { fn, 'function' },
 | 
			
		||||
  })
 | 
			
		||||
 | 
			
		||||
  ---@type table<any,table<any,any>>
 | 
			
		||||
  local cache = setmetatable({}, { __mode = 'kv' })
 | 
			
		||||
 | 
			
		||||
  hash = resolve_hash(hash)
 | 
			
		||||
 | 
			
		||||
  return function(...)
 | 
			
		||||
    local key = hash(...)
 | 
			
		||||
    if cache[key] == nil then
 | 
			
		||||
      cache[key] = vim.F.pack_len(fn(...))
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    return vim.F.unpack_len(cache[key])
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
@@ -10,20 +10,12 @@ local M = {}
 | 
			
		||||
 | 
			
		||||
--- @alias vim.treesitter.ParseError {msg: string, range: Range4}
 | 
			
		||||
 | 
			
		||||
--- @private
 | 
			
		||||
--- Caches parse results for queries for each language.
 | 
			
		||||
--- Entries of parse_cache[lang][query_text] will either be true for successful parse or contain the
 | 
			
		||||
--- message and range of the parse error.
 | 
			
		||||
--- @type table<string,table<string,vim.treesitter.ParseError|true>>
 | 
			
		||||
local parse_cache = {}
 | 
			
		||||
 | 
			
		||||
--- Contains language dependent context for the query linter
 | 
			
		||||
--- @class QueryLinterLanguageContext
 | 
			
		||||
--- @field lang string? Current `lang` of the targeted parser
 | 
			
		||||
--- @field parser_info table? Parser info returned by vim.treesitter.language.inspect
 | 
			
		||||
--- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs`
 | 
			
		||||
 | 
			
		||||
--- @private
 | 
			
		||||
--- Adds a diagnostic for node in the query buffer
 | 
			
		||||
--- @param diagnostics Diagnostic[]
 | 
			
		||||
--- @param range Range4
 | 
			
		||||
@@ -42,7 +34,6 @@ local function add_lint_for_node(diagnostics, range, lint, lang)
 | 
			
		||||
  }
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- @private
 | 
			
		||||
--- Determines the target language of a query file by its path: <lang>/<query_type>.scm
 | 
			
		||||
--- @param buf integer
 | 
			
		||||
--- @return string?
 | 
			
		||||
@@ -53,7 +44,6 @@ local function guess_query_lang(buf)
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- @private
 | 
			
		||||
--- @param buf integer
 | 
			
		||||
--- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil
 | 
			
		||||
--- @return QueryLinterNormalizedOpts
 | 
			
		||||
@@ -87,7 +77,6 @@ local lint_query = [[;; query
 | 
			
		||||
  (ERROR) @error
 | 
			
		||||
]]
 | 
			
		||||
 | 
			
		||||
--- @private
 | 
			
		||||
--- @param err string
 | 
			
		||||
--- @param node TSNode
 | 
			
		||||
--- @return vim.treesitter.ParseError
 | 
			
		||||
@@ -112,38 +101,26 @@ local function get_error_entry(err, node)
 | 
			
		||||
  }
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- @private
 | 
			
		||||
--- @param node TSNode
 | 
			
		||||
--- @param buf integer
 | 
			
		||||
--- @param lang string
 | 
			
		||||
--- @param diagnostics Diagnostic[]
 | 
			
		||||
local function check_toplevel(node, buf, lang, diagnostics)
 | 
			
		||||
  local query_text = vim.treesitter.get_node_text(node, buf)
 | 
			
		||||
 | 
			
		||||
  if not parse_cache[lang] then
 | 
			
		||||
    parse_cache[lang] = {}
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  local lang_cache = parse_cache[lang]
 | 
			
		||||
 | 
			
		||||
  if lang_cache[query_text] == nil then
 | 
			
		||||
    local cache_val, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|Query
 | 
			
		||||
 | 
			
		||||
    if not cache_val and type(err) == 'string' then
 | 
			
		||||
      cache_val = get_error_entry(err, node)
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    lang_cache[query_text] = cache_val
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  local cache_entry = lang_cache[query_text]
 | 
			
		||||
 | 
			
		||||
  if type(cache_entry) ~= 'boolean' then
 | 
			
		||||
    add_lint_for_node(diagnostics, cache_entry.range, cache_entry.msg, lang)
 | 
			
		||||
  end
 | 
			
		||||
local function hash_parse(node, buf, lang)
 | 
			
		||||
  return tostring(node:id()) .. tostring(buf) .. tostring(vim.b[buf].changedtick) .. lang
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- @private
 | 
			
		||||
--- @param node TSNode
 | 
			
		||||
--- @param buf integer
 | 
			
		||||
--- @param lang string
 | 
			
		||||
--- @return vim.treesitter.ParseError?
 | 
			
		||||
local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
 | 
			
		||||
  local query_text = vim.treesitter.get_node_text(node, buf)
 | 
			
		||||
  local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|Query
 | 
			
		||||
 | 
			
		||||
  if not ok and type(err) == 'string' then
 | 
			
		||||
    return get_error_entry(err, node)
 | 
			
		||||
  end
 | 
			
		||||
end)
 | 
			
		||||
 | 
			
		||||
--- @param buf integer
 | 
			
		||||
--- @param match table<integer,TSNode>
 | 
			
		||||
--- @param query Query
 | 
			
		||||
@@ -164,7 +141,10 @@ local function lint_match(buf, match, query, lang_context, diagnostics)
 | 
			
		||||
 | 
			
		||||
    -- other checks rely on Neovim parser introspection
 | 
			
		||||
    if lang and parser_info and cap_id == 'toplevel' then
 | 
			
		||||
      check_toplevel(node, buf, lang, diagnostics)
 | 
			
		||||
      local err = parse(node, buf, lang)
 | 
			
		||||
      if err then
 | 
			
		||||
        add_lint_for_node(diagnostics, err.range, err.msg, lang)
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
 
 | 
			
		||||
@@ -738,12 +738,14 @@ local function add_injection(t, tree_index, pattern, lang, combined, ranges)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
-- TODO(clason): replace by refactored `ts.has_parser` API (without registering)
 | 
			
		||||
---@param lang string parser name
 | 
			
		||||
---@return boolean # true if parser for {lang} exists on rtp
 | 
			
		||||
local has_parser = function(lang)
 | 
			
		||||
--- The result of this function is cached to prevent nvim_get_runtime_file from being
 | 
			
		||||
--- called too often
 | 
			
		||||
--- @param lang string parser name
 | 
			
		||||
--- @return boolean # true if parser for {lang} exists on rtp
 | 
			
		||||
local has_parser = vim.func._memoize(1, function(lang)
 | 
			
		||||
  return vim._ts_has_language(lang)
 | 
			
		||||
    or #vim.api.nvim_get_runtime_file('parser/' .. lang .. '.*', false) > 0
 | 
			
		||||
end
 | 
			
		||||
end)
 | 
			
		||||
 | 
			
		||||
--- Return parser name for language (if exists) or filetype (if registered and exists).
 | 
			
		||||
--- Also attempts with the input lower-cased.
 | 
			
		||||
 
 | 
			
		||||
@@ -191,12 +191,6 @@ function M.set(lang, query_name, text)
 | 
			
		||||
  explicit_queries[lang][query_name] = M.parse(lang, text)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- `false` if query files didn't exist or were empty
 | 
			
		||||
---@type table<string, table<string, Query|false>>
 | 
			
		||||
local query_get_cache = vim.defaulttable(function()
 | 
			
		||||
  return setmetatable({}, { __mode = 'v' })
 | 
			
		||||
end)
 | 
			
		||||
 | 
			
		||||
---@deprecated
 | 
			
		||||
function M.get_query(...)
 | 
			
		||||
  vim.deprecate('vim.treesitter.query.get_query()', 'vim.treesitter.query.get()', '0.10')
 | 
			
		||||
@@ -209,34 +203,19 @@ end
 | 
			
		||||
---@param query_name string Name of the query (e.g. "highlights")
 | 
			
		||||
---
 | 
			
		||||
---@return Query|nil Parsed query
 | 
			
		||||
function M.get(lang, query_name)
 | 
			
		||||
M.get = vim.func._memoize('concat-2', function(lang, query_name)
 | 
			
		||||
  if explicit_queries[lang][query_name] then
 | 
			
		||||
    return explicit_queries[lang][query_name]
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  local cached = query_get_cache[lang][query_name]
 | 
			
		||||
  if cached then
 | 
			
		||||
    return cached
 | 
			
		||||
  elseif cached == false then
 | 
			
		||||
    return nil
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  local query_files = M.get_files(lang, query_name)
 | 
			
		||||
  local query_string = read_query_files(query_files)
 | 
			
		||||
 | 
			
		||||
  if #query_string == 0 then
 | 
			
		||||
    query_get_cache[lang][query_name] = false
 | 
			
		||||
    return nil
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  local query = M.parse(lang, query_string)
 | 
			
		||||
  query_get_cache[lang][query_name] = query
 | 
			
		||||
  return query
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
---@type table<string, table<string, Query>>
 | 
			
		||||
local query_parse_cache = vim.defaulttable(function()
 | 
			
		||||
  return setmetatable({}, { __mode = 'v' })
 | 
			
		||||
  return M.parse(lang, query_string)
 | 
			
		||||
end)
 | 
			
		||||
 | 
			
		||||
---@deprecated
 | 
			
		||||
@@ -262,20 +241,15 @@ end
 | 
			
		||||
---@param query string Query in s-expr syntax
 | 
			
		||||
---
 | 
			
		||||
---@return Query Parsed query
 | 
			
		||||
function M.parse(lang, query)
 | 
			
		||||
M.parse = vim.func._memoize('concat-2', function(lang, query)
 | 
			
		||||
  language.add(lang)
 | 
			
		||||
  local cached = query_parse_cache[lang][query]
 | 
			
		||||
  if cached then
 | 
			
		||||
    return cached
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  local self = setmetatable({}, Query)
 | 
			
		||||
  self.query = vim._ts_parse_query(lang, query)
 | 
			
		||||
  self.info = self.query:inspect()
 | 
			
		||||
  self.captures = self.info.captures
 | 
			
		||||
  query_parse_cache[lang][query] = self
 | 
			
		||||
  return self
 | 
			
		||||
end
 | 
			
		||||
end)
 | 
			
		||||
 | 
			
		||||
---@deprecated
 | 
			
		||||
function M.get_range(...)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user