Merge pull request #23382 from gpanders/iter-benchmark

Add vim.iter benchmark to benchmark test suite
This commit is contained in:
Gregory Anders
2023-04-29 20:33:27 -06:00
committed by GitHub
3 changed files with 320 additions and 65 deletions

View File

@@ -1,13 +1,14 @@
---@defgroup lua-iter
---
--- The \*vim.iter\* module provides a generic "iterator" interface over tables and iterator
--- functions.
--- The \*vim.iter\* module provides a generic "iterator" interface over tables
--- and iterator functions.
---
--- \*vim.iter()\* wraps its table or function argument into an \*Iter\* object with methods (such
--- as |Iter:filter()| and |Iter:map()|) that transform the underlying source data. These methods
--- can be chained together to create iterator "pipelines". Each pipeline stage receives as input
--- the output values from the prior stage. The values used in the first stage of the pipeline
--- depend on the type passed to this function:
--- \*vim.iter()\* wraps its table or function argument into an \*Iter\* object
--- with methods (such as |Iter:filter()| and |Iter:map()|) that transform the
--- underlying source data. These methods can be chained together to create
--- iterator "pipelines". Each pipeline stage receives as input the output
--- values from the prior stage. The values used in the first stage of the
--- pipeline depend on the type passed to this function:
---
--- - List tables pass only the value of each element
--- - Non-list tables pass both the key and value of each element
@@ -47,8 +48,8 @@
--- -- true
--- </pre>
---
--- In addition to the |vim.iter()| function, the |vim.iter| module provides convenience functions
--- like |vim.iter.filter()| and |vim.iter.totable()|.
--- In addition to the |vim.iter()| function, the |vim.iter| module provides
--- convenience functions like |vim.iter.filter()| and |vim.iter.totable()|.
local M = {}
@@ -61,9 +62,9 @@ end
--- Special case implementations for iterators on list tables.
---@class ListIter : Iter
---@field _table table Underlying table data (table iterators only)
---@field _head number Index to the front of a table iterator (table iterators only)
---@field _tail number Index to the end of a table iterator (table iterators only)
---@field _table table Underlying table data
---@field _head number Index to the front of a table iterator
---@field _tail number Index to the end of a table iterator
local ListIter = {}
ListIter.__index = setmetatable(ListIter, Iter)
ListIter.__call = function(self)
@@ -75,7 +76,7 @@ local packedmt = {}
---@private
local function unpack(t)
if getmetatable(t) == packedmt then
if type(t) == 'table' and getmetatable(t) == packedmt then
return _G.unpack(t, 1, t.n)
end
return t
@@ -92,13 +93,47 @@ end
---@private
local function sanitize(t)
if getmetatable(t) == packedmt then
if type(t) == 'table' and getmetatable(t) == packedmt then
-- Remove length tag
t.n = nil
end
return t
end
--- Determine if the current iterator stage should continue.
---
--- If any arguments are passed to this function, then return those arguments
--- and stop the current iterator stage. Otherwise, return true to signal that
--- the current stage should continue.
---
---@param ... any Function arguments.
---@return boolean True if the iterator stage should continue, false otherwise
---@return any Function arguments.
---@private
local function continue(...)
if select('#', ...) > 0 then
return false, ...
end
return true
end
--- If no input arguments are given return false, indicating the current
--- iterator stage should stop. Otherwise, apply the arguments to the function
--- f. If that function returns no values, the current iterator stage continues.
--- Otherwise, those values are returned.
---
---@param f function Function to call with the given arguments
---@param ... any Arguments to apply to f
---@return boolean True if the iterator pipeline should continue, false otherwise
---@return any Return values of f
---@private
local function apply(f, ...)
if select('#', ...) > 0 then
return continue(f(...))
end
return false
end
--- Add a filter step to the iterator pipeline.
---
--- Example:
@@ -106,33 +141,16 @@ end
--- local bufs = vim.iter(vim.api.nvim_list_bufs()):filter(vim.api.nvim_buf_is_loaded)
--- </pre>
---
---@param f function(...):bool Takes all values returned from the previous stage in the pipeline and
--- returns false or nil if the current iterator element should be
--- removed.
---@param f function(...):bool Takes all values returned from the previous stage
--- in the pipeline and returns false or nil if the
--- current iterator element should be removed.
---@return Iter
function Iter.filter(self, f)
---@private
local function fn(...)
local result = nil
if select(1, ...) ~= nil then
if not f(...) then
return true, nil
else
result = pack(...)
end
return self:map(function(...)
if f(...) then
return ...
end
return false, result
end
local next = self.next
self.next = function(this)
local cont, result
repeat
cont, result = fn(next(this))
until not cont
return unpack(result)
end
return self
end)
end
---@private
@@ -165,31 +183,52 @@ end
--- -- { 6, 12 }
--- </pre>
---
---@param f function(...):any Mapping function. Takes all values returned from the previous stage
--- in the pipeline as arguments and returns one or more new values,
--- which are used in the next pipeline stage. Nil return values returned
--- are filtered from the output.
---@param f function(...):any Mapping function. Takes all values returned from
--- the previous stage in the pipeline as arguments
--- and returns one or more new values, which are used
--- in the next pipeline stage. Nil return values
--- are filtered from the output.
---@return Iter
function Iter.map(self, f)
---@private
local function fn(...)
local result = nil
if select(1, ...) ~= nil then
result = pack(f(...))
if result == nil then
return true, nil
end
end
return false, result
end
-- Implementation note: the reader may be forgiven for observing that this
-- function appears excessively convoluted. The problem to solve is that each
-- stage of the iterator pipeline can return any number of values, and the
-- number of values could even change per iteration. And the return values
-- must be checked to determine if the pipeline has ended, so we cannot
-- naively forward them along to the next stage.
--
-- A simple approach is to pack all of the return values into a table, check
-- for nil, then unpack the table for the next stage. However, packing and
-- unpacking tables is quite slow. There is no other way in Lua to handle an
-- unknown number of function return values than to simply forward those
-- values along to another function. Hence the intricate function passing you
-- see here.
local next = self.next
self.next = function(this)
local cont, result
repeat
cont, result = fn(next(this))
until not cont
return unpack(result)
--- Drain values from the upstream iterator source until a value can be
--- returned.
---
--- This is a recursive function. The base case is when the first argument is
--- false, which indicates that the rest of the arguments should be returned
--- as the values for the current iteration stage.
---
---@param cont boolean If true, the current iterator stage should continue to
--- pull values from its upstream pipeline stage.
--- Otherwise, this stage is complete and returns the
--- values passed.
---@param ... any Values to return if cont is false.
---@return any
---@private
local function fn(cont, ...)
if cont then
return fn(apply(f, next(self)))
end
return ...
end
self.next = function()
return fn(apply(f, next(self)))
end
return self
end
@@ -211,17 +250,18 @@ end
--- Call a function once for each item in the pipeline.
---
--- This is used for functions which have side effects. To modify the values in the iterator, use
--- |Iter:map()|.
--- This is used for functions which have side effects. To modify the values in
--- the iterator, use |Iter:map()|.
---
--- This function drains the iterator.
---
---@param f function(...) Function to execute for each item in the pipeline. Takes all of the
--- values returned by the previous stage in the pipeline as arguments.
---@param f function(...) Function to execute for each item in the pipeline.
--- Takes all of the values returned by the previous stage
--- in the pipeline as arguments.
function Iter.each(self, f)
---@private
local function fn(...)
if select(1, ...) ~= nil then
if select('#', ...) > 0 then
f(...)
return true
end