mirror of
				https://github.com/neovim/neovim.git
				synced 2025-11-04 09:44:31 +00:00 
			
		
		
		
	feat(func): allow manual cache invalidation for _memoize
This commit also adds some tests for the existing memoization functionality.
This commit is contained in:
		
				
					committed by
					
						
						Christian Clason
					
				
			
			
				
	
			
			
			
						parent
						
							54ac406649
						
					
				
				
					commit
					b61051ccb4
				
			@@ -3,9 +3,6 @@ local M = {}
 | 
			
		||||
-- TODO(lewis6991): Private for now until:
 | 
			
		||||
-- - There are other places in the codebase that could benefit from this
 | 
			
		||||
--   (e.g. LSP), but might require other changes to accommodate.
 | 
			
		||||
-- - Invalidation of the cache needs to be controllable. Using weak tables
 | 
			
		||||
--   is an acceptable invalidation policy, but it shouldn't be the only
 | 
			
		||||
--   one.
 | 
			
		||||
-- - I don't think the story around `hash` is completely thought out. We
 | 
			
		||||
--   may be able to have a good default hash by hashing each argument,
 | 
			
		||||
--   so basically a better 'concat'.
 | 
			
		||||
@@ -17,6 +14,10 @@ local M = {}
 | 
			
		||||
--- Internally uses a |lua-weaktable| to cache the results of {fn} meaning the
 | 
			
		||||
--- cache will be invalidated whenever Lua does garbage collection.
 | 
			
		||||
---
 | 
			
		||||
--- The cache can also be manually invalidated by calling `:clear()` on the returned object.
 | 
			
		||||
--- Calling this function with no arguments clears the entire cache; otherwise, the arguments will
 | 
			
		||||
--- be interpreted as function inputs, and only the cache entry at their hash will be cleared.
 | 
			
		||||
---
 | 
			
		||||
--- The memoized function returns shared references so be wary about
 | 
			
		||||
--- mutating return values.
 | 
			
		||||
---
 | 
			
		||||
@@ -32,11 +33,12 @@ local M = {}
 | 
			
		||||
---     first n arguments passed to {fn}.
 | 
			
		||||
---
 | 
			
		||||
--- @param fn F Function to memoize.
 | 
			
		||||
--- @param strong? boolean Do not use a weak table
 | 
			
		||||
--- @param weak? boolean Use a weak table (default `true`)
 | 
			
		||||
--- @return F # Memoized version of {fn}
 | 
			
		||||
--- @nodoc
 | 
			
		||||
function M._memoize(hash, fn, strong)
 | 
			
		||||
  return require('vim.func._memoize')(hash, fn, strong)
 | 
			
		||||
function M._memoize(hash, fn, weak)
 | 
			
		||||
  -- this is wrapped in a function to lazily require the module
 | 
			
		||||
  return require('vim.func._memoize')(hash, fn, weak)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
return M
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,7 @@
 | 
			
		||||
--- Module for private utility functions
 | 
			
		||||
 | 
			
		||||
--- @alias vim.func.MemoObj { _hash: (fun(...): any), _weak: boolean?, _cache: table<any> }
 | 
			
		||||
 | 
			
		||||
--- @param argc integer?
 | 
			
		||||
--- @return fun(...): any
 | 
			
		||||
local function concat_hash(argc)
 | 
			
		||||
@@ -33,29 +35,49 @@ local function resolve_hash(hash)
 | 
			
		||||
  return hash
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- @param weak boolean?
 | 
			
		||||
--- @return table
 | 
			
		||||
local create_cache = function(weak)
 | 
			
		||||
  return setmetatable({}, {
 | 
			
		||||
    __mode = weak ~= false and 'kv',
 | 
			
		||||
  })
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
--- @generic F: function
 | 
			
		||||
--- @param hash integer|string|fun(...): any
 | 
			
		||||
--- @param fn F
 | 
			
		||||
--- @param strong? boolean
 | 
			
		||||
--- @param weak? boolean
 | 
			
		||||
--- @return F
 | 
			
		||||
return function(hash, fn, strong)
 | 
			
		||||
