diff --git a/core/math/big/basic.odin b/core/math/big/basic.odin index 2e04c3c44..fbb533ddf 100644 --- a/core/math/big/basic.odin +++ b/core/math/big/basic.odin @@ -8,12 +8,11 @@ package big For the theoretical underpinnings, see Knuth's The Art of Computer Programming, Volume 2, section 4.3. The code started out as an idiomatic source port of libTomMath, which is in the public domain, with thanks. - This file contains basic arithmetic operations like `add`, `sub`, `div`, ... + This file contains basic arithmetic operations like `add`, `sub`, `mul`, `div`, ... */ import "core:mem" import "core:intrinsics" -import "core:fmt" /* =========================== @@ -26,15 +25,9 @@ import "core:fmt" */ int_add :: proc(dest, a, b: ^Int) -> (err: Error) { dest := dest; x := a; y := b; - if err = clear_if_uninitialized(a); err != .None { - return err; - } - if err = clear_if_uninitialized(b); err != .None { - return err; - } - if err = clear_if_uninitialized(dest); err != .None { - return err; - } + if err = clear_if_uninitialized(a); err != .None { return err; } + if err = clear_if_uninitialized(b); err != .None { return err; } + if err = clear_if_uninitialized(dest); err != .None { return err; } /* All parameters have been initialized. We can now safely ignore errors from comparison routines. @@ -599,7 +592,7 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) { digits := src.used + multiplier.used + 1; neg := src.sign != multiplier.sign; - if false && src == multiplier { + if src == multiplier { /* Do we need to square? */ @@ -614,7 +607,7 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) { /* Fast comba? */ // err = s_mp_sqr_comba(a, c); } else { - // err = s_mp_sqr(a, c); + err = _int_sqr(dest, src); } } else { /* @@ -646,7 +639,6 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) { */ // err = s_mp_mul_comba(a, b, c, digs); } else { - fmt.println("Hai"); err = _int_mul(dest, src, multiplier, digits); } } @@ -889,4 +881,72 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) { swap(dest, t); destroy(t); return clamp(dest); +} + +/* + Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 +*/ +_int_sqr :: proc(dest, src: ^Int) -> (err: Error) { + pa := src.used; + + t := &Int{}; ix, iy: int; + /* + Grow `t` to maximum needed size, or `_DEFAULT_DIGIT_COUNT`, whichever is bigger. + */ + if err = grow(t, min((2 * pa) + 1, _DEFAULT_DIGIT_COUNT)); err != .None { return err; } + t.used = (2 * pa) + 1; + + for ix = 0; ix < pa; ix += 1 { + carry := DIGIT(0); + /* + First calculate the digit at 2*ix; calculate double precision result. + */ + r := _WORD(t.digit[ix+ix]) + _WORD(src.digit[ix] * src.digit[ix]); + + /* + Store lower part in result. + */ + t.digit[ix+ix] = DIGIT(r & _WORD(_MASK)); + + /* + Get the carry. + */ + carry = DIGIT(r >> _DIGIT_BITS); + + for iy = ix + 1; iy < pa; iy += 1 { + /* + First calculate the product. + */ + r = _WORD(src.digit[ix]) * _WORD(src.digit[iy]); + + /* Now calculate the double precision result. Nte we use + * addition instead of *2 since it's easier to optimize + */ + r = _WORD(t.digit[ix+iy]) + r + r + _WORD(carry); + + /* + Store lower part. + */ + t.digit[ix+iy] = DIGIT(r & _WORD(_MASK)); + + /* + Get carry. + */ + carry = DIGIT(r >> _DIGIT_BITS); + } + /* + Propagate upwards. + */ + for carry != 0 { + r = _WORD(t.digit[ix+iy]) + _WORD(carry); + t.digit[ix+iy] = DIGIT(r & _WORD(_MASK)); + carry = DIGIT(r >> _WORD(_DIGIT_BITS)); + iy += 1; + } + } + + err = clamp(t); + swap(dest, t); + destroy(t); + return err; } \ No newline at end of file diff --git a/core/math/big/example.odin b/core/math/big/example.odin index e6eddf826..a9dde60f2 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -64,9 +64,8 @@ demo :: proc() { print("b", b, 10); fmt.println("--- mul ---"); - mul(c, a, b); + mul(c, a, a); print("c", c, 10); - } main :: proc() { diff --git a/core/math/big/helpers.odin b/core/math/big/helpers.odin index e36f0614b..20193e061 100644 --- a/core/math/big/helpers.odin +++ b/core/math/big/helpers.odin @@ -57,9 +57,7 @@ set :: proc { int_set_from_integer, int_copy }; Copy one `Int` to another. */ int_copy :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) { - if err = clear_if_uninitialized(src); err != .None { - return err; - } + if err = clear_if_uninitialized(src); err != .None { return err; } /* If dest == src, do nothing */ @@ -535,6 +533,22 @@ clear_if_uninitialized :: proc(dest: ^Int, minimize := false) -> (err: Error) { return .None; } +_copy_digits :: proc(dest, src: ^Int, digits: int) -> (err: Error) { + digits := digits; + if err = clear_if_uninitialized(src); err != .None { return err; } + if err = clear_if_uninitialized(dest); err != .None { return err; } + /* + If dest == src, do nothing + */ + if (dest == src) { + return .None; + } + + digits = min(digits, len(src.digit), len(dest.digit)); + mem.copy_non_overlapping(&dest.digit[0], &src.digit[0], size_of(DIGIT) * digits); + return .None; +} + /* Trim unused digits.