mirror of
				https://github.com/neovim/neovim.git
				synced 2025-10-26 12:27:24 +00:00 
			
		
		
		
	feat(lua): add vim.func._memoize
Memoizes a function, using a custom function to hash the arguments. Private for now until: - There are other places in the codebase that could benefit from this (e.g. LSP), but might require other changes to accommodate. - Invalidation of the cache needs to be controllable. Using weak tables is an acceptable invalidation policy, but it shouldn't be the only one. - I don't think the story around `hash_fn` is completely thought out. We may be able to have a good default hash_fn by hashing each argument, so basically a better 'concat'.
This commit is contained in:
		 Lewis Russell
					Lewis Russell
				
			
				
					committed by
					
						 Lewis Russell
						Lewis Russell
					
				
			
			
				
	
			
			
			 Lewis Russell
						Lewis Russell
					
				
			
						parent
						
							11865dbe39
						
					
				
				
					commit
					877d04d0fb
				
			| @@ -29,6 +29,7 @@ for k, v in pairs({ | |||||||
|   treesitter = true, |   treesitter = true, | ||||||
|   filetype = true, |   filetype = true, | ||||||
|   loader = true, |   loader = true, | ||||||
|  |   func = true, | ||||||
|   F = true, |   F = true, | ||||||
|   lsp = true, |   lsp = true, | ||||||
|   highlight = true, |   highlight = true, | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ vim._watch = require('vim._watch') | |||||||
| vim.diagnostic = require('vim.diagnostic') | vim.diagnostic = require('vim.diagnostic') | ||||||
| vim.filetype = require('vim.filetype') | vim.filetype = require('vim.filetype') | ||||||
| vim.fs = require('vim.fs') | vim.fs = require('vim.fs') | ||||||
|  | vim.func = require('vim.func') | ||||||
| vim.health = require('vim.health') | vim.health = require('vim.health') | ||||||
| vim.highlight = require('vim.highlight') | vim.highlight = require('vim.highlight') | ||||||
| vim.iter = require('vim.iter') | vim.iter = require('vim.iter') | ||||||
|   | |||||||
							
								
								
									
										41
									
								
								runtime/lua/vim/func.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								runtime/lua/vim/func.lua
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | |||||||
|  | local M = {} | ||||||
|  |  | ||||||
|  | -- TODO(lewis6991): Private for now until: | ||||||
|  | -- - There are other places in the codebase that could benefit from this | ||||||
|  | --   (e.g. LSP), but might require other changes to accommodate. | ||||||
|  | -- - Invalidation of the cache needs to be controllable. Using weak tables | ||||||
|  | --   is an acceptable invalidation policy, but it shouldn't be the only | ||||||
|  | --   one. | ||||||
|  | -- - I don't think the story around `hash` is completely thought out. We | ||||||
|  | --   may be able to have a good default hash by hashing each argument, | ||||||
|  | --   so basically a better 'concat'. | ||||||
|  | -- - Need to support multi level caches. Can be done by allow `hash` to | ||||||
|  | --   return multiple values. | ||||||
|  | -- | ||||||
|  | --- Memoizes a function {fn} using {hash} to hash the arguments. | ||||||
|  | --- | ||||||
|  | --- Internally uses a |lua-weaktable| to cache the results of {fn} meaning the | ||||||
|  | --- cache will be invalidated whenever Lua does garbage collection. | ||||||
|  | --- | ||||||
|  | --- The memoized function returns shared references so be wary about | ||||||
|  | --- mutating return values. | ||||||
|  | --- | ||||||
|  | --- @generic F: function | ||||||
|  | --- @param hash integer|string|function Hash function to create a hash to use as a key to | ||||||
|  | ---     store results. Possible values: | ||||||
|  | ---     - When integer, refers to the index of an argument of {fn} to hash. | ||||||
|  | ---     This argument can have any type. | ||||||
|  | ---     - When function, is evaluated using the same arguments passed to {fn}. | ||||||
|  | ---     - When `concat`, the hash is determined by string concatenating all the | ||||||
|  | ---     arguments passed to {fn}. | ||||||
|  | ---     - When `concat-n`, the hash is determined by string concatenating the | ||||||
|  | ---     first n arguments passed to {fn}. | ||||||
|  | --- | ||||||
|  | --- @param fn F Function to memoize. | ||||||
|  | --- @return F # Memoized version of {fn} | ||||||
|  | --- @nodoc | ||||||
|  | function M._memoize(hash, fn) | ||||||
|  |   return require('vim.func._memoize')(hash, fn) | ||||||
|  | end | ||||||
|  |  | ||||||
|  | return M | ||||||
							
								
								
									
										59
									
								
								runtime/lua/vim/func/_memoize.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								runtime/lua/vim/func/_memoize.lua
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | |||||||
|  | --- Module for private utility functions | ||||||
|  |  | ||||||
|  | --- @param argc integer? | ||||||
|  | --- @return fun(...): any | ||||||
|  | local function concat_hash(argc) | ||||||
|  |   return function(...) | ||||||
|  |     return table.concat({ ... }, '%%', 1, argc) | ||||||
|  |   end | ||||||
|  | end | ||||||
|  |  | ||||||
|  | --- @param idx integer | ||||||
|  | --- @return fun(...): any | ||||||
|  | local function idx_hash(idx) | ||||||
|  |   return function(...) | ||||||
|  |     return select(idx, ...) | ||||||
|  |   end | ||||||
|  | end | ||||||
|  |  | ||||||
|  | --- @param hash integer|string|fun(...): any | ||||||
|  | --- @return fun(...): any | ||||||
|  | local function resolve_hash(hash) | ||||||
|  |   if type(hash) == 'number' then | ||||||
|  |     hash = idx_hash(hash) | ||||||
|  |   elseif type(hash) == 'string' then | ||||||
|  |     local c = hash == 'concat' or hash:match('^concat%-(%d+)') | ||||||
|  |     if c then | ||||||
|  |       hash = concat_hash(tonumber(c)) | ||||||
|  |     else | ||||||
|  |       error('invalid value for hash: ' .. hash) | ||||||
|  |     end | ||||||
|  |   end | ||||||
|  |   --- @cast hash -integer | ||||||
|  |   return hash | ||||||
|  | end | ||||||
|  |  | ||||||
|  | --- @generic F: function | ||||||
|  | --- @param hash integer|string|fun(...): any | ||||||
|  | --- @param fn F | ||||||
|  | --- @return F | ||||||
|  | return function(hash, fn) | ||||||
|  |   vim.validate({ | ||||||
|  |     hash = { hash, { 'number', 'string', 'function' } }, | ||||||
|  |     fn = { fn, 'function' }, | ||||||
|  |   }) | ||||||
|  |  | ||||||
|  |   ---@type table<any,table<any,any>> | ||||||
|  |   local cache = setmetatable({}, { __mode = 'kv' }) | ||||||
|  |  | ||||||
|  |   hash = resolve_hash(hash) | ||||||
|  |  | ||||||
|  |   return function(...) | ||||||
|  |     local key = hash(...) | ||||||
|  |     if cache[key] == nil then | ||||||
|  |       cache[key] = vim.F.pack_len(fn(...)) | ||||||
|  |     end | ||||||
|  |  | ||||||
|  |     return vim.F.unpack_len(cache[key]) | ||||||
|  |   end | ||||||
|  | end | ||||||
| @@ -10,20 +10,12 @@ local M = {} | |||||||
|  |  | ||||||
| --- @alias vim.treesitter.ParseError {msg: string, range: Range4} | --- @alias vim.treesitter.ParseError {msg: string, range: Range4} | ||||||
|  |  | ||||||
| --- @private |  | ||||||
| --- Caches parse results for queries for each language. |  | ||||||
| --- Entries of parse_cache[lang][query_text] will either be true for successful parse or contain the |  | ||||||
| --- message and range of the parse error. |  | ||||||
| --- @type table<string,table<string,vim.treesitter.ParseError|true>> |  | ||||||
| local parse_cache = {} |  | ||||||
|  |  | ||||||
| --- Contains language dependent context for the query linter | --- Contains language dependent context for the query linter | ||||||
| --- @class QueryLinterLanguageContext | --- @class QueryLinterLanguageContext | ||||||
| --- @field lang string? Current `lang` of the targeted parser | --- @field lang string? Current `lang` of the targeted parser | ||||||
| --- @field parser_info table? Parser info returned by vim.treesitter.language.inspect | --- @field parser_info table? Parser info returned by vim.treesitter.language.inspect | ||||||
| --- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs` | --- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs` | ||||||
|  |  | ||||||
| --- @private |  | ||||||
| --- Adds a diagnostic for node in the query buffer | --- Adds a diagnostic for node in the query buffer | ||||||
| --- @param diagnostics Diagnostic[] | --- @param diagnostics Diagnostic[] | ||||||
| --- @param range Range4 | --- @param range Range4 | ||||||
| @@ -42,7 +34,6 @@ local function add_lint_for_node(diagnostics, range, lint, lang) | |||||||
|   } |   } | ||||||
| end | end | ||||||
|  |  | ||||||
| --- @private |  | ||||||
| --- Determines the target language of a query file by its path: <lang>/<query_type>.scm | --- Determines the target language of a query file by its path: <lang>/<query_type>.scm | ||||||
| --- @param buf integer | --- @param buf integer | ||||||
| --- @return string? | --- @return string? | ||||||
| @@ -53,7 +44,6 @@ local function guess_query_lang(buf) | |||||||
|   end |   end | ||||||
| end | end | ||||||
|  |  | ||||||
| --- @private |  | ||||||
| --- @param buf integer | --- @param buf integer | ||||||
| --- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil | --- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil | ||||||
| --- @return QueryLinterNormalizedOpts | --- @return QueryLinterNormalizedOpts | ||||||
| @@ -87,7 +77,6 @@ local lint_query = [[;; query | |||||||
|   (ERROR) @error |   (ERROR) @error | ||||||
| ]] | ]] | ||||||
|  |  | ||||||
| --- @private |  | ||||||
| --- @param err string | --- @param err string | ||||||
| --- @param node TSNode | --- @param node TSNode | ||||||
| --- @return vim.treesitter.ParseError | --- @return vim.treesitter.ParseError | ||||||
| @@ -112,38 +101,26 @@ local function get_error_entry(err, node) | |||||||
|   } |   } | ||||||
| end | end | ||||||
|  |  | ||||||
| --- @private |  | ||||||
| --- @param node TSNode | --- @param node TSNode | ||||||
| --- @param buf integer | --- @param buf integer | ||||||
| --- @param lang string | --- @param lang string | ||||||
| --- @param diagnostics Diagnostic[] | local function hash_parse(node, buf, lang) | ||||||
| local function check_toplevel(node, buf, lang, diagnostics) |   return tostring(node:id()) .. tostring(buf) .. tostring(vim.b[buf].changedtick) .. lang | ||||||
|   local query_text = vim.treesitter.get_node_text(node, buf) |  | ||||||
|  |  | ||||||
|   if not parse_cache[lang] then |  | ||||||
|     parse_cache[lang] = {} |  | ||||||
|   end |  | ||||||
|  |  | ||||||
|   local lang_cache = parse_cache[lang] |  | ||||||
|  |  | ||||||
|   if lang_cache[query_text] == nil then |  | ||||||
|     local cache_val, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|Query |  | ||||||
|  |  | ||||||
|     if not cache_val and type(err) == 'string' then |  | ||||||
|       cache_val = get_error_entry(err, node) |  | ||||||
|     end |  | ||||||
|  |  | ||||||
|     lang_cache[query_text] = cache_val |  | ||||||
|   end |  | ||||||
|  |  | ||||||
|   local cache_entry = lang_cache[query_text] |  | ||||||
|  |  | ||||||
|   if type(cache_entry) ~= 'boolean' then |  | ||||||
|     add_lint_for_node(diagnostics, cache_entry.range, cache_entry.msg, lang) |  | ||||||
|   end |  | ||||||
| end | end | ||||||
|  |  | ||||||
| --- @private | --- @param node TSNode | ||||||
|  | --- @param buf integer | ||||||
|  | --- @param lang string | ||||||
|  | --- @return vim.treesitter.ParseError? | ||||||
|  | local parse = vim.func._memoize(hash_parse, function(node, buf, lang) | ||||||
|  |   local query_text = vim.treesitter.get_node_text(node, buf) | ||||||
|  |   local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|Query | ||||||
|  |  | ||||||
|  |   if not ok and type(err) == 'string' then | ||||||
|  |     return get_error_entry(err, node) | ||||||
|  |   end | ||||||
|  | end) | ||||||
|  |  | ||||||
| --- @param buf integer | --- @param buf integer | ||||||
| --- @param match table<integer,TSNode> | --- @param match table<integer,TSNode> | ||||||
| --- @param query Query | --- @param query Query | ||||||
| @@ -164,7 +141,10 @@ local function lint_match(buf, match, query, lang_context, diagnostics) | |||||||
|  |  | ||||||
|     -- other checks rely on Neovim parser introspection |     -- other checks rely on Neovim parser introspection | ||||||
|     if lang and parser_info and cap_id == 'toplevel' then |     if lang and parser_info and cap_id == 'toplevel' then | ||||||
|       check_toplevel(node, buf, lang, diagnostics) |       local err = parse(node, buf, lang) | ||||||
|  |       if err then | ||||||
|  |         add_lint_for_node(diagnostics, err.range, err.msg, lang) | ||||||
|  |       end | ||||||
|     end |     end | ||||||
|   end |   end | ||||||
| end | end | ||||||
|   | |||||||
| @@ -738,12 +738,14 @@ local function add_injection(t, tree_index, pattern, lang, combined, ranges) | |||||||
| end | end | ||||||
|  |  | ||||||
| -- TODO(clason): replace by refactored `ts.has_parser` API (without registering) | -- TODO(clason): replace by refactored `ts.has_parser` API (without registering) | ||||||
| ---@param lang string parser name | --- The result of this function is cached to prevent nvim_get_runtime_file from being | ||||||
| ---@return boolean # true if parser for {lang} exists on rtp | --- called too often | ||||||
| local has_parser = function(lang) | --- @param lang string parser name | ||||||
|  | --- @return boolean # true if parser for {lang} exists on rtp | ||||||
|  | local has_parser = vim.func._memoize(1, function(lang) | ||||||
|   return vim._ts_has_language(lang) |   return vim._ts_has_language(lang) | ||||||
|     or #vim.api.nvim_get_runtime_file('parser/' .. lang .. '.*', false) > 0 |     or #vim.api.nvim_get_runtime_file('parser/' .. lang .. '.*', false) > 0 | ||||||
| end | end) | ||||||
|  |  | ||||||
| --- Return parser name for language (if exists) or filetype (if registered and exists). | --- Return parser name for language (if exists) or filetype (if registered and exists). | ||||||
| --- Also attempts with the input lower-cased. | --- Also attempts with the input lower-cased. | ||||||
|   | |||||||
| @@ -191,12 +191,6 @@ function M.set(lang, query_name, text) | |||||||
|   explicit_queries[lang][query_name] = M.parse(lang, text) |   explicit_queries[lang][query_name] = M.parse(lang, text) | ||||||
| end | end | ||||||
|  |  | ||||||
| --- `false` if query files didn't exist or were empty |  | ||||||
| ---@type table<string, table<string, Query|false>> |  | ||||||
| local query_get_cache = vim.defaulttable(function() |  | ||||||
|   return setmetatable({}, { __mode = 'v' }) |  | ||||||
| end) |  | ||||||
|  |  | ||||||
| ---@deprecated | ---@deprecated | ||||||
| function M.get_query(...) | function M.get_query(...) | ||||||
|   vim.deprecate('vim.treesitter.query.get_query()', 'vim.treesitter.query.get()', '0.10') |   vim.deprecate('vim.treesitter.query.get_query()', 'vim.treesitter.query.get()', '0.10') | ||||||
| @@ -209,34 +203,19 @@ end | |||||||
| ---@param query_name string Name of the query (e.g. "highlights") | ---@param query_name string Name of the query (e.g. "highlights") | ||||||
| --- | --- | ||||||
| ---@return Query|nil Parsed query | ---@return Query|nil Parsed query | ||||||
| function M.get(lang, query_name) | M.get = vim.func._memoize('concat-2', function(lang, query_name) | ||||||
|   if explicit_queries[lang][query_name] then |   if explicit_queries[lang][query_name] then | ||||||
|     return explicit_queries[lang][query_name] |     return explicit_queries[lang][query_name] | ||||||
|   end |   end | ||||||
|  |  | ||||||
|   local cached = query_get_cache[lang][query_name] |  | ||||||
|   if cached then |  | ||||||
|     return cached |  | ||||||
|   elseif cached == false then |  | ||||||
|     return nil |  | ||||||
|   end |  | ||||||
|  |  | ||||||
|   local query_files = M.get_files(lang, query_name) |   local query_files = M.get_files(lang, query_name) | ||||||
|   local query_string = read_query_files(query_files) |   local query_string = read_query_files(query_files) | ||||||
|  |  | ||||||
|   if #query_string == 0 then |   if #query_string == 0 then | ||||||
|     query_get_cache[lang][query_name] = false |  | ||||||
|     return nil |     return nil | ||||||
|   end |   end | ||||||
|  |  | ||||||
|   local query = M.parse(lang, query_string) |   return M.parse(lang, query_string) | ||||||
|   query_get_cache[lang][query_name] = query |  | ||||||
|   return query |  | ||||||
| end |  | ||||||
|  |  | ||||||
| ---@type table<string, table<string, Query>> |  | ||||||
| local query_parse_cache = vim.defaulttable(function() |  | ||||||
|   return setmetatable({}, { __mode = 'v' }) |  | ||||||
| end) | end) | ||||||
|  |  | ||||||
| ---@deprecated | ---@deprecated | ||||||
| @@ -262,20 +241,15 @@ end | |||||||
| ---@param query string Query in s-expr syntax | ---@param query string Query in s-expr syntax | ||||||
| --- | --- | ||||||
| ---@return Query Parsed query | ---@return Query Parsed query | ||||||
| function M.parse(lang, query) | M.parse = vim.func._memoize('concat-2', function(lang, query) | ||||||
|   language.add(lang) |   language.add(lang) | ||||||
|   local cached = query_parse_cache[lang][query] |  | ||||||
|   if cached then |  | ||||||
|     return cached |  | ||||||
|   end |  | ||||||
|  |  | ||||||
|   local self = setmetatable({}, Query) |   local self = setmetatable({}, Query) | ||||||
|   self.query = vim._ts_parse_query(lang, query) |   self.query = vim._ts_parse_query(lang, query) | ||||||
|   self.info = self.query:inspect() |   self.info = self.query:inspect() | ||||||
|   self.captures = self.info.captures |   self.captures = self.info.captures | ||||||
|   query_parse_cache[lang][query] = self |  | ||||||
|   return self |   return self | ||||||
| end | end) | ||||||
|  |  | ||||||
| ---@deprecated | ---@deprecated | ||||||
| function M.get_range(...) | function M.get_range(...) | ||||||
|   | |||||||
| @@ -361,6 +361,12 @@ local function process_line(line, in_stream, generics) | |||||||
|     return process_block_comment(line:sub(5), in_stream) |     return process_block_comment(line:sub(5), in_stream) | ||||||
|   end |   end | ||||||
|  |  | ||||||
|  |   -- Hax... I'm sorry | ||||||
|  |   -- M.fun = vim.memoize(function(...) | ||||||
|  |   --   -> | ||||||
|  |   -- function M.fun(...) | ||||||
|  |   line = line:gsub('^(.+) = .*_memoize%([^,]+, function%((.*)%)$', 'function %1(%2)') | ||||||
|  |  | ||||||
|   if line:find('^function') or line:find('^local%s+function') then |   if line:find('^function') or line:find('^local%s+function') then | ||||||
|     return process_function_header(line) |     return process_function_header(line) | ||||||
|   end |   end | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user