From b4a29844e96756043933d38cf0e71cc86fe27f65 Mon Sep 17 00:00:00 2001 From: Jeroen van Rijn Date: Fri, 23 Jul 2021 19:04:55 +0200 Subject: [PATCH] big: Add multiplication. --- core/math/big/basic.odin | 241 ++++++++++++++++++++++++++++++------- core/math/big/example.odin | 16 ++- 2 files changed, 207 insertions(+), 50 deletions(-) diff --git a/core/math/big/basic.odin b/core/math/big/basic.odin index b16635849..2e04c3c44 100644 --- a/core/math/big/basic.odin +++ b/core/math/big/basic.odin @@ -13,6 +13,7 @@ package big import "core:mem" import "core:intrinsics" +import "core:fmt" /* =========================== @@ -467,13 +468,8 @@ shl1 :: double; remainder = numerator % (1 << bits) */ int_mod_bits :: proc(remainder, numerator: ^Int, bits: int) -> (err: Error) { - remainder := remainder; numerator := numerator; - if err = clear_if_uninitialized(remainder); err != .None { - return err; - } - if err = clear_if_uninitialized(numerator); err != .None { - return err; - } + if err = clear_if_uninitialized(remainder); err != .None { return err; } + if err = clear_if_uninitialized(numerator); err != .None { return err; } if bits < 0 { return .Invalid_Argument; } if bits == 0 { return zero(remainder); } @@ -505,6 +501,161 @@ int_mod_bits :: proc(remainder, numerator: ^Int, bits: int) -> (err: Error) { } mod_bits :: proc { int_mod_bits, }; +/* + Multiply by a DIGIT. +*/ +int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT) -> (err: Error) { + if err = clear_if_uninitialized(src ); err != .None { return err; } + if err = clear_if_uninitialized(dest); err != .None { return err; } + + if multiplier == 0 { + return zero(dest); + } + if multiplier == 1 { + return copy(dest, src); + } + + /* + Power of two? + */ + if multiplier == 2 { + return double(dest, src); + } + if is_power_of_two(int(multiplier)) { + ix: int; + if ix, err = log_n(multiplier, 2); err != .None { return err; } + return shl(dest, src, ix); + } + + /* + Ensure `dest` is big enough to hold `src` * `multiplier`. + */ + if err = grow(dest, max(src.used + 1, _DEFAULT_DIGIT_COUNT)); err != .None { return err; } + + /* + Save the original used count. + */ + old_used := dest.used; + /* + Set the sign. + */ + dest.sign = src.sign; + /* + Set up carry. + */ + carry := _WORD(0); + /* + Compute columns. + */ + ix := 0; + for ; ix < src.used; ix += 1 { + /* + Compute product and carry sum for this term + */ + product := carry + _WORD(src.digit[ix]) * _WORD(multiplier); + /* + Mask off higher bits to get a single DIGIT. + */ + dest.digit[ix] = DIGIT(product & _WORD(_MASK)); + /* + Send carry into next iteration + */ + carry = product >> _DIGIT_BITS; + } + + /* + Store final carry [if any] and increment used. + */ + dest.digit[ix] = DIGIT(carry); + dest.used = src.used + 1; + + /* + Zero unused digits. + */ + zero_count := old_used - dest.used; + if zero_count > 0 { + mem.zero_slice(dest.digit[zero_count:]); + } + return clamp(dest); +} + +/* + High level multiplication (handles sign). +*/ +int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) { + if err = clear_if_uninitialized(src); err != .None { return err; } + if err = clear_if_uninitialized(dest); err != .None { return err; } + if err = clear_if_uninitialized(multiplier); err != .None { return err; } + + /* + Early out for `multiplier` is zero; Set `dest` to zero. + */ + if z, _ := is_zero(multiplier); z { + return zero(dest); + } + + min_used := min(src.used, multiplier.used); + max_used := max(src.used, multiplier.used); + digits := src.used + multiplier.used + 1; + neg := src.sign != multiplier.sign; + + if false && src == multiplier { + /* + Do we need to square? + */ + if false && src.used >= _SQR_TOOM_CUTOFF { + /* Use Toom-Cook? */ + // err = s_mp_sqr_toom(a, c); + } else if false && src.used >= _SQR_KARATSUBA_CUTOFF { + /* Karatsuba? */ + // err = s_mp_sqr_karatsuba(a, c); + } else if false && ((src.used * 2) + 1) < _WARRAY && + src.used < (_MAX_COMBA / 2) { + /* Fast comba? */ + // err = s_mp_sqr_comba(a, c); + } else { + // err = s_mp_sqr(a, c); + } + } else { + /* + Can we use the balance method? Check sizes. + * The smaller one needs to be larger than the Karatsuba cut-off. + * The bigger one needs to be at least about one `_MUL_KARATSUBA_CUTOFF` bigger + * to make some sense, but it depends on architecture, OS, position of the + * stars... so YMMV. + * Using it to cut the input into slices small enough for _mul_comba + * was actually slower on the author's machine, but YMMV. + */ + if false && min_used >= _MUL_KARATSUBA_CUTOFF && + max_used / 2 >= _MUL_KARATSUBA_CUTOFF && + /* + Not much effect was observed below a ratio of 1:2, but again: YMMV. + */ + max_used >= 2 * min_used { + // err = s_mp_mul_balance(a,b,c); + } else if false && min_used >= _MUL_TOOM_CUTOFF { + // err = s_mp_mul_toom(a, b, c); + } else if false && min_used >= _MUL_KARATSUBA_CUTOFF { + // err = s_mp_mul_karatsuba(a, b, c); + } else if false && digits < _WARRAY && min_used <= _MAX_COMBA { + /* + Can we use the fast multiplier? + * The fast multiplier can be used if the output will + * have less than MP_WARRAY digits and the number of + * digits won't affect carry propagation + */ + // err = s_mp_mul_comba(a, b, c, digs); + } else { + fmt.println("Hai"); + err = _int_mul(dest, src, multiplier, digits); + } + } + dest.sign = .Negative if dest.used > 0 && neg else .Zero_or_Positive; + return err; +} + +mul :: proc { int_mul, int_mul_digit, }; + /* ========================== Low-level routines @@ -688,46 +839,54 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) { } } - if err = grow(dest, digits); err != .None { return err; } - dest.used = digits; + /* + Set up temporary output `Int`, which we'll swap for `dest` when done. + */ + + t := &Int{}; + + if err = grow(t, max(digits, _DEFAULT_DIGIT_COUNT)); err != .None { return err; } + t.used = digits; /* + Compute the digits of the product directly. + */ + pa := a.used; + for ix := 0; ix < pa; ix += 1 { + /* + Limit ourselves to `digits` DIGITs of output. + */ + pb := min(b.used, digits - ix); + carry := DIGIT(0); + iy := 0; + /* + Compute the column of the output and propagate the carry. + */ + for iy = 0; iy < pb; iy += 1 { + /* + Compute the column as a _WORD. + */ + column := t.digit[ix + iy] + a.digit[ix] * b.digit[iy] + carry; - /* compute the digits of the product directly */ - pa = a->used; - for (ix = 0; ix < pa; ix++) { - int iy, pb; - mp_digit u = 0; + /* + The new column is the lower part of the result. + */ + t.digit[ix + iy] = column & _MASK; - /* limit ourselves to making digs digits of output */ - pb = MP_MIN(b->used, digs - ix); - - /* compute the columns of the output and propagate the carry */ - for (iy = 0; iy < pb; iy++) { - /* compute the column as a mp_word */ - mp_word r = (mp_word)t.dp[ix + iy] + - ((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) + - (mp_word)u; - - /* the new column is the lower part of the result */ - t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); - - /* get the carry word from the result */ - u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); + /* + Get the carry word from the result. + */ + carry = column >> _DIGIT_BITS; } - /* set carry if it is placed below digs */ - if ((ix + iy) < digs) { - t.dp[ix + pb] = u; + /* + Set carry if it is placed below digits + */ + if ix + iy < digits { + t.digit[ix + pb] = carry; } } - mp_clamp(&t); - mp_exch(&t, c); - - mp_clear(&t); - return MP_OKAY; -} - -*/ - return .None; + swap(dest, t); + destroy(t); + return clamp(dest); } \ No newline at end of file diff --git a/core/math/big/example.odin b/core/math/big/example.odin index 6da1a292c..e6eddf826 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -57,17 +57,15 @@ demo :: proc() { a, b, c := &Int{}, &Int{}, &Int{}; defer destroy(a, b, c); - err = set(a, -512); - err = set(b, 1024); + err = set(a, -1024); + err = set(b, -1024); - print("a", a, 16); - print("b", b, 16); + print("a", a, 10); + print("b", b, 10); - fmt.println("--- swap ---"); - foo(a, b); - - print("a", a, 16); - print("b", b, 16); + fmt.println("--- mul ---"); + mul(c, a, b); + print("c", c, 10); }