return function(hash, fn, weak)
 | 
			
		||||
  vim.validate('hash', hash, { 'number', 'string', 'function' })
 | 
			
		||||
  vim.validate('fn', fn, 'function')
 | 
			
		||||
  vim.validate('weak', weak, 'boolean', true)
 | 
			
		||||
 | 
			
		||||
  ---@type table<any,table<any,any>>
 | 
			
		||||
  local cache = {}
 | 
			
		||||
  if not strong then
 | 
			
		||||
    setmetatable(cache, { __mode = 'kv' })
 | 
			
		||||
  --- @type vim.func.MemoObj
 | 
			
		||||
  local obj = {
 | 
			
		||||
    _cache = create_cache(weak),
 | 
			
		||||
    _hash = resolve_hash(hash),
 | 
			
		||||
    _weak = weak,
 | 
			
		||||
    --- @param self vim.func.MemoObj
 | 
			
		||||
    clear = function(self, ...)
 | 
			
		||||
      if select('#', ...) == 0 then
 | 
			
		||||
        self._cache = create_cache(self._weak)
 | 
			
		||||
        return
 | 
			
		||||
      end
 | 
			
		||||
      local key = self._hash(...)
 | 
			
		||||
      self._cache[key] = nil
 | 
			
		||||
    end,
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  hash = resolve_hash(hash)
 | 
			
		||||
 | 
			
		||||
  return function(...)
 | 
			
		||||
    local key = hash(...)
 | 
			
		||||
  return setmetatable(obj, {
 | 
			
		||||
    --- @param self vim.func.MemoObj
 | 
			
		||||
    __call = function(self, ...)
 | 
			
		||||
      local key = self._hash(...)
 | 
			
		||||
      local cache = self._cache
 | 
			
		||||
      if cache[key] == nil then
 | 
			
		||||
        cache[key] = vim.F.pack_len(fn(...))
 | 
			
		||||
      end
 | 
			
		||||
 | 
			
		||||
      return vim.F.unpack_len(cache[key])
 | 
			
		||||
  end
 | 
			
		||||
    end,
 | 
			
		||||
  })
 | 
			
		||||
end
 | 
			
		||||
 
 | 
			
		||||
@@ -902,8 +902,8 @@ function Query:iter_captures(node, source, start, stop)
 | 
			
		||||
 | 
			
		||||
  local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
 | 
			
		||||
 | 
			
		||||
  local apply_directives = memoize(match_id_hash, self.apply_directives, true)
 | 
			
		||||
  local match_preds = memoize(match_id_hash, self.match_preds, true)
 | 
			
		||||
  local apply_directives = memoize(match_id_hash, self.apply_directives, false)
 | 
			
		||||
  local match_preds = memoize(match_id_hash, self.match_preds, false)
 | 
			
		||||
 | 
			
		||||
  local function iter(end_line)
 | 
			
		||||
    local capture, captured_node, match = cursor:next_capture()
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										142
									
								
								test/functional/func/memoize_spec.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								test/functional/func/memoize_spec.lua
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,142 @@
 | 
			
		||||
local t = require('test.testutil')
 | 
			
		||||
local n = require('test.functional.testnvim')()
 | 
			
		||||
local clear = n.clear
 | 
			
		||||
local exec_lua = n.exec_lua
 | 
			
		||||
local eq = t.eq
 | 
			
		||||
 | 
			
		||||
describe('vim.func._memoize', function()
 | 
			
		||||
  before_each(clear)
 | 
			
		||||
 | 
			
		||||
  it('caches function results based on their parameters', function()
 | 
			
		||||
    exec_lua([[
 | 
			
		||||
      _G.count = 0
 | 
			
		||||
 | 
			
		||||
      local adder = vim.func._memoize('concat', function(arg1, arg2)
 | 
			
		||||
        _G.count = _G.count + 1
 | 
			
		||||
        return arg1 + arg2
 | 
			
		||||
      end)
 | 
			
		||||
 | 
			
		||||
      collectgarbage('stop')
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      collectgarbage('restart')
 | 
			
		||||
    ]])
 | 
			
		||||
 | 
			
		||||
    eq(1, exec_lua([[return _G.count]]))
 | 
			
		||||
  end)
 | 
			
		||||
 | 
			
		||||
  it('caches function results using a weak table by default', function()
 | 
			
		||||
    exec_lua([[
 | 
			
		||||
      _G.count = 0
 | 
			
		||||
 | 
			
		||||
      local adder = vim.func._memoize('concat-2', function(arg1, arg2)
 | 
			
		||||
        _G.count = _G.count + 1
 | 
			
		||||
        return arg1 + arg2
 | 
			
		||||
      end)
 | 
			
		||||
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      collectgarbage()
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      collectgarbage()
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
    ]])
 | 
			
		||||
 | 
			
		||||
    eq(3, exec_lua([[return _G.count]]))
 | 
			
		||||
  end)
 | 
			
		||||
 | 
			
		||||
  it('can cache using a strong table', function()
 | 
			
		||||
    exec_lua([[
 | 
			
		||||
      _G.count = 0
 | 
			
		||||
 | 
			
		||||
      local adder = vim.func._memoize('concat-2', function(arg1, arg2)
 | 
			
		||||
        _G.count = _G.count + 1
 | 
			
		||||
        return arg1 + arg2
 | 
			
		||||
      end, false)
 | 
			
		||||
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      collectgarbage()
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      collectgarbage()
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
    ]])
 | 
			
		||||
 | 
			
		||||
    eq(1, exec_lua([[return _G.count]]))
 | 
			
		||||
  end)
 | 
			
		||||
 | 
			
		||||
  it('can clear a single cache entry', function()
 | 
			
		||||
    exec_lua([[
 | 
			
		||||
      _G.count = 0
 | 
			
		||||
 | 
			
		||||
      local adder = vim.func._memoize(function(arg1, arg2)
 | 
			
		||||
        return tostring(arg1) .. '%%' .. tostring(arg2)
 | 
			
		||||
      end, function(arg1, arg2)
 | 
			
		||||
        _G.count = _G.count + 1
 | 
			
		||||
        return arg1 + arg2
 | 
			
		||||
      end)
 | 
			
		||||
 | 
			
		||||
      collectgarbage('stop')
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder:clear(3, -4)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      collectgarbage('restart')
 | 
			
		||||
    ]])
 | 
			
		||||
 | 
			
		||||
    eq(2, exec_lua([[return _G.count]]))
 | 
			
		||||
  end)
 | 
			
		||||
 | 
			
		||||
  it('can clear the entire cache', function()
 | 
			
		||||
    exec_lua([[
 | 
			
		||||
      _G.count = 0
 | 
			
		||||
 | 
			
		||||
      local adder = vim.func._memoize(function(arg1, arg2)
 | 
			
		||||
        return tostring(arg1) .. '%%' .. tostring(arg2)
 | 
			
		||||
      end, function(arg1, arg2)
 | 
			
		||||
        _G.count = _G.count + 1
 | 
			
		||||
        return arg1 + arg2
 | 
			
		||||
      end)
 | 
			
		||||
 | 
			
		||||
      collectgarbage('stop')
 | 
			
		||||
      adder(1, 2)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder(1, 2)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder(1, 2)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      adder:clear()
 | 
			
		||||
      adder(1, 2)
 | 
			
		||||
      adder(3, -4)
 | 
			
		||||
      collectgarbage('restart')
 | 
			
		||||
    ]])
 | 
			
		||||
 | 
			
		||||
    eq(4, exec_lua([[return _G.count]]))
 | 
			
		||||
  end)
 | 
			
		||||
 | 
			
		||||
  it('can cache functions that return nil', function()
 | 
			
		||||
    exec_lua([[
 | 
			
		||||
      _G.count = 0
 | 
			
		||||
 | 
			
		||||
      local adder = vim.func._memoize('concat', function(arg1, arg2)
 | 
			
		||||
        _G.count = _G.count + 1
 | 
			
		||||
        return nil
 | 
			
		||||
      end)
 | 
			
		||||
 | 
			
		||||
      collectgarbage('stop')
 | 
			
		||||
      adder(1, 2)
 | 
			
		||||
      adder(1, 2)
 | 
			
		||||
      adder(1, 2)
 | 
			
		||||
      adder(1, 2)
 | 
			
		||||
      adder:clear()
 | 
			
		||||
      adder(1, 2)
 | 
			
		||||
      collectgarbage('restart')
 | 
			
		||||
    ]])
 | 
			
		||||
 | 
			
		||||
    eq(2, exec_lua([[return _G.count]]))
 | 
			
		||||
  end)
 | 
			
		||||
end)
 | 
			
		||||
		Reference in New Issue
	
	Block a user