diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt index 94831e0123..3df04c4845 100644 --- a/runtime/doc/lua.txt +++ b/runtime/doc/lua.txt @@ -4545,17 +4545,24 @@ Iter:rskip({n}) *Iter:rskip()* (`Iter`) Iter:skip({n}) *Iter:skip()* - Skips `n` values of an iterator pipeline. + Skips `n` values of an iterator pipeline, or all values satisfying a + predicate of a |list-iterator|. Example: >lua local it = vim.iter({ 3, 6, 9, 12 }):skip(2) it:next() -- 9 + + local function pred(x) return x < 10 end + local it2 = vim.iter({ 3, 6, 9, 12 }):skip(pred) + it2:next() + -- 12 < Parameters: ~ - • {n} (`number`) Number of values to skip. + • {n} (`integer|fun(...):boolean`) Number of values to skip or a + predicate. Return: ~ (`Iter`) @@ -4573,7 +4580,8 @@ Iter:slice({first}, {last}) *Iter:slice()* (`Iter`) Iter:take({n}) *Iter:take()* - Transforms an iterator to yield only the first n values. + Transforms an iterator to yield only the first n values, or all values + satisfying a predicate. Example: >lua local it = vim.iter({ 1, 2, 3, 4 }):take(2) @@ -4583,10 +4591,18 @@ Iter:take({n}) *Iter:take()* -- 2 it:next() -- nil + + local function pred(x) return x < 2 end + local it2 = vim.iter({ 1, 2, 3, 4 }):take(pred) + it2:next() + -- 1 + it2:next() + -- nil < Parameters: ~ - • {n} (`integer`) + • {n} (`integer|fun(...):boolean`) Number of values to take or a + predicate. Return: ~ (`Iter`) diff --git a/runtime/doc/news.txt b/runtime/doc/news.txt index 79ef80187c..a9524eb35b 100644 --- a/runtime/doc/news.txt +++ b/runtime/doc/news.txt @@ -201,6 +201,7 @@ LUA • |vim.fs.root()| can define "equal priority" via nested lists. • |vim.version.range()| output can be converted to human-readable string with |tostring()|. • |vim.version.intersect()| computes intersection of two version ranges. +• |Iter:take()| and |Iter:skip()| now optionally accept predicates. OPTIONS diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua index bdbe2be95a..c1dd20f745 100644 --- a/runtime/lua/vim/iter.lua +++ b/runtime/lua/vim/iter.lua @@ -681,7 +681,8 @@ function ArrayIter:rfind(f) self._head = self._tail end ---- Transforms an iterator to yield only the first n values. +--- Transforms an iterator to yield only the first n values, or all values +--- satisfying a predicate. --- --- Example: --- @@ -693,24 +694,56 @@ end --- -- 2 --- it:next() --- -- nil +--- +--- local function pred(x) return x < 2 end +--- local it2 = vim.iter({ 1, 2, 3, 4 }):take(pred) +--- it2:next() +--- -- 1 +--- it2:next() +--- -- nil --- ``` --- ----@param n integer +---@param n integer|fun(...):boolean Number of values to take or a predicate. ---@return Iter function Iter:take(n) - local next = self.next local i = 0 - self.next = function() - if i < n then - i = i + 1 - return next(self) + local f = n + if type(n) ~= 'function' then + f = function() + return i < n end end + + local stop = false + local function fn(...) + if not stop and select(1, ...) ~= nil and f(...) then + i = i + 1 + return ... + else + stop = true + end + end + + local next = self.next + self.next = function() + return fn(next(self)) + end return self end ---@private function ArrayIter:take(n) + if type(n) == 'function' then + local inc = self._head < self._tail and 1 or -1 + for i = self._head, self._tail, inc do + if not n(unpack(self._table[i])) then + self._tail = i + break + end + end + return self + end + local inc = self._head < self._tail and n or -n local cmp = self._head < self._tail and math.min or math.max self._tail = cmp(self._tail, self._head + inc) @@ -772,7 +805,8 @@ function ArrayIter:rpeek() end end ---- Skips `n` values of an iterator pipeline. +--- Skips `n` values of an iterator pipeline, or all values satisfying a +--- predicate of a |list-iterator|. --- --- Example: --- @@ -782,11 +816,20 @@ end --- it:next() --- -- 9 --- +--- local function pred(x) return x < 10 end +--- local it2 = vim.iter({ 3, 6, 9, 12 }):skip(pred) +--- it2:next() +--- -- 12 --- ``` --- ----@param n number Number of values to skip. +---@param n integer|fun(...):boolean Number of values to skip or a predicate. ---@return Iter function Iter:skip(n) + if type(n) == 'function' then + -- We would need to evaluate the perdicate without advancing iterator + error('skip() with predicate requires an array-like table') + end + for _ = 1, n do local _ = self:next() end @@ -795,6 +838,16 @@ end ---@private function ArrayIter:skip(n) + if type(n) == 'function' then + local inc = self._head < self._tail and 1 or -1 + local i = self._head + while n(unpack(self:peek())) and i ~= self._tail do + self:next() + i = i + inc + end + return self + end + local inc = self._head < self._tail and n or -n self._head = self._head + inc if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then diff --git a/test/functional/lua/iter_spec.lua b/test/functional/lua/iter_spec.lua index 79e92e6a7d..64f8d5df79 100644 --- a/test/functional/lua/iter_spec.lua +++ b/test/functional/lua/iter_spec.lua @@ -159,6 +159,30 @@ describe('vim.iter', function() eq({}, vim.iter(q):skip(#q + 1):totable()) end + do + local function wrong() + return false + end + + local function correct() + return true + end + + local q = { 4, 3, 2, 1 } + + eq({ 4, 3, 2, 1 }, vim.iter(q):skip(wrong):totable()) + eq( + { 2, 1 }, + vim + .iter(q) + :skip(function(x) + return x > 2 + end) + :totable() + ) + eq({}, vim.iter(q):skip(correct):totable()) + end + do local function skip(n) return vim.iter(vim.gsplit('a|b|c|d', '|')):skip(n):totable() @@ -241,6 +265,14 @@ describe('vim.iter', function() end) it('take()', function() + local function correct() + return true + end + + local function wrong() + return false + end + do local q = { 4, 3, 2, 1 } eq({}, vim.iter(q):take(0):totable()) @@ -251,6 +283,22 @@ describe('vim.iter', function() eq({ 4, 3, 2, 1 }, vim.iter(q):take(5):totable()) end + do + local q = { 4, 3, 2, 1 } + + eq({}, vim.iter(q):take(wrong):totable()) + eq( + { 4, 3 }, + vim + .iter(q) + :take(function(x) + return x > 2 + end) + :totable() + ) + eq({ 4, 3, 2, 1 }, vim.iter(q):take(correct):totable()) + end + do local q = { 4, 3, 2, 1 } eq({ 1, 2, 3 }, vim.iter(q):rev():take(3):totable()) @@ -271,6 +319,24 @@ describe('vim.iter', function() -- non-array iterators are consumed by take() eq({}, it:take(2):totable()) end + + do + eq({ 'a', 'b', 'c', 'd' }, vim.iter(vim.gsplit('a|b|c|d', '|')):take(correct):totable()) + eq( + { 'a', 'b', 'c' }, + vim + .iter(vim.gsplit('a|b|c|d', '|')) + :enumerate() + :take(function(i, x) + return i < 3 or x == 'c' + end) + :map(function(_, x) + return x + end) + :totable() + ) + eq({}, vim.iter(vim.gsplit('a|b|c|d', '|')):take(wrong):totable()) + end end) it('any()', function()