From 83947752412d55fc089a45549bb51ec0863f55fa Mon Sep 17 00:00:00 2001 From: Yi Ming Date: Sun, 3 May 2026 04:19:52 +0800 Subject: [PATCH] perf(lua): memoize `key_fn` results --- runtime/doc/lua.txt | 11 +++--- runtime/doc/news.txt | 2 ++ runtime/lua/vim/_core/shared.lua | 60 +++++++++++++++++++++++--------- 3 files changed, 51 insertions(+), 22 deletions(-) diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt index 249e9f4f91..dfac7b9c6c 100644 --- a/runtime/doc/lua.txt +++ b/runtime/doc/lua.txt @@ -1817,7 +1817,8 @@ vim.list.bisect({t}, {val}, {opts}) *vim.list.bisect()* • {hi}? (`integer`, default: `#t + 1`) End index of the list, exclusive. • {key}? (`fun(val: any): any`) Optional, compare the return - value instead of the {val} itself if provided. + value instead of the {val} itself if provided. Key results + are memoized per call. • {lo}? (`integer`, default: `1`) Start index of the list. Return: ~ @@ -1831,10 +1832,10 @@ vim.list.unique({t}, {key}) *vim.list.unique()* performed in-place and the input table is modified. Accepts an optional `key` argument, which if provided is called for each - value in the list to compute a hash key for uniqueness comparison. This is - useful for deduplicating table values or complex objects. If `key` returns - `nil` for a value, that value will be considered unique, even if multiple - values return `nil`. + value in the list to compute a hash key for uniqueness comparison. Key + results are memoized per call. This is useful for deduplicating table + values or complex objects. If `key` returns `nil` for a value, that value + will be considered unique, even if multiple values return `nil`. Example: >lua diff --git a/runtime/doc/news.txt b/runtime/doc/news.txt index 6857cd4eae..1a482f343b 100644 --- a/runtime/doc/news.txt +++ b/runtime/doc/news.txt @@ -192,6 +192,8 @@ PERFORMANCE thus reducing GC and memory reallocation during each data reset. • When parsing the received Content-Length messages, the RPC client will no longer allocate extra strings. +• |vim.list.unique()| and |vim.list.bisect()| now memoize key function results, + which can speed up calls with expensive key functions. PLUGINS diff --git a/runtime/lua/vim/_core/shared.lua b/runtime/lua/vim/_core/shared.lua index 6df4c85ae7..719ee283e8 100644 --- a/runtime/lua/vim/_core/shared.lua +++ b/runtime/lua/vim/_core/shared.lua @@ -376,13 +376,36 @@ end vim.list = {} ----TODO(ofseed): memoize, string value support, type alias. +---TODO(ofseed): string value support, type alias. +--- Returns a `key` function with per-call memoization. ---@generic T ----@param v T ---@param key? fun(v: T): any ----@return any -local function key_fn(v, key) - return key and key(v) or v +---@return fun(v: T): any +local function make_key_fn(key) + if not key then + return function(v) + return v + end + end + + -- Keep memoized keys local to one list operation to avoid stale results. + local cache = {} --- @type table + + return function(v) + if v == nil then + return key(v) + end + + local cached = cache[v] + if cached ~= nil then + -- Use `vim.NIL` to remember that `key(v)` returned `nil`. + return cached == vim.NIL and nil or cached + end + + local result = key(v) + cache[v] = result == nil and vim.NIL or result + return result + end end --- Removes duplicate values from a |lua-list| in-place. @@ -392,6 +415,7 @@ end --- --- Accepts an optional `key` argument, which if provided is called for each --- value in the list to compute a hash key for uniqueness comparison. +--- Key results are memoized per call. --- This is useful for deduplicating table values or complex objects. --- If `key` returns `nil` for a value, that value will be considered unique, --- even if multiple values return `nil`. @@ -416,6 +440,7 @@ end --- @see |Iter:unique()| function vim.list.unique(t, key) vim.validate('t', t, 'table') + local key_fn = make_key_fn(key) local seen = {} --- @type table local finish = #t @@ -423,7 +448,7 @@ function vim.list.unique(t, key) local j = 1 for i = 1, finish do local v = t[i] - local vh = key_fn(v, key) + local vh = key_fn(v) if not seen[vh] then t[j] = v if vh ~= nil then @@ -452,6 +477,7 @@ end ---@field hi? integer --- --- Optional, compare the return value instead of the {val} itself if provided. +--- Key results are memoized per call. ---@field key? fun(val: any): any --- --- Specifies the search variant. @@ -465,18 +491,18 @@ end ---@generic T ---@param t T[] ---@param val T ----@param key? fun(val: any): any ---@param lo integer ---@param hi integer +---@param key_fn fun(val: any): any ---@return integer i in range such that `t[j]` < {val} for all j < i, --- and `t[j]` >= {val} for all j >= i, --- or return {hi} if no such index is found. -local function lower_bound(t, val, lo, hi, key) +local function lower_bound(t, val, lo, hi, key_fn) local bit = require('bit') -- Load bitop on demand - local val_key = key_fn(val, key) + local val_key = key_fn(val) while lo < hi do local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2) - if key_fn(t[mid], key) < val_key then + if key_fn(t[mid]) < val_key then lo = mid + 1 else hi = mid @@ -488,18 +514,18 @@ end ---@generic T ---@param t T[] ---@param val T ----@param key? fun(val: any): any ---@param lo integer ---@param hi integer +---@param key_fn fun(val: any): any ---@return integer i in range such that `t[j]` <= {val} for all j < i, --- and `t[j]` > {val} for all j >= i, --- or return {hi} if no such index is found. -local function upper_bound(t, val, lo, hi, key) +local function upper_bound(t, val, lo, hi, key_fn) local bit = require('bit') -- Load bitop on demand - local val_key = key_fn(val, key) + local val_key = key_fn(val) while lo < hi do local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2) - if val_key < key_fn(t[mid], key) then + if val_key < key_fn(t[mid]) then hi = mid else lo = mid + 1 @@ -553,12 +579,12 @@ function vim.list.bisect(t, val, opts) opts = opts or {} local lo = opts.lo or 1 local hi = opts.hi or #t + 1 - local key = opts.key + local key_fn = make_key_fn(opts.key) if opts.bound == 'upper' then - return upper_bound(t, val, lo, hi, key) + return upper_bound(t, val, lo, hi, key_fn) else - return lower_bound(t, val, lo, hi, key) + return lower_bound(t, val, lo, hi, key_fn) end end