diff --git a/core/math/big/basic.odin b/core/math/big/basic.odin index 5d0a50218..4129f44d5 100644 --- a/core/math/big/basic.odin +++ b/core/math/big/basic.odin @@ -773,8 +773,43 @@ int_factorial :: proc(res: ^Int, n: DIGIT) -> (err: Error) { } factorial :: proc { int_factorial, }; +/* + Number of ways to choose `k` items from `n` items. + Also known as the binomial coefficient. + TODO: Speed up. + Could be done faster by reusing code from factorial and reusing the common "prefix" results for n!, k! and n-k! + We know that n >= k, otherwise we early out with res = 0. + + So: + n-k, keep result + n, start from previous result + k, start from previous result + +*/ +int_choose_digit :: proc(res: ^Int, n, k: DIGIT) -> (err: Error) { + if res == nil { return .Invalid_Pointer; } + if err = clear_if_uninitialized(res); err != .None { return err; } + + if k > n { return zero(res); } + + /* + res = n! / (k! * (n - k)!) + */ + n_fac, k_fac, n_minus_k_fac := &Int{}, &Int{}, &Int{}; + defer destroy(n_fac, k_fac, n_minus_k_fac); + + if err = factorial(n_minus_k_fac, n - k); err != .None { return err; } + if err = factorial(k_fac, k); err != .None { return err; } + if err = mul(k_fac, k_fac, n_minus_k_fac); err != .None { return err; } + + if err = factorial(n_fac, n); err != .None { return err; } + if err = div(res, n_fac, k_fac); err != .None { return err; } + + return err; +} +choose :: proc { int_choose_digit, }; /* ========================== diff --git a/core/math/big/example.odin b/core/math/big/example.odin index 5452dd711..14531a184 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -54,7 +54,7 @@ print_timings :: proc() { case avg < time.Millisecond: avg_s = fmt.tprintf("%v µs", time.duration_microseconds(avg)); case: - avg_s = fmt.tprintf("%v", time.duration_milliseconds(avg)); + avg_s = fmt.tprintf("%v ms", time.duration_milliseconds(avg)); } total_s: string; @@ -64,7 +64,7 @@ print_timings :: proc() { case v.t < time.Millisecond: total_s = fmt.tprintf("%v µs", time.duration_microseconds(v.t)); case: - total_s = fmt.tprintf("%v", time.duration_milliseconds(v.t)); + total_s = fmt.tprintf("%v ms", time.duration_milliseconds(v.t)); } fmt.printf("\t%v: %s (avg), %s (total, %v calls)\n", i, avg_s, total_s, v.c); @@ -76,6 +76,7 @@ Category :: enum { itoa, atoi, factorial, + choose, lsb, ctz, }; @@ -114,30 +115,11 @@ demo :: proc() { a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}; defer destroy(a, b, c, d, e, f); - set(a, 125); - set(b, 75); - - err = gcd_lcm(c, d, a, b); - fmt.printf("gcd_lcm("); - print("a =", a, 10, false, true, false); - print(", b =", b, 10, false, true, false); - print("), gcd =", c, 10, false, true, false); - print(", lcm =", d, 10, false, true, false); - fmt.printf(" (err = %v)\n", err); - - err = gcd(c, a, b); - fmt.printf("gcd("); - print("a =", a, 10, false, true, false); - print(", b =", b, 10, false, true, false); - print(") =", c, 10, false, true, false); - fmt.printf(" (err = %v)\n", err); - - err = lcm(c, a, b); - fmt.printf("lcm("); - print("a =", a, 10, false, true, false); - print(", b =", b, 10, false, true, false); - print(") =", c, 10, false, true, false); - fmt.printf(" (err = %v)\n", err); + s := time.tick_now(); + err = choose(a, 65535, 255); + Timings[.choose].t += time.tick_since(s); Timings[.choose].c += 1; + print("choose", a); + fmt.println(err); } main :: proc() {