From 8b49bbb0fca317a02cf6f14fa5c7c8784ea4076d Mon Sep 17 00:00:00 2001 From: Jeroen van Rijn Date: Mon, 16 Aug 2021 16:10:10 +0200 Subject: [PATCH] big: Add `_private_mul_karatsuba`. --- core/math/big/build.bat | 4 +- core/math/big/example.odin | 12 ++-- core/math/big/helpers.odin | 6 +- core/math/big/internal.odin | 14 ++--- core/math/big/private.odin | 106 +++++++++++++++++++++++++++++++++++- 5 files changed, 116 insertions(+), 26 deletions(-) diff --git a/core/math/big/build.bat b/core/math/big/build.bat index eb6f581aa..540907a3a 100644 --- a/core/math/big/build.bat +++ b/core/math/big/build.bat @@ -1,8 +1,8 @@ @echo off -:odin run . -vet +odin run . -vet : -o:size :odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests :odin build . -build-mode:shared -show-timings -o:size -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests :odin build . -build-mode:shared -show-timings -o:size -define:MATH_BIG_EXE=false && python test.py -fast-tests -odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests +:odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests :odin build . -build-mode:shared -show-timings -o:speed -define:MATH_BIG_EXE=false && python test.py -fast-tests \ No newline at end of file diff --git a/core/math/big/example.odin b/core/math/big/example.odin index 4fbf44664..2a66251c9 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -206,16 +206,12 @@ demo :: proc() { a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}; defer destroy(a, b, c, d, e, f); - atoi(a, "12980742146337069150589594264770969721", 10); + power_of_two(a, 312); print("a: ", a, 10, true, true, true); - atoi(b, "4611686018427387904", 10); + power_of_two(b, 314); print("b: ", b, 10, true, true, true); - - if err := internal_divmod(c, d, a, b); err != nil { - fmt.printf("Error: %v\n", err); - } - print("c: ", c); - print("c: ", d); + _private_mul_karatsuba(c, a, b); + print("c: ", c, 10, true, true, true); } main :: proc() { diff --git a/core/math/big/helpers.odin b/core/math/big/helpers.odin index ab686b914..e50579ac0 100644 --- a/core/math/big/helpers.odin +++ b/core/math/big/helpers.odin @@ -432,18 +432,16 @@ int_init_multi :: proc(integers: ..^Int, allocator := context.allocator) -> (err init_multi :: proc { int_init_multi, }; -copy_digits :: proc(dest, src: ^Int, digits: int, allocator := context.allocator) -> (err: Error) { +copy_digits :: proc(dest, src: ^Int, digits: int, offset := int(0), allocator := context.allocator) -> (err: Error) { context.allocator = allocator; - digits := digits; /* Check that `src` is usable and `dest` isn't immutable. */ assert_if_nil(dest, src); #force_inline internal_clear_if_uninitialized(src) or_return; - digits = min(digits, len(src.digit), len(dest.digit)); - return #force_inline internal_copy_digits(dest, src, digits); + return #force_inline internal_copy_digits(dest, src, digits, offset); } /* diff --git a/core/math/big/internal.odin b/core/math/big/internal.odin index 2c988f91e..d5cb03cc4 100644 --- a/core/math/big/internal.odin +++ b/core/math/big/internal.odin @@ -36,8 +36,6 @@ import "core:mem" import "core:intrinsics" import rnd "core:math/rand" -import "core:fmt" - /* Low-level addition, unsigned. Handbook of Applied Cryptography, algorithm 14.7. @@ -651,7 +649,6 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc Fast comba? */ err = #force_inline _private_int_sqr_comba(dest, src); - //err = #force_inline _private_int_sqr(dest, src); } else { err = #force_inline _private_int_sqr(dest, src); } @@ -679,8 +676,8 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc // err = s_mp_mul_balance(a,b,c); } else if false && min_used >= MUL_TOOM_CUTOFF { // err = s_mp_mul_toom(a, b, c); - } else if false && min_used >= MUL_KARATSUBA_CUTOFF { - // err = s_mp_mul_karatsuba(a, b, c); + } else if min_used >= MUL_KARATSUBA_CUTOFF { + err = #force_inline _private_mul_karatsuba(dest, src, multiplier); } else if digits < _WARRAY && min_used <= _MAX_COMBA { /* Can we use the fast multiplier? @@ -1628,16 +1625,13 @@ internal_int_set_from_integer :: proc(dest: ^Int, src: $T, minimize := false, al internal_set :: proc { internal_int_set_from_integer, internal_int_copy }; -internal_copy_digits :: #force_inline proc(dest, src: ^Int, digits: int) -> (err: Error) { +internal_copy_digits :: #force_inline proc(dest, src: ^Int, digits: int, offset := int(0)) -> (err: Error) { #force_inline internal_error_if_immutable(dest) or_return; /* If dest == src, do nothing */ - if (dest == src) { return nil; } - - #force_inline mem.copy_non_overlapping(&dest.digit[0], &src.digit[0], size_of(DIGIT) * digits); - return nil; + return #force_inline _private_copy_digits(dest, src, digits, offset); } /* diff --git a/core/math/big/private.odin b/core/math/big/private.odin index a99d6119f..50a6f9c9c 100644 --- a/core/math/big/private.odin +++ b/core/math/big/private.odin @@ -89,6 +89,108 @@ _private_int_mul :: proc(dest, a, b: ^Int, digits: int, allocator := context.all return internal_clamp(dest); } +/* + product = |a| * |b| using Karatsuba Multiplication using three half size multiplications. + + Let `B` represent the radix [e.g. 2**_DIGIT_BITS] and let `n` represent + half of the number of digits in the min(a,b) + + `a` = `a1` * `B`**`n` + `a0` + `b` = `b`1 * `B`**`n` + `b0` + + Then, a * b => 1b1 * B**2n + ((a1 + a0)(b1 + b0) - (a0b0 + a1b1)) * B + a0b0 + + Note that a1b1 and a0b0 are used twice and only need to be computed once. + So in total three half size (half # of digit) multiplications are performed, + a0b0, a1b1 and (a1+b1)(a0+b0) + + Note that a multiplication of half the digits requires 1/4th the number of + single precision multiplications, so in total after one call 25% of the + single precision multiplications are saved. + + Note also that the call to `internal_mul` can end up back in this function + if the a0, a1, b0, or b1 are above the threshold. + + This is known as divide-and-conquer and leads to the famous O(N**lg(3)) or O(N**1.584) + work which is asymptopically lower than the standard O(N**2) that the + baseline/comba methods use. Generally though, the overhead of this method doesn't pay off + until a certain size is reached, of around 80 used DIGITs. +*/ +_private_mul_karatsuba :: proc(dest, a, b: ^Int, allocator := context.allocator) -> (err: Error) { + context.allocator = allocator; + + x0, x1, y0, y1, t1, x0y0, x1y1 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}; + defer destroy(x0, x1, y0, y1, t1, x0y0, x1y1); + + /* + min # of digits, divided by two. + */ + B := min(a.used, b.used) >> 1; + + /* + Init all the temps. + */ + internal_grow(x0, B) or_return; + internal_grow(x1, a.used - B) or_return; + internal_grow(y0, B) or_return; + internal_grow(y1, b.used - B) or_return; + internal_grow(t1, B * 2) or_return; + internal_grow(x0y0, B * 2) or_return; + internal_grow(x1y1, B * 2) or_return; + + /* + Now shift the digits. + */ + x0.used, y0.used = B, B; + x1.used = a.used - B; + y1.used = b.used - B; + + /* + We copy the digits directly instead of using higher level functions + since we also need to shift the digits. + */ + internal_copy_digits(x0, a, x0.used); + internal_copy_digits(y0, b, y0.used); + internal_copy_digits(x1, a, x1.used, B); + internal_copy_digits(y1, b, y1.used, B); + + /* + Only need to clamp the lower words since by definition the + upper words x1/y1 must have a known number of digits. + */ + clamp(x0); + clamp(y0); + + /* + Now calc the products x0y0 and x1y1, + after this x0 is no longer required, free temp [x0==t2]! + */ + internal_mul(x0y0, x0, y0) or_return; /* x0y0 = x0*y0 */ + internal_mul(x1y1, x1, y1) or_return; /* x1y1 = x1*y1 */ + internal_add(t1, x1, x0) or_return; /* now calc x1+x0 and */ + internal_add(x0, y1, y0) or_return; /* t2 = y1 + y0 */ + internal_mul(t1, t1, x0) or_return; /* t1 = (x1 + x0) * (y1 + y0) */ + + /* + Add x0y0. + */ + internal_add(x0, x0y0, x1y1) or_return; /* t2 = x0y0 + x1y1 */ + internal_sub(t1, t1, x0) or_return; /* t1 = (x1+x0)*(y1+y0) - (x1y1 + x0y0) */ + + /* + shift by B. + */ + internal_shl_digit(t1, B) or_return; /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))< (log: int, err: Error Copies DIGITs from `src` to `dest`. Assumes `src` and `dest` to not be `nil` and have been initialized. */ -_private_copy_digits :: proc(dest, src: ^Int, digits: int) -> (err: Error) { +_private_copy_digits :: proc(dest, src: ^Int, digits: int, offset := int(0)) -> (err: Error) { digits := digits; /* If dest == src, do nothing @@ -1639,7 +1741,7 @@ _private_copy_digits :: proc(dest, src: ^Int, digits: int) -> (err: Error) { } digits = min(digits, len(src.digit), len(dest.digit)); - mem.copy_non_overlapping(&dest.digit[0], &src.digit[0], size_of(DIGIT) * digits); + mem.copy_non_overlapping(&dest.digit[0], &src.digit[offset], size_of(DIGIT) * digits); return nil; }