diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt index f546404f85..794df90d97 100644 --- a/runtime/doc/lua.txt +++ b/runtime/doc/lua.txt @@ -1755,13 +1755,66 @@ vim.islist({t}) *vim.islist()* See also: ~ • |vim.isarray()| +vim.list.bisect({t}, {val}, {opts}) *vim.list.bisect()* + Search for a position in a sorted list {t} where {val} can be inserted + while keeping the list sorted. + + Use {bound} to determine whether to return the first or the last position, + defaults to "lower", i.e., the first position. + + NOTE: Behavior is undefined on unsorted lists! + + Example: >lua + + local t = { 1, 2, 2, 3, 3, 3 } + local first = vim.list.bisect(t, 3) + -- `first` is `val`'s first index if found, + -- useful for existence checks. + print(t[first]) -- 3 + + local last = vim.list.bisect(t, 3, { bound = 'upper' }) + -- Note that `last` is 7, not 6, + -- this is suitable for insertion. + + table.insert(t, last, 4) + -- t is now { 1, 2, 2, 3, 3, 3, 4 } + + -- You can use lower bound and upper bound together + -- to obtain the range of occurrences of `val`. + + -- 3 is in [first, last) + for i = first, last - 1 do + print(t[i]) -- { 3, 3, 3 } + end +< + + Parameters: ~ + • {t} (`any[]`) A comparable list. + • {val} (`any`) The value to search. + • {opts} (`table?`) A table with the following fields: + • {lo}? (`integer`, default: `1`) Start index of the list. + • {hi}? (`integer`, default: `#t + 1`) End index of the list, + exclusive. + • {key}? (`fun(val: any): any`) Optional, compare the return + value instead of the {val} itself if provided. + • {bound}? (`'lower'|'upper'`, default: `'lower'`) Specifies + the search variant. + • "lower": returns the first position where inserting {val} + keeps the list sorted. + • "upper": returns the last position where inserting {val} + keeps the list sorted.. + + Return: ~ + (`integer`) index serves as either the lower bound or the upper bound + position. + vim.list.unique({t}, {key}) *vim.list.unique()* Removes duplicate values from a list-like table in-place. Only the first occurrence of each value is kept. The operation is performed in-place and the input table is modified. - Accepts an optional `hash` argument that if provided is called for each + Accepts an optional `key` argument that if provided is called for each value in the list to compute a hash key for uniqueness comparison. This is useful for deduplicating table values or complex objects. @@ -1778,7 +1831,7 @@ vim.list.unique({t}, {key}) *vim.list.unique()* Parameters: ~ • {t} (`any[]`) - • {key} (`fun(x: T): any??`) Optional hash function to determine + • {key} (`fun(x: T): any?`) Optional hash function to determine uniqueness of values Return: ~ diff --git a/runtime/doc/news.txt b/runtime/doc/news.txt index b7d098b956..7393f8ef04 100644 --- a/runtime/doc/news.txt +++ b/runtime/doc/news.txt @@ -234,6 +234,7 @@ LUA • |Iter:take()| and |Iter:skip()| now optionally accept predicates. • Built-in plugin manager |vim.pack| • |vim.list.unique()| to deduplicate lists. +• |vim.list.bisect()| for binary search. OPTIONS diff --git a/runtime/lua/vim/lsp/semantic_tokens.lua b/runtime/lua/vim/lsp/semantic_tokens.lua index 5246f93ef7..24bf5d46c1 100644 --- a/runtime/lua/vim/lsp/semantic_tokens.lua +++ b/runtime/lua/vim/lsp/semantic_tokens.lua @@ -45,38 +45,6 @@ local STHighlighter = { name = 'Semantic Tokens', active = {} } STHighlighter.__index = STHighlighter setmetatable(STHighlighter, Capability) ---- Do a binary search of the tokens in the half-open range [lo, hi). ---- ---- Return the index i in range such that tokens[j].line < line for all j < i, and ---- tokens[j].line >= line for all j >= i, or return hi if no such index is found. -local function lower_bound(tokens, line, lo, hi) - while lo < hi do - local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2). - if tokens[mid].end_line < line then - lo = mid + 1 - else - hi = mid - end - end - return lo -end - ---- Do a binary search of the tokens in the half-open range [lo, hi). ---- ---- Return the index i in range such that tokens[j].line <= line for all j < i, and ---- tokens[j].line > line for all j >= i, or return hi if no such index is found. -local function upper_bound(tokens, line, lo, hi) - while lo < hi do - local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2). - if line < tokens[mid].line then - hi = mid - else - lo = mid + 1 - end - end - return lo -end - --- Extracts modifier strings from the encoded number in the token array --- ---@param x integer @@ -488,8 +456,18 @@ function STHighlighter:on_win(topline, botline) local ft = vim.bo[self.bufnr].filetype local highlights = assert(current_result.highlights) - local first = lower_bound(highlights, topline, 1, #highlights + 1) - local last = upper_bound(highlights, botline, first, #highlights + 1) - 1 + local first = vim.list.bisect(highlights, { end_line = topline }, { + key = function(highlight) + return highlight.end_line + end, + }) + local last = vim.list.bisect(highlights, { line = botline }, { + lo = first, + bound = 'upper', + key = function(highlight) + return highlight.line + end, + }) - 1 --- @type boolean?, integer? local is_folded, foldend @@ -761,7 +739,11 @@ function M.get_at_pos(bufnr, row, col) for client_id, client in pairs(highlighter.client_state) do local highlights = client.current_result.highlights if highlights then - local idx = lower_bound(highlights, row, 1, #highlights + 1) + local idx = vim.list.bisect(highlights, { end_line = row }, { + key = function(highlight) + return highlight.end_line + end, + }) for i = idx, #highlights do local token = highlights[i] --- @cast token STTokenRangeInspect diff --git a/runtime/lua/vim/shared.lua b/runtime/lua/vim/shared.lua index 9ce5d5b72d..014b64c18b 100644 --- a/runtime/lua/vim/shared.lua +++ b/runtime/lua/vim/shared.lua @@ -350,12 +350,21 @@ end vim.list = {} +---TODO(ofseed): memoize, string value support, type alias. +---@generic T +---@param v T +---@param key? fun(v: T): any +---@return any +local function key_fn(v, key) + return key and key(v) or v +end + --- Removes duplicate values from a list-like table in-place. --- --- Only the first occurrence of each value is kept. --- The operation is performed in-place and the input table is modified. --- ---- Accepts an optional `hash` argument that if provided is called for each +--- Accepts an optional `key` argument that if provided is called for each --- value in the list to compute a hash key for uniqueness comparison. --- This is useful for deduplicating table values or complex objects. --- @@ -373,21 +382,18 @@ vim.list = {} --- --- @generic T --- @param t T[] ---- @param key? fun(x: T): any? Optional hash function to determine uniqueness of values +--- @param key? fun(x: T): any Optional hash function to determine uniqueness of values --- @return T[] : The deduplicated list function vim.list.unique(t, key) vim.validate('t', t, 'table') local seen = {} --- @type table local finish = #t - key = key or function(a) - return a - end local j = 1 for i = 1, finish do local v = t[i] - local vh = key(v) + local vh = key_fn(v, key) if not seen[vh] then t[j] = v if vh ~= nil then @@ -404,6 +410,127 @@ function vim.list.unique(t, key) return t end +---@class vim.list.bisect.Opts +---@inlinedoc +--- +--- Start index of the list. +--- (default: `1`) +---@field lo? integer +--- +--- End index of the list, exclusive. +--- (default: `#t + 1`) +---@field hi? integer +--- +--- Optional, compare the return value instead of the {val} itself if provided. +---@field key? fun(val: any): any +--- +--- Specifies the search variant. +--- - "lower": returns the first position +--- where inserting {val} keeps the list sorted. +--- - "upper": returns the last position +--- where inserting {val} keeps the list sorted.. +--- (default: `'lower'`) +---@field bound? 'lower' | 'upper' + +---@generic T +---@param t T[] +---@param val T +---@param key? fun(val: any): any +---@param lo integer +---@param hi integer +---@return integer i in range such that `t[j]` < {val} for all j < i, +--- and `t[j]` >= {val} for all j >= i, +--- or return {hi} if no such index is found. +local function lower_bound(t, val, lo, hi, key) + local bit = require('bit') -- Load bitop on demand + local val_key = key_fn(val, key) + while lo < hi do + local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2) + if key_fn(t[mid], key) < val_key then + lo = mid + 1 + else + hi = mid + end + end + return lo +end + +---@generic T +---@param t T[] +---@param val T +---@param key? fun(val: any): any +---@param lo integer +---@param hi integer +---@return integer i in range such that `t[j]` <= {val} for all j < i, +--- and `t[j]` > {val} for all j >= i, +--- or return {hi} if no such index is found. +local function upper_bound(t, val, lo, hi, key) + local bit = require('bit') -- Load bitop on demand + local val_key = key_fn(val, key) + while lo < hi do + local mid = bit.rshift(lo + hi, 1) -- Equivalent to floor((lo + hi) / 2) + if val_key < key_fn(t[mid], key) then + hi = mid + else + lo = mid + 1 + end + end + return lo +end + +--- Search for a position in a sorted list {t} +--- where {val} can be inserted while keeping the list sorted. +--- +--- Use {bound} to determine whether to return the first or the last position, +--- defaults to "lower", i.e., the first position. +--- +--- NOTE: Behavior is undefined on unsorted lists! +--- +--- Example: +--- ```lua +--- +--- local t = { 1, 2, 2, 3, 3, 3 } +--- local first = vim.list.bisect(t, 3) +--- -- `first` is `val`'s first index if found, +--- -- useful for existence checks. +--- print(t[first]) -- 3 +--- +--- local last = vim.list.bisect(t, 3, { bound = 'upper' }) +--- -- Note that `last` is 7, not 6, +--- -- this is suitable for insertion. +--- +--- table.insert(t, last, 4) +--- -- t is now { 1, 2, 2, 3, 3, 3, 4 } +--- +--- -- You can use lower bound and upper bound together +--- -- to obtain the range of occurrences of `val`. +--- +--- -- 3 is in [first, last) +--- for i = first, last - 1 do +--- print(t[i]) -- { 3, 3, 3 } +--- end +--- ``` +---@generic T +---@param t T[] A comparable list. +---@param val T The value to search. +---@param opts? vim.list.bisect.Opts +---@return integer index serves as either the lower bound or the upper bound position. +function vim.list.bisect(t, val, opts) + vim.validate('t', t, 'table') + vim.validate('opts', opts, 'table', true) + + opts = opts or {} + local lo = opts.lo or 1 + local hi = opts.hi or #t + 1 + local key = opts.key + + if opts.bound == 'upper' then + return upper_bound(t, val, lo, hi, key) + else + return lower_bound(t, val, lo, hi, key) + end +end + --- Checks if a table is empty. --- ---@see https://github.com/premake/premake-core/blob/master/src/base/table.lua diff --git a/test/functional/lua/list_spec.lua b/test/functional/lua/list_spec.lua new file mode 100644 index 0000000000..7534505ef2 --- /dev/null +++ b/test/functional/lua/list_spec.lua @@ -0,0 +1,65 @@ +-- Test suite for vim.list +local t = require('test.testutil') +local eq = t.eq + +describe('vim.list', function() + it('vim.list.unique()', function() + eq({ 1, 2, 3, 4, 5 }, vim.list.unique({ 1, 2, 2, 3, 4, 4, 5 })) + eq({ 1, 2, 3, 4, 5 }, vim.list.unique({ 1, 2, 3, 4, 4, 5, 1, 2, 3, 2, 1, 2, 3, 4, 5 })) + eq({ 1, 2, 3, 4, 5, field = 1 }, vim.list.unique({ 1, 2, 2, 3, 4, 4, 5, field = 1 })) + + -- Not properly defined, but test anyway + -- luajit evaluates #t as 7, whereas Lua 5.1 evaluates it as 12 + local r = vim.list.unique({ 1, 2, 2, 3, 4, 4, 5, nil, 6, 6, 7, 7 }) + if jit then + eq({ 1, 2, 3, 4, 5, nil, nil, nil, 6, 6, 7, 7 }, r) + else + eq({ 1, 2, 3, 4, 5, nil, 6, 7 }, r) + end + + eq( + { { 1 }, { 2 }, { 3 } }, + vim.list.unique({ { 1 }, { 1 }, { 2 }, { 2 }, { 3 }, { 3 } }, function(x) + return x[1] + end) + ) + end) + + --- Generate a list like { 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, ...}. + ---@param num integer + local function gen_list(num) + ---@type integer[] + local list = {} + for i = 1, num do + for _ = 1, i do + list[#list + 1] = i + end + end + return list + end + + --- Index of the last {num}. + --- Mathematically, a triangular number. + ---@param num integer + local function index(num) + return math.floor((math.pow(num, 2) + num) / 2) + end + + it("vim.list.bisect(..., { bound = 'lower' })", function() + local num = math.random(100) + local list = gen_list(num) + + local target = math.random(num) + eq(vim.list.bisect(list, target, { bound = 'lower' }), index(target - 1) + 1) + eq(vim.list.bisect(list, num + 1, { bound = 'lower' }), index(num) + 1) + end) + + it("vim.list.bisect(..., bound = { 'upper' })", function() + local num = math.random(100) + local list = gen_list(num) + + local target = math.random(num) + eq(vim.list.bisect(list, target, { bound = 'upper' }), index(target) + 1) + eq(vim.list.bisect(list, num + 1, { bound = 'upper' }), index(num) + 1) + end) +end) diff --git a/test/functional/lua/vim_spec.lua b/test/functional/lua/vim_spec.lua index 14e08cffba..9372f5c7b7 100644 --- a/test/functional/lua/vim_spec.lua +++ b/test/functional/lua/vim_spec.lua @@ -1260,28 +1260,6 @@ describe('lua stdlib', function() eq({ 2 }, exec_lua [[ return vim.list_extend({}, {2;a=1}, -1, 2) ]]) end) - it('vim.list.unique', function() - eq({ 1, 2, 3, 4, 5 }, vim.list.unique({ 1, 2, 2, 3, 4, 4, 5 })) - eq({ 1, 2, 3, 4, 5 }, vim.list.unique({ 1, 2, 3, 4, 4, 5, 1, 2, 3, 2, 1, 2, 3, 4, 5 })) - eq({ 1, 2, 3, 4, 5, field = 1 }, vim.list.unique({ 1, 2, 2, 3, 4, 4, 5, field = 1 })) - - -- Not properly defined, but test anyway - -- luajit evaluates #t as 7, whereas Lua 5.1 evaluates it as 12 - local r = vim.list.unique({ 1, 2, 2, 3, 4, 4, 5, nil, 6, 6, 7, 7 }) - if jit then - eq({ 1, 2, 3, 4, 5, nil, nil, nil, 6, 6, 7, 7 }, r) - else - eq({ 1, 2, 3, 4, 5, nil, 6, 7 }, r) - end - - eq( - { { 1 }, { 2 }, { 3 } }, - vim.list.unique({ { 1 }, { 1 }, { 2 }, { 2 }, { 3 }, { 3 } }, function(x) - return x[1] - end) - ) - end) - it('vim.tbl_add_reverse_lookup', function() eq( true,