diff --git a/lib/pure/collections/sequtils.nim b/lib/pure/collections/sequtils.nim index b2f72ee146..f5db9d3fab 100644 --- a/lib/pure/collections/sequtils.nim +++ b/lib/pure/collections/sequtils.nim @@ -88,6 +88,71 @@ proc zip*[S, T](seq1: seq[S], seq2: seq[T]): seq[tuple[a: S, b: T]] = newSeq(result, m) for i in 0 .. m-1: result[i] = (seq1[i], seq2[i]) +proc distribute*[T](s: seq[T], num: int, spread = true): seq[seq[T]] = + ## Splits and distributes a sequence `s` into `num` sub sequences. + ## + ## Returns a sequence of `num` sequences. For some input values this is the + ## inverse of the `concat <#concat>`_ proc. The proc will assert in debug + ## builds if `s` is nil or `num` is less than one, and will likely crash on + ## release builds. The input sequence `s` can be empty, which will produce + ## `num` empty sequences. + ## + ## If `spread` is false and the length of `s` is not a multiple of `num`, the + ## proc will max out the first sub sequences with ``1 + len(s) div num`` + ## entries, leaving the remainder of elements to the last sequence. + ## + ## On the other hand, if `spread` is true, the proc will distribute evenly + ## the remainder of the division across all sequences, which makes the result + ## more suited to multithreading where you are passing equal sized work units + ## to a thread pool and want to maximize core usage. + ## + ## Example: + ## + ## .. code-block:: nimrod + ## let numbers = @[1, 2, 3, 4, 5, 6, 7] + ## assert numbers.distribute(3) == @[@[1, 2, 3], @[4, 5], @[6, 7]] + ## assert numbers.distribute(3, false) == @[@[1, 2, 3], @[4, 5, 6], @[7]] + ## assert numbers.distribute(6)[0] == @[1, 2] + ## assert numbers.distribute(6)[5] == @[7] + assert(not s.isNil, "`s` can't be nil") + assert(num > 0, "`num` has to be greater than zero") + if num < 2: + result = @[s] + return + + # Create the result and calculate the stride size and the remainder if any. + result = newSeq[seq[T]](num) + var + stride = s.len div num + first = 0 + last = 0 + extra = s.len mod num + + if extra == 0 or spread == false: + # Use an algorithm which overcounts the stride and minimizes reading limits. + if extra > 0: inc(stride) + + for i in 0 .. 0: + extra -= 1 + inc(last) + + result[i] = newSeq[T]() + for g in first ..