feat(runtime): accept predicates in take and skip (#34657)

Make `vim.iter():take()` and `vim.iter():skip()`
optionally accept predicates to enable takewhile
and skipwhile patterns used in functional
programming.
This commit is contained in:
Mart-Mihkel Aun
2025-07-03 16:12:24 +03:00
committed by GitHub
parent 715c28d67f
commit f01419f3d5
4 changed files with 149 additions and 13 deletions

View File

@@ -4545,17 +4545,24 @@ Iter:rskip({n}) *Iter:rskip()*
(`Iter`) (`Iter`)
Iter:skip({n}) *Iter:skip()* Iter:skip({n}) *Iter:skip()*
Skips `n` values of an iterator pipeline. Skips `n` values of an iterator pipeline, or all values satisfying a
predicate of a |list-iterator|.
Example: >lua Example: >lua
local it = vim.iter({ 3, 6, 9, 12 }):skip(2) local it = vim.iter({ 3, 6, 9, 12 }):skip(2)
it:next() it:next()
-- 9 -- 9
local function pred(x) return x < 10 end
local it2 = vim.iter({ 3, 6, 9, 12 }):skip(pred)
it2:next()
-- 12
< <
Parameters: ~ Parameters: ~
• {n} (`number`) Number of values to skip. • {n} (`integer|fun(...):boolean`) Number of values to skip or a
predicate.
Return: ~ Return: ~
(`Iter`) (`Iter`)
@@ -4573,7 +4580,8 @@ Iter:slice({first}, {last}) *Iter:slice()*
(`Iter`) (`Iter`)
Iter:take({n}) *Iter:take()* Iter:take({n}) *Iter:take()*
Transforms an iterator to yield only the first n values. Transforms an iterator to yield only the first n values, or all values
satisfying a predicate.
Example: >lua Example: >lua
local it = vim.iter({ 1, 2, 3, 4 }):take(2) local it = vim.iter({ 1, 2, 3, 4 }):take(2)
@@ -4583,10 +4591,18 @@ Iter:take({n}) *Iter:take()*
-- 2 -- 2
it:next() it:next()
-- nil -- nil
local function pred(x) return x < 2 end
local it2 = vim.iter({ 1, 2, 3, 4 }):take(pred)
it2:next()
-- 1
it2:next()
-- nil
< <
Parameters: ~ Parameters: ~
• {n} (`integer`) • {n} (`integer|fun(...):boolean`) Number of values to take or a
predicate.
Return: ~ Return: ~
(`Iter`) (`Iter`)

View File

@@ -201,6 +201,7 @@ LUA
• |vim.fs.root()| can define "equal priority" via nested lists. • |vim.fs.root()| can define "equal priority" via nested lists.
• |vim.version.range()| output can be converted to human-readable string with |tostring()|. • |vim.version.range()| output can be converted to human-readable string with |tostring()|.
• |vim.version.intersect()| computes intersection of two version ranges. • |vim.version.intersect()| computes intersection of two version ranges.
• |Iter:take()| and |Iter:skip()| now optionally accept predicates.
OPTIONS OPTIONS

View File

@@ -681,7 +681,8 @@ function ArrayIter:rfind(f)
self._head = self._tail self._head = self._tail
end end
--- Transforms an iterator to yield only the first n values. --- Transforms an iterator to yield only the first n values, or all values
--- satisfying a predicate.
--- ---
--- Example: --- Example:
--- ---
@@ -693,24 +694,56 @@ end
--- -- 2 --- -- 2
--- it:next() --- it:next()
--- -- nil --- -- nil
---
--- local function pred(x) return x < 2 end
--- local it2 = vim.iter({ 1, 2, 3, 4 }):take(pred)
--- it2:next()
--- -- 1
--- it2:next()
--- -- nil
--- ``` --- ```
--- ---
---@param n integer ---@param n integer|fun(...):boolean Number of values to take or a predicate.
---@return Iter ---@return Iter
function Iter:take(n) function Iter:take(n)
local next = self.next
local i = 0 local i = 0
self.next = function() local f = n
if i < n then if type(n) ~= 'function' then
i = i + 1 f = function()
return next(self) return i < n
end end
end end
local stop = false
local function fn(...)
if not stop and select(1, ...) ~= nil and f(...) then
i = i + 1
return ...
else
stop = true
end
end
local next = self.next
self.next = function()
return fn(next(self))
end
return self return self
end end
---@private ---@private
function ArrayIter:take(n) function ArrayIter:take(n)
if type(n) == 'function' then
local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail, inc do
if not n(unpack(self._table[i])) then
self._tail = i
break
end
end
return self
end
local inc = self._head < self._tail and n or -n local inc = self._head < self._tail and n or -n
local cmp = self._head < self._tail and math.min or math.max local cmp = self._head < self._tail and math.min or math.max
self._tail = cmp(self._tail, self._head + inc) self._tail = cmp(self._tail, self._head + inc)
@@ -772,7 +805,8 @@ function ArrayIter:rpeek()
end end
end end
--- Skips `n` values of an iterator pipeline. --- Skips `n` values of an iterator pipeline, or all values satisfying a
--- predicate of a |list-iterator|.
--- ---
--- Example: --- Example:
--- ---
@@ -782,11 +816,20 @@ end
--- it:next() --- it:next()
--- -- 9 --- -- 9
--- ---
--- local function pred(x) return x < 10 end
--- local it2 = vim.iter({ 3, 6, 9, 12 }):skip(pred)
--- it2:next()
--- -- 12
--- ``` --- ```
--- ---
---@param n number Number of values to skip. ---@param n integer|fun(...):boolean Number of values to skip or a predicate.
---@return Iter ---@return Iter
function Iter:skip(n) function Iter:skip(n)
if type(n) == 'function' then
-- We would need to evaluate the perdicate without advancing iterator
error('skip() with predicate requires an array-like table')
end
for _ = 1, n do for _ = 1, n do
local _ = self:next() local _ = self:next()
end end
@@ -795,6 +838,16 @@ end
---@private ---@private
function ArrayIter:skip(n) function ArrayIter:skip(n)
if type(n) == 'function' then
local inc = self._head < self._tail and 1 or -1
local i = self._head
while n(unpack(self:peek())) and i ~= self._tail do
self:next()
i = i + inc
end
return self
end
local inc = self._head < self._tail and n or -n local inc = self._head < self._tail and n or -n
self._head = self._head + inc self._head = self._head + inc
if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then

View File

@@ -159,6 +159,30 @@ describe('vim.iter', function()
eq({}, vim.iter(q):skip(#q + 1):totable()) eq({}, vim.iter(q):skip(#q + 1):totable())
end end
do
local function wrong()
return false
end
local function correct()
return true
end
local q = { 4, 3, 2, 1 }
eq({ 4, 3, 2, 1 }, vim.iter(q):skip(wrong):totable())
eq(
{ 2, 1 },
vim
.iter(q)
:skip(function(x)
return x > 2
end)
:totable()
)
eq({}, vim.iter(q):skip(correct):totable())
end
do do
local function skip(n) local function skip(n)
return vim.iter(vim.gsplit('a|b|c|d', '|')):skip(n):totable() return vim.iter(vim.gsplit('a|b|c|d', '|')):skip(n):totable()
@@ -241,6 +265,14 @@ describe('vim.iter', function()
end) end)
it('take()', function() it('take()', function()
local function correct()
return true
end
local function wrong()
return false
end
do do
local q = { 4, 3, 2, 1 } local q = { 4, 3, 2, 1 }
eq({}, vim.iter(q):take(0):totable()) eq({}, vim.iter(q):take(0):totable())
@@ -251,6 +283,22 @@ describe('vim.iter', function()
eq({ 4, 3, 2, 1 }, vim.iter(q):take(5):totable()) eq({ 4, 3, 2, 1 }, vim.iter(q):take(5):totable())
end end
do
local q = { 4, 3, 2, 1 }
eq({}, vim.iter(q):take(wrong):totable())
eq(
{ 4, 3 },
vim
.iter(q)
:take(function(x)
return x > 2
end)
:totable()
)
eq({ 4, 3, 2, 1 }, vim.iter(q):take(correct):totable())
end
do do
local q = { 4, 3, 2, 1 } local q = { 4, 3, 2, 1 }
eq({ 1, 2, 3 }, vim.iter(q):rev():take(3):totable()) eq({ 1, 2, 3 }, vim.iter(q):rev():take(3):totable())
@@ -271,6 +319,24 @@ describe('vim.iter', function()
-- non-array iterators are consumed by take() -- non-array iterators are consumed by take()
eq({}, it:take(2):totable()) eq({}, it:take(2):totable())
end end
do
eq({ 'a', 'b', 'c', 'd' }, vim.iter(vim.gsplit('a|b|c|d', '|')):take(correct):totable())
eq(
{ 'a', 'b', 'c' },
vim
.iter(vim.gsplit('a|b|c|d', '|'))
:enumerate()
:take(function(i, x)
return i < 3 or x == 'c'
end)
:map(function(_, x)
return x
end)
:totable()
)
eq({}, vim.iter(vim.gsplit('a|b|c|d', '|')):take(wrong):totable())
end
end) end)
it('any()', function() it('any()', function()