mirror of
https://github.com/nim-lang/Nim.git
synced 2025-12-28 17:04:41 +00:00
Add ability to sample elements from openArray according to a weight array (#10072)
* Add the ability to sample elements from an openArray according to a parallel array of weights/unnormalized probabilities (any sort of histogram, basically). Also add a non-thread safe version for convenience. * Address Araq comments on https://github.com/nim-lang/Nim/pull/10072 * import at top of file and space after '#'. * Put in a check for non-zero total weight. * Clarify constraint on `w`. * Rename `rand(openArray[T])` to `sample(openArray[T])` to `sample`, deprecating old name and name new (openArray[T], openArray[U]) variants `sample`. * Rename caller-provided state version of rand(openArray[T]) and also clean up doc comments. * Add test for new non-uniform array sampler. 3 sd bound makes it 99% likely that it will still pass in the future if the random number generator changes. We cannot both have a tight bound to check distribution *and* loose check to ensure resilience to RNG changes. (We cannot *guarantee* resilience, anyway. There's always a small chance any test hits a legitimate random fluctuation.)
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
##
|
||||
## **Do not use this module for cryptographic purposes!**
|
||||
|
||||
import algorithm #For upperBound
|
||||
|
||||
include "system/inclrtl"
|
||||
{.push debugger:off.}
|
||||
|
||||
@@ -155,14 +157,45 @@ proc rand*[T](x: HSlice[T, T]): T =
|
||||
## For a slice `a .. b` returns a value in the range `a .. b`.
|
||||
result = rand(state, x)
|
||||
|
||||
proc rand*[T](r: var Rand; a: openArray[T]): T =
|
||||
proc rand*[T](r: var Rand; a: openArray[T]): T {.deprecated.} =
|
||||
## returns a random element from the openarray `a`.
|
||||
## **Deprecated since v0.20.0:** use ``sample`` instead.
|
||||
result = a[rand(r, a.low..a.high)]
|
||||
|
||||
proc rand*[T](a: openArray[T]): T =
|
||||
proc rand*[T](a: openArray[T]): T {.deprecated.} =
|
||||
## returns a random element from the openarray `a`.
|
||||
## **Deprecated since v0.20.0:** use ``sample`` instead.
|
||||
result = a[rand(a.low..a.high)]
|
||||
|
||||
proc sample*[T](r: var Rand; a: openArray[T]): T =
|
||||
## returns a random element from openArray ``a`` using state in ``r``.
|
||||
result = a[r.rand(a.low..a.high)]
|
||||
|
||||
proc sample*[T](a: openArray[T]): T =
|
||||
## returns a random element from openArray ``a`` using non-thread-safe state.
|
||||
result = a[rand(a.low..a.high)]
|
||||
|
||||
proc sample*[T, U](r: var Rand; a: openArray[T], w: openArray[U], n=1): seq[T] =
|
||||
## Return a sample (with replacement) of size ``n`` from elements of ``a``
|
||||
## according to convertible-to-``float``, not necessarily normalized, and
|
||||
## non-negative weights ``w``. Uses state in ``r``. Must have sum ``w > 0.0``.
|
||||
assert(w.len == a.len)
|
||||
var cdf = newSeq[float](a.len) # The *unnormalized* CDF
|
||||
var tot = 0.0 # Unnormalized is fine if we sample up to tot
|
||||
for i, w in w:
|
||||
assert(w >= 0)
|
||||
tot += float(w)
|
||||
cdf[i] = tot
|
||||
assert(tot > 0.0) # Need at least one non-zero weight
|
||||
for i in 0 ..< n:
|
||||
result.add(a[cdf.upperBound(r.rand(tot))])
|
||||
|
||||
proc sample*[T, U](a: openArray[T], w: openArray[U], n=1): seq[T] =
|
||||
## Return a sample (with replacement) of size ``n`` from elements of ``a``
|
||||
## according to convertible-to-``float``, not necessarily normalized, and
|
||||
## non-negative weights ``w``. Uses default non-thread-safe state.
|
||||
state.sample(a, w, n)
|
||||
|
||||
|
||||
proc initRand*(seed: int64): Rand =
|
||||
## Creates a new ``Rand`` state from ``seed``.
|
||||
|
||||
@@ -4,6 +4,8 @@ discard """
|
||||
|
||||
[Suite] random float
|
||||
|
||||
[Suite] random sample
|
||||
|
||||
[Suite] ^
|
||||
|
||||
'''
|
||||
@@ -11,7 +13,7 @@ discard """
|
||||
|
||||
import math, random, os
|
||||
import unittest
|
||||
import sets
|
||||
import sets, tables
|
||||
|
||||
suite "random int":
|
||||
test "there might be some randomness":
|
||||
@@ -72,6 +74,30 @@ suite "random float":
|
||||
var rand2:float = random(1000000.0)
|
||||
check rand1 != rand2
|
||||
|
||||
suite "random sample":
|
||||
test "non-uniform array sample":
|
||||
let values = [ 10, 20, 30, 40, 50 ] # values
|
||||
let weight = [ 4, 3, 2, 1, 0 ] # weights aka unnormalized probabilities
|
||||
let weightSum = 10.0 # sum of weights
|
||||
var histo = initCountTable[int]()
|
||||
for v in sample(values, weight, 5000):
|
||||
histo.inc(v)
|
||||
check histo.len == 4 # number of non-zero in `weight`
|
||||
# Any one bin is a binomial random var for n samples, each with prob p of
|
||||
# adding a count to k; E[k]=p*n, Var k=p*(1-p)*n, approximately Normal for
|
||||
# big n. So, P(abs(k - p*n)/sqrt(p*(1-p)*n))>3.0) =~ 0.0027, while
|
||||
# P(wholeTestFails) =~ 1 - P(binPasses)^4 =~ 1 - (1-0.0027)^4 =~ 0.01.
|
||||
for i, w in weight:
|
||||
if w == 0:
|
||||
check values[i] notin histo
|
||||
continue
|
||||
let p = float(w) / float(weightSum)
|
||||
let n = 5000.0
|
||||
let expected = p * n
|
||||
let stdDev = sqrt(n * p * (1.0 - p))
|
||||
check abs(float(histo[values[i]]) - expected) <= 3.0 * stdDev
|
||||
|
||||
|
||||
suite "^":
|
||||
test "compiles for valid types":
|
||||
check: compiles(5 ^ 2)
|
||||
|
||||
Reference in New Issue
Block a user