fix(vim.iter): enable optimizations for arrays (lists with holes) (#28781)

The optimizations that vim.iter uses for array-like tables don't require
that the source table has no holes. The only thing that needs to change
is the determination if a table is "list-like": rather than requiring
consecutive, integer keys, we can simply test for (positive) integer
keys only, and remove any holes in the original array when we make a
copy for the iterator.
This commit is contained in:
Gregory Anders
2024-05-17 14:17:25 -05:00
committed by GitHub
parent aec4938a21
commit 4c0d18c197
3 changed files with 82 additions and 79 deletions

View File

@@ -3830,6 +3830,7 @@ chained to create iterator "pipelines": the output of each pipeline stage is
input to the next stage. The first stage depends on the type passed to input to the next stage. The first stage depends on the type passed to
`vim.iter()`: `vim.iter()`:
• List tables (arrays, |lua-list|) yield only the value of each element. • List tables (arrays, |lua-list|) yield only the value of each element.
• Holes (nil values) are allowed.
• Use |Iter:enumerate()| to also pass the index to the next stage. • Use |Iter:enumerate()| to also pass the index to the next stage.
• Or initialize with ipairs(): `vim.iter(ipairs(…))`. • Or initialize with ipairs(): `vim.iter(ipairs(…))`.
• Non-list tables (|lua-dict|) yield both the key and value of each element. • Non-list tables (|lua-dict|) yield both the key and value of each element.
@@ -4287,9 +4288,9 @@ Iter:totable() *Iter:totable()*
Collect the iterator into a table. Collect the iterator into a table.
The resulting table depends on the initial source in the iterator The resulting table depends on the initial source in the iterator
pipeline. List-like tables and function iterators will be collected into a pipeline. Array-like tables and function iterators will be collected into
list-like table. If multiple values are returned from the final stage in an array-like table. If multiple values are returned from the final stage
the iterator pipeline, each value will be included in a table. in the iterator pipeline, each value will be included in a table.
Examples: >lua Examples: >lua
vim.iter(string.gmatch('100 20 50', '%d+')):map(tonumber):totable() vim.iter(string.gmatch('100 20 50', '%d+')):map(tonumber):totable()
@@ -4302,7 +4303,7 @@ Iter:totable() *Iter:totable()*
-- { { 'a', 1 }, { 'c', 3 } } -- { { 'a', 1 }, { 'c', 3 } }
< <
The generated table is a list-like table with consecutive, numeric The generated table is an array-like table with consecutive, numeric
indices. To create a map-like table with arbitrary keys, use indices. To create a map-like table with arbitrary keys, use
|Iter:fold()|. |Iter:fold()|.

View File

@@ -7,6 +7,7 @@
--- `vim.iter()`: --- `vim.iter()`:
--- ---
--- - List tables (arrays, |lua-list|) yield only the value of each element. --- - List tables (arrays, |lua-list|) yield only the value of each element.
--- - Holes (nil values) are allowed.
--- - Use |Iter:enumerate()| to also pass the index to the next stage. --- - Use |Iter:enumerate()| to also pass the index to the next stage.
--- - Or initialize with ipairs(): `vim.iter(ipairs(…))`. --- - Or initialize with ipairs(): `vim.iter(ipairs(…))`.
--- - Non-list tables (|lua-dict|) yield both the key and value of each element. --- - Non-list tables (|lua-dict|) yield both the key and value of each element.
@@ -80,13 +81,13 @@ end
--- Special case implementations for iterators on list tables. --- Special case implementations for iterators on list tables.
---@nodoc ---@nodoc
---@class ListIter : Iter ---@class ArrayIter : Iter
---@field _table table Underlying table data ---@field _table table Underlying table data
---@field _head number Index to the front of a table iterator ---@field _head number Index to the front of a table iterator
---@field _tail number Index to the end of a table iterator (exclusive) ---@field _tail number Index to the end of a table iterator (exclusive)
local ListIter = {} local ArrayIter = {}
ListIter.__index = setmetatable(ListIter, Iter) ArrayIter.__index = setmetatable(ArrayIter, Iter)
ListIter.__call = function(self) ArrayIter.__call = function(self)
return self:next() return self:next()
end end
@@ -110,36 +111,34 @@ end
local function sanitize(t) local function sanitize(t)
if type(t) == 'table' and getmetatable(t) == packedmt then if type(t) == 'table' and getmetatable(t) == packedmt then
-- Remove length tag -- Remove length tag and metatable
t.n = nil t.n = nil
setmetatable(t, nil)
end end
return t return t
end end
--- Flattens a single list-like table. Errors if it attempts to flatten a --- Flattens a single array-like table. Errors if it attempts to flatten a
--- dict-like table --- dict-like table
---@param v table table which should be flattened ---@param t table table which should be flattened
---@param max_depth number depth to which the table should be flattened ---@param max_depth number depth to which the table should be flattened
---@param depth number current iteration depth ---@param depth number current iteration depth
---@param result table output table that contains flattened result ---@param result table output table that contains flattened result
---@return table|nil flattened table if it can be flattened, otherwise nil ---@return table|nil flattened table if it can be flattened, otherwise nil
local function flatten(v, max_depth, depth, result) local function flatten(t, max_depth, depth, result)
if depth < max_depth and type(v) == 'table' then if depth < max_depth and type(t) == 'table' then
local i = 0 for k, v in pairs(t) do
for _ in pairs(v) do if type(k) ~= 'number' or k <= 0 or math.floor(k) ~= k then
i = i + 1
if v[i] == nil then
-- short-circuit: this is not a list like table -- short-circuit: this is not a list like table
return nil return nil
end end
if flatten(v[i], max_depth, depth + 1, result) == nil then if flatten(v, max_depth, depth + 1, result) == nil then
return nil return nil
end end
end end
else elseif t ~= nil then
result[#result + 1] = v result[#result + 1] = t
end end
return result return result
@@ -198,7 +197,7 @@ function Iter:filter(f)
end end
---@private ---@private
function ListIter:filter(f) function ArrayIter:filter(f)
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
local n = self._head local n = self._head
for i = self._head, self._tail - inc, inc do for i = self._head, self._tail - inc, inc do
@@ -233,11 +232,11 @@ end
---@return Iter ---@return Iter
---@diagnostic disable-next-line:unused-local ---@diagnostic disable-next-line:unused-local
function Iter:flatten(depth) -- luacheck: no unused args function Iter:flatten(depth) -- luacheck: no unused args
error('flatten() requires a list-like table') error('flatten() requires an array-like table')
end end
---@private ---@private
function ListIter:flatten(depth) function ArrayIter:flatten(depth)
depth = depth or 1 depth = depth or 1
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
local target = {} local target = {}
@@ -247,7 +246,7 @@ function ListIter:flatten(depth)
-- exit early if we try to flatten a dict-like table -- exit early if we try to flatten a dict-like table
if flattened == nil then if flattened == nil then
error('flatten() requires a list-like table') error('flatten() requires an array-like table')
end end
for _, v in pairs(flattened) do for _, v in pairs(flattened) do
@@ -327,7 +326,7 @@ function Iter:map(f)
end end
---@private ---@private
function ListIter:map(f) function ArrayIter:map(f)
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
local n = self._head local n = self._head
for i = self._head, self._tail - inc, inc do for i = self._head, self._tail - inc, inc do
@@ -360,7 +359,7 @@ function Iter:each(f)
end end
---@private ---@private
function ListIter:each(f) function ArrayIter:each(f)
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do for i = self._head, self._tail - inc, inc do
f(unpack(self._table[i])) f(unpack(self._table[i]))
@@ -371,7 +370,7 @@ end
--- Collect the iterator into a table. --- Collect the iterator into a table.
--- ---
--- The resulting table depends on the initial source in the iterator pipeline. --- The resulting table depends on the initial source in the iterator pipeline.
--- List-like tables and function iterators will be collected into a list-like --- Array-like tables and function iterators will be collected into an array-like
--- table. If multiple values are returned from the final stage in the iterator --- table. If multiple values are returned from the final stage in the iterator
--- pipeline, each value will be included in a table. --- pipeline, each value will be included in a table.
--- ---
@@ -388,7 +387,7 @@ end
--- -- { { 'a', 1 }, { 'c', 3 } } --- -- { { 'a', 1 }, { 'c', 3 } }
--- ``` --- ```
--- ---
--- The generated table is a list-like table with consecutive, numeric indices. --- The generated table is an array-like table with consecutive, numeric indices.
--- To create a map-like table with arbitrary keys, use |Iter:fold()|. --- To create a map-like table with arbitrary keys, use |Iter:fold()|.
--- ---
--- ---
@@ -408,12 +407,12 @@ function Iter:totable()
end end
---@private ---@private
function ListIter:totable() function ArrayIter:totable()
if self.next ~= ListIter.next or self._head >= self._tail then if self.next ~= ArrayIter.next or self._head >= self._tail then
return Iter.totable(self) return Iter.totable(self)
end end
local needs_sanitize = getmetatable(self._table[1]) == packedmt local needs_sanitize = getmetatable(self._table[self._head]) == packedmt
-- Reindex and sanitize. -- Reindex and sanitize.
local len = self._tail - self._head local len = self._tail - self._head
@@ -493,7 +492,7 @@ function Iter:fold(init, f)
end end
---@private ---@private
function ListIter:fold(init, f) function ArrayIter:fold(init, f)
local acc = init local acc = init
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do for i = self._head, self._tail - inc, inc do
@@ -525,7 +524,7 @@ function Iter:next()
end end
---@private ---@private
function ListIter:next() function ArrayIter:next()
if self._head ~= self._tail then if self._head ~= self._tail then
local v = self._table[self._head] local v = self._table[self._head]
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
@@ -548,11 +547,11 @@ end
--- ---
---@return Iter ---@return Iter
function Iter:rev() function Iter:rev()
error('rev() requires a list-like table') error('rev() requires an array-like table')
end end
---@private ---@private
function ListIter:rev() function ArrayIter:rev()
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
self._head, self._tail = self._tail - inc, self._head - inc self._head, self._tail = self._tail - inc, self._head - inc
return self return self
@@ -576,11 +575,11 @@ end
--- ---
---@return any ---@return any
function Iter:peek() function Iter:peek()
error('peek() requires a list-like table') error('peek() requires an array-like table')
end end
---@private ---@private
function ListIter:peek() function ArrayIter:peek()
if self._head ~= self._tail then if self._head ~= self._tail then
return self._table[self._head] return self._table[self._head]
end end
@@ -657,11 +656,11 @@ end
---@return any ---@return any
---@diagnostic disable-next-line: unused-local ---@diagnostic disable-next-line: unused-local
function Iter:rfind(f) -- luacheck: no unused args function Iter:rfind(f) -- luacheck: no unused args
error('rfind() requires a list-like table') error('rfind() requires an array-like table')
end end
---@private ---@private
function ListIter:rfind(f) function ArrayIter:rfind(f)
if type(f) ~= 'function' then if type(f) ~= 'function' then
local val = f local val = f
f = function(v) f = function(v)
@@ -709,10 +708,10 @@ function Iter:take(n)
end end
---@private ---@private
function ListIter:take(n) function ArrayIter:take(n)
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and n or -n
local cmp = self._head < self._tail and math.min or math.max local cmp = self._head < self._tail and math.min or math.max
self._tail = cmp(self._tail, self._head + n * inc) self._tail = cmp(self._tail, self._head + inc)
return self return self
end end
@@ -730,11 +729,11 @@ end
--- ---
---@return any ---@return any
function Iter:pop() function Iter:pop()
error('pop() requires a list-like table') error('pop() requires an array-like table')
end end
--- @nodoc --- @nodoc
function ListIter:pop() function ArrayIter:pop()
if self._head ~= self._tail then if self._head ~= self._tail then
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
self._tail = self._tail - inc self._tail = self._tail - inc
@@ -760,11 +759,11 @@ end
--- ---
---@return any ---@return any
function Iter:rpeek() function Iter:rpeek()
error('rpeek() requires a list-like table') error('rpeek() requires an array-like table')
end end
---@nodoc ---@nodoc
function ListIter:rpeek() function ArrayIter:rpeek()
if self._head ~= self._tail then if self._head ~= self._tail then
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
return self._table[self._tail - inc] return self._table[self._tail - inc]
@@ -793,7 +792,7 @@ function Iter:skip(n)
end end
---@private ---@private
function ListIter:skip(n) function ArrayIter:skip(n)
local inc = self._head < self._tail and n or -n local inc = self._head < self._tail and n or -n
self._head = self._head + inc self._head = self._head + inc
if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then
@@ -818,11 +817,11 @@ end
---@return Iter ---@return Iter
---@diagnostic disable-next-line: unused-local ---@diagnostic disable-next-line: unused-local
function Iter:rskip(n) -- luacheck: no unused args function Iter:rskip(n) -- luacheck: no unused args
error('rskip() requires a list-like table') error('rskip() requires an array-like table')
end end
---@private ---@private
function ListIter:rskip(n) function ArrayIter:rskip(n)
local inc = self._head < self._tail and n or -n local inc = self._head < self._tail and n or -n
self._tail = self._tail - inc self._tail = self._tail - inc
if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then
@@ -870,11 +869,11 @@ end
---@return Iter ---@return Iter
---@diagnostic disable-next-line: unused-local ---@diagnostic disable-next-line: unused-local
function Iter:slice(first, last) -- luacheck: no unused args function Iter:slice(first, last) -- luacheck: no unused args
error('slice() requires a list-like table') error('slice() requires an array-like table')
end end
---@private ---@private
function ListIter:slice(first, last) function ArrayIter:slice(first, last)
return self:skip(math.max(0, first - 1)):rskip(math.max(0, self._tail - last - 1)) return self:skip(math.max(0, first - 1)):rskip(math.max(0, self._tail - last - 1))
end end
@@ -955,7 +954,7 @@ function Iter:last()
end end
---@private ---@private
function ListIter:last() function ArrayIter:last()
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
local v = self._table[self._tail - inc] local v = self._table[self._tail - inc]
self._head = self._tail self._head = self._tail
@@ -1000,7 +999,7 @@ function Iter:enumerate()
end end
---@private ---@private
function ListIter:enumerate() function ArrayIter:enumerate()
local inc = self._head < self._tail and 1 or -1 local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do for i = self._head, self._tail - inc, inc do
local v = self._table[i] local v = self._table[i]
@@ -1030,17 +1029,14 @@ function Iter.new(src, ...)
local t = {} local t = {}
-- O(n): scan the source table to decide if it is a list (consecutive integer indices 1…n). -- O(n): scan the source table to decide if it is an array (only positive integer indices).
local count = 0 for k, v in pairs(src) do
for _ in pairs(src) do if type(k) ~= 'number' or k <= 0 or math.floor(k) ~= k then
count = count + 1
local v = src[count]
if v == nil then
return Iter.new(pairs(src)) return Iter.new(pairs(src))
end end
t[count] = v t[#t + 1] = v
end end
return ListIter.new(t) return ArrayIter.new(t)
end end
if type(src) == 'function' then if type(src) == 'function' then
@@ -1068,17 +1064,18 @@ function Iter.new(src, ...)
return it return it
end end
--- Create a new ListIter --- Create a new ArrayIter
--- ---
---@param t table List-like table. Caller guarantees that this table is a valid list. ---@param t table Array-like table. Caller guarantees that this table is a valid array. Can have
--- holes (nil values).
---@return Iter ---@return Iter
---@private ---@private
function ListIter.new(t) function ArrayIter.new(t)
local it = {} local it = {}
it._table = t it._table = t
it._head = 1 it._head = 1
it._tail = #t + 1 it._tail = #t + 1
setmetatable(it, ListIter) setmetatable(it, ArrayIter)
return it return it
end end

View File

@@ -117,6 +117,9 @@ describe('vim.iter', function()
eq({ { 1, 1 }, { 2, 4 }, { 3, 9 } }, it:totable()) eq({ { 1, 1 }, { 2, 4 }, { 3, 9 } }, it:totable())
end end
-- Holes in array-like tables are removed
eq({ 1, 2, 3 }, vim.iter({ 1, nil, 2, nil, 3 }):totable())
do do
local it = vim.iter(string.gmatch('1,4,lol,17,blah,2,9,3', '%d+')):map(tonumber) local it = vim.iter(string.gmatch('1,4,lol,17,blah,2,9,3', '%d+')):map(tonumber)
eq({ 1, 4, 17, 2, 9, 3 }, it:totable()) eq({ 1, 4, 17, 2, 9, 3 }, it:totable())
@@ -142,7 +145,7 @@ describe('vim.iter', function()
eq({ 3, 2, 1 }, vim.iter({ 1, 2, 3 }):rev():totable()) eq({ 3, 2, 1 }, vim.iter({ 1, 2, 3 }):rev():totable())
local it = vim.iter(string.gmatch('abc', '%w')) local it = vim.iter(string.gmatch('abc', '%w'))
matches('rev%(%) requires a list%-like table', pcall_err(it.rev, it)) matches('rev%(%) requires an array%-like table', pcall_err(it.rev, it))
end) end)
it('skip()', function() it('skip()', function()
@@ -181,7 +184,7 @@ describe('vim.iter', function()
end end
local it = vim.iter(vim.gsplit('a|b|c|d', '|')) local it = vim.iter(vim.gsplit('a|b|c|d', '|'))
matches('rskip%(%) requires a list%-like table', pcall_err(it.rskip, it, 0)) matches('rskip%(%) requires an array%-like table', pcall_err(it.rskip, it, 0))
end) end)
it('slice()', function() it('slice()', function()
@@ -195,7 +198,7 @@ describe('vim.iter', function()
eq({ 8, 9, 10 }, vim.iter(q):slice(8, 11):totable()) eq({ 8, 9, 10 }, vim.iter(q):slice(8, 11):totable())
local it = vim.iter(vim.gsplit('a|b|c|d', '|')) local it = vim.iter(vim.gsplit('a|b|c|d', '|'))
matches('slice%(%) requires a list%-like table', pcall_err(it.slice, it, 1, 3)) matches('slice%(%) requires an array%-like table', pcall_err(it.slice, it, 1, 3))
end) end)
it('nth()', function() it('nth()', function()
@@ -234,7 +237,7 @@ describe('vim.iter', function()
end end
local it = vim.iter(vim.gsplit('a|b|c|d', '|')) local it = vim.iter(vim.gsplit('a|b|c|d', '|'))
matches('rskip%(%) requires a list%-like table', pcall_err(it.nth, it, -1)) matches('rskip%(%) requires an array%-like table', pcall_err(it.nth, it, -1))
end) end)
it('take()', function() it('take()', function()
@@ -356,7 +359,7 @@ describe('vim.iter', function()
do do
local it = vim.iter(vim.gsplit('hi', '')) local it = vim.iter(vim.gsplit('hi', ''))
matches('peek%(%) requires a list%-like table', pcall_err(it.peek, it)) matches('peek%(%) requires an array%-like table', pcall_err(it.peek, it))
end end
end) end)
@@ -417,7 +420,7 @@ describe('vim.iter', function()
do do
local it = vim.iter(vim.gsplit('AbCdE', '')) local it = vim.iter(vim.gsplit('AbCdE', ''))
matches('rfind%(%) requires a list%-like table', pcall_err(it.rfind, it, 'E')) matches('rfind%(%) requires an array%-like table', pcall_err(it.rfind, it, 'E'))
end end
end) end)
@@ -434,7 +437,7 @@ describe('vim.iter', function()
do do
local it = vim.iter(vim.gsplit('hi', '')) local it = vim.iter(vim.gsplit('hi', ''))
matches('pop%(%) requires a list%-like table', pcall_err(it.pop, it)) matches('pop%(%) requires an array%-like table', pcall_err(it.pop, it))
end end
end) end)
@@ -448,7 +451,7 @@ describe('vim.iter', function()
do do
local it = vim.iter(vim.gsplit('hi', '')) local it = vim.iter(vim.gsplit('hi', ''))
matches('rpeek%(%) requires a list%-like table', pcall_err(it.rpeek, it)) matches('rpeek%(%) requires an array%-like table', pcall_err(it.rpeek, it))
end end
end) end)
@@ -482,18 +485,20 @@ describe('vim.iter', function()
local m = { a = 1, b = { 2, 3 }, d = { 4 } } local m = { a = 1, b = { 2, 3 }, d = { 4 } }
local it = vim.iter(m) local it = vim.iter(m)
local flat_err = 'flatten%(%) requires a list%-like table' local flat_err = 'flatten%(%) requires an array%-like table'
matches(flat_err, pcall_err(it.flatten, it)) matches(flat_err, pcall_err(it.flatten, it))
-- cases from the documentation -- cases from the documentation
local simple_example = { 1, { 2 }, { { 3 } } } local simple_example = { 1, { 2 }, { { 3 } } }
eq({ 1, 2, { 3 } }, vim.iter(simple_example):flatten():totable()) eq({ 1, 2, { 3 } }, vim.iter(simple_example):flatten():totable())
local not_list_like = vim.iter({ [2] = 2 }) local not_list_like = { [2] = 2 }
matches(flat_err, pcall_err(not_list_like.flatten, not_list_like)) eq({ 2 }, vim.iter(not_list_like):flatten():totable())
local also_not_list_like = vim.iter({ nil, 2 }) local also_not_list_like = { nil, 2 }
matches(flat_err, pcall_err(not_list_like.flatten, also_not_list_like)) eq({ 2 }, vim.iter(also_not_list_like):flatten():totable())
eq({ 1, 2, 3 }, vim.iter({ nil, { 1, nil, 2 }, 3 }):flatten():totable())
local nested_non_lists = vim.iter({ 1, { { a = 2 } }, { { nil } }, { 3 } }) local nested_non_lists = vim.iter({ 1, { { a = 2 } }, { { nil } }, { 3 } })
eq({ 1, { a = 2 }, { nil }, 3 }, nested_non_lists:flatten():totable()) eq({ 1, { a = 2 }, { nil }, 3 }, nested_non_lists:flatten():totable())