perf(lua): memoize key_fn results

This commit is contained in:
Yi Ming
2026-05-03 04:19:52 +08:00
parent 14819d55fb
commit 8394775241
3 changed files with 51 additions and 22 deletions

View File

@@ -1817,7 +1817,8 @@ vim.list.bisect({t}, {val}, {opts}) *vim.list.bisect()*
• {hi}? (`integer`, default: `#t + 1`) End index of the list,
exclusive.
• {key}? (`fun(val: any): any`) Optional, compare the return
value instead of the {val} itself if provided.
value instead of the {val} itself if provided. Key results
are memoized per call.
• {lo}? (`integer`, default: `1`) Start index of the list.
Return: ~
@@ -1831,10 +1832,10 @@ vim.list.unique({t}, {key}) *vim.list.unique()*
performed in-place and the input table is modified.
Accepts an optional `key` argument, which if provided is called for each
value in the list to compute a hash key for uniqueness comparison. This is
useful for deduplicating table values or complex objects. If `key` returns
`nil` for a value, that value will be considered unique, even if multiple
values return `nil`.
value in the list to compute a hash key for uniqueness comparison. Key
results are memoized per call. This is useful for deduplicating table
values or complex objects. If `key` returns `nil` for a value, that value
will be considered unique, even if multiple values return `nil`.
Example: >lua

View File

@@ -192,6 +192,8 @@ PERFORMANCE
thus reducing GC and memory reallocation during each data reset.
• When parsing the received Content-Length messages,
the RPC client will no longer allocate extra strings.
• |vim.list.unique()| and |vim.list.bisect()| now memoize key function results,
which can speed up calls with expensive key functions.
PLUGINS

View File

@@ -376,13 +376,36 @@ end
vim.list = {}
---TODO(ofseed): memoize, string value support, type alias.
---TODO(ofseed): string value support, type alias.
--- Returns a `key` function with per-call memoization.
---@generic T
---@param v T
---@param key? fun(v: T): any
---@return any
local function key_fn(v, key)
return key and key(v) or v
---@return fun(v: T): any
local function make_key_fn(key)
if not key then
return function(v)
return v
end
end
-- Keep memoized keys local to one list operation to avoid stale results.
local cache = {} --- @type table<any,any>
return function(v)
if v == nil then
return key(v)
end
local cached = cache[v]
if cached ~= nil then
-- Use `vim.NIL` to remember that `key(v)` returned `nil`.
return cached == vim.NIL and nil or cached
end
local result = key(v)
cache[v] = result == nil and vim.NIL or result
return result
end
end
--- Removes duplicate values from a |lua-list| in-place.
@@ -392,6 +415,7 @@ end
---
--- Accepts an optional `key` argument, which if provided is called for each
--- value in the list to compute a hash key for uniqueness comparison.
--- Key results are memoized per call.
--- This is useful for deduplicating table values or complex objects.
--- If `key` returns `nil` for a value, that value will be considered unique,
--- even if multiple values return `nil`.
@@ -416,6 +440,7 @@ end
--- @see |Iter:unique()|
function vim.list.unique(t, key)
vim.validate('t', t, 'table')
local key_fn = make_key_fn(key)
local seen = {} --- @type table<any,boolean>
local finish = #t
@@ -423,7 +448,7 @@ function vim.list.unique(t, key)
local j = 1
for i = 1, finish do
local v = t[i]
local vh = key_fn(v, key)
local vh = key_fn(v)
if not seen[vh] then
t[j] = v
if vh ~= nil then
@@ -452,6 +477,7 @@ end
---@field hi? integer
---
--- Optional, compare the return value instead of the {val} itself if provided.
--- Key results are memoized per call.
---@field key? fun(val: any): any
---
--- Specifies the search variant.
@@ -465,18 +491,18 @@ end
---@generic T
---@param t T[]
---@param val T
---@param key? fun(val: any): any
---@param lo integer
---@param hi integer
---@param key_fn fun(val: any): any
---@return integer i in range such that `t[j]` < {val} for all j < i,
--- and `t[j]` >= {val} for all j >= i,
--- or return {hi} if no such index is found.
local function lower_bound(t, val, lo, hi, key)
local function lower_bound(t, val, lo, hi, key_fn)
local bit = require('bit') -- Load bitop on demand
local val_key = key_fn(val, key)
local val_key = key_fn(val)
while lo < hi do
local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2)
if key_fn(t[mid], key) < val_key then
if key_fn(t[mid]) < val_key then
lo = mid + 1
else
hi = mid
@@ -488,18 +514,18 @@ end
---@generic T
---@param t T[]
---@param val T
---@param key? fun(val: any): any
---@param lo integer
---@param hi integer
---@param key_fn fun(val: any): any
---@return integer i in range such that `t[j]` <= {val} for all j < i,
--- and `t[j]` > {val} for all j >= i,
--- or return {hi} if no such index is found.
local function upper_bound(t, val, lo, hi, key)
local function upper_bound(t, val, lo, hi, key_fn)
local bit = require('bit') -- Load bitop on demand
local val_key = key_fn(val, key)
local val_key = key_fn(val)
while lo < hi do
local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2)
if val_key < key_fn(t[mid], key) then
if val_key < key_fn(t[mid]) then
hi = mid
else
lo = mid + 1
@@ -553,12 +579,12 @@ function vim.list.bisect(t, val, opts)
opts = opts or {}
local lo = opts.lo or 1
local hi = opts.hi or #t + 1
local key = opts.key
local key_fn = make_key_fn(opts.key)
if opts.bound == 'upper' then
return upper_bound(t, val, lo, hi, key)
return upper_bound(t, val, lo, hi, key_fn)
else
return lower_bound(t, val, lo, hi, key)
return lower_bound(t, val, lo, hi, key_fn)
end
end