big: Add multiplication.

This commit is contained in:
Jeroen van Rijn
2021-07-23 19:04:55 +02:00
parent 0254057f1b
commit b4a29844e9
2 changed files with 207 additions and 50 deletions

View File

@@ -13,6 +13,7 @@ package big
import "core:mem"
import "core:intrinsics"
import "core:fmt"
/*
===========================
@@ -467,13 +468,8 @@ shl1 :: double;
remainder = numerator % (1 << bits)
*/
int_mod_bits :: proc(remainder, numerator: ^Int, bits: int) -> (err: Error) {
remainder := remainder; numerator := numerator;
if err = clear_if_uninitialized(remainder); err != .None {
return err;
}
if err = clear_if_uninitialized(numerator); err != .None {
return err;
}
if err = clear_if_uninitialized(remainder); err != .None { return err; }
if err = clear_if_uninitialized(numerator); err != .None { return err; }
if bits < 0 { return .Invalid_Argument; }
if bits == 0 { return zero(remainder); }
@@ -505,6 +501,161 @@ int_mod_bits :: proc(remainder, numerator: ^Int, bits: int) -> (err: Error) {
}
mod_bits :: proc { int_mod_bits, };
/*
Multiply by a DIGIT.
*/
int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT) -> (err: Error) {
if err = clear_if_uninitialized(src ); err != .None { return err; }
if err = clear_if_uninitialized(dest); err != .None { return err; }
if multiplier == 0 {
return zero(dest);
}
if multiplier == 1 {
return copy(dest, src);
}
/*
Power of two?
*/
if multiplier == 2 {
return double(dest, src);
}
if is_power_of_two(int(multiplier)) {
ix: int;
if ix, err = log_n(multiplier, 2); err != .None { return err; }
return shl(dest, src, ix);
}
/*
Ensure `dest` is big enough to hold `src` * `multiplier`.
*/
if err = grow(dest, max(src.used + 1, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
/*
Save the original used count.
*/
old_used := dest.used;
/*
Set the sign.
*/
dest.sign = src.sign;
/*
Set up carry.
*/
carry := _WORD(0);
/*
Compute columns.
*/
ix := 0;
for ; ix < src.used; ix += 1 {
/*
Compute product and carry sum for this term
*/
product := carry + _WORD(src.digit[ix]) * _WORD(multiplier);
/*
Mask off higher bits to get a single DIGIT.
*/
dest.digit[ix] = DIGIT(product & _WORD(_MASK));
/*
Send carry into next iteration
*/
carry = product >> _DIGIT_BITS;
}
/*
Store final carry [if any] and increment used.
*/
dest.digit[ix] = DIGIT(carry);
dest.used = src.used + 1;
/*
Zero unused digits.
*/
zero_count := old_used - dest.used;
if zero_count > 0 {
mem.zero_slice(dest.digit[zero_count:]);
}
return clamp(dest);
}
/*
High level multiplication (handles sign).
*/
int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
if err = clear_if_uninitialized(src); err != .None { return err; }
if err = clear_if_uninitialized(dest); err != .None { return err; }
if err = clear_if_uninitialized(multiplier); err != .None { return err; }
/*
Early out for `multiplier` is zero; Set `dest` to zero.
*/
if z, _ := is_zero(multiplier); z {
return zero(dest);
}
min_used := min(src.used, multiplier.used);
max_used := max(src.used, multiplier.used);
digits := src.used + multiplier.used + 1;
neg := src.sign != multiplier.sign;
if false && src == multiplier {
/*
Do we need to square?
*/
if false && src.used >= _SQR_TOOM_CUTOFF {
/* Use Toom-Cook? */
// err = s_mp_sqr_toom(a, c);
} else if false && src.used >= _SQR_KARATSUBA_CUTOFF {
/* Karatsuba? */
// err = s_mp_sqr_karatsuba(a, c);
} else if false && ((src.used * 2) + 1) < _WARRAY &&
src.used < (_MAX_COMBA / 2) {
/* Fast comba? */
// err = s_mp_sqr_comba(a, c);
} else {
// err = s_mp_sqr(a, c);
}
} else {
/*
Can we use the balance method? Check sizes.
* The smaller one needs to be larger than the Karatsuba cut-off.
* The bigger one needs to be at least about one `_MUL_KARATSUBA_CUTOFF` bigger
* to make some sense, but it depends on architecture, OS, position of the
* stars... so YMMV.
* Using it to cut the input into slices small enough for _mul_comba
* was actually slower on the author's machine, but YMMV.
*/
if false && min_used >= _MUL_KARATSUBA_CUTOFF &&
max_used / 2 >= _MUL_KARATSUBA_CUTOFF &&
/*
Not much effect was observed below a ratio of 1:2, but again: YMMV.
*/
max_used >= 2 * min_used {
// err = s_mp_mul_balance(a,b,c);
} else if false && min_used >= _MUL_TOOM_CUTOFF {
// err = s_mp_mul_toom(a, b, c);
} else if false && min_used >= _MUL_KARATSUBA_CUTOFF {
// err = s_mp_mul_karatsuba(a, b, c);
} else if false && digits < _WARRAY && min_used <= _MAX_COMBA {
/*
Can we use the fast multiplier?
* The fast multiplier can be used if the output will
* have less than MP_WARRAY digits and the number of
* digits won't affect carry propagation
*/
// err = s_mp_mul_comba(a, b, c, digs);
} else {
fmt.println("Hai");
err = _int_mul(dest, src, multiplier, digits);
}
}
dest.sign = .Negative if dest.used > 0 && neg else .Zero_or_Positive;
return err;
}
mul :: proc { int_mul, int_mul_digit, };
/*
==========================
Low-level routines
@@ -688,46 +839,54 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
}
}
if err = grow(dest, digits); err != .None { return err; }
dest.used = digits;
/*
Set up temporary output `Int`, which we'll swap for `dest` when done.
*/
t := &Int{};
if err = grow(t, max(digits, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
t.used = digits;
/*
Compute the digits of the product directly.
*/
pa := a.used;
for ix := 0; ix < pa; ix += 1 {
/*
Limit ourselves to `digits` DIGITs of output.
*/
pb := min(b.used, digits - ix);
carry := DIGIT(0);
iy := 0;
/*
Compute the column of the output and propagate the carry.
*/
for iy = 0; iy < pb; iy += 1 {
/*
Compute the column as a _WORD.
*/
column := t.digit[ix + iy] + a.digit[ix] * b.digit[iy] + carry;
/* compute the digits of the product directly */
pa = a->used;
for (ix = 0; ix < pa; ix++) {
int iy, pb;
mp_digit u = 0;
/*
The new column is the lower part of the result.
*/
t.digit[ix + iy] = column & _MASK;
/* limit ourselves to making digs digits of output */
pb = MP_MIN(b->used, digs - ix);
/* compute the columns of the output and propagate the carry */
for (iy = 0; iy < pb; iy++) {
/* compute the column as a mp_word */
mp_word r = (mp_word)t.dp[ix + iy] +
((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) +
(mp_word)u;
/* the new column is the lower part of the result */
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
/* get the carry word from the result */
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
/*
Get the carry word from the result.
*/
carry = column >> _DIGIT_BITS;
}
/* set carry if it is placed below digs */
if ((ix + iy) < digs) {
t.dp[ix + pb] = u;
/*
Set carry if it is placed below digits
*/
if ix + iy < digits {
t.digit[ix + pb] = carry;
}
}
mp_clamp(&t);
mp_exch(&t, c);
mp_clear(&t);
return MP_OKAY;
}
*/
return .None;
swap(dest, t);
destroy(t);
return clamp(dest);
}

View File

@@ -57,17 +57,15 @@ demo :: proc() {
a, b, c := &Int{}, &Int{}, &Int{};
defer destroy(a, b, c);
err = set(a, -512);
err = set(b, 1024);
err = set(a, -1024);
err = set(b, -1024);
print("a", a, 16);
print("b", b, 16);
print("a", a, 10);
print("b", b, 10);
fmt.println("--- swap ---");
foo(a, b);
print("a", a, 16);
print("b", b, 16);
fmt.println("--- mul ---");
mul(c, a, b);
print("c", c, 10);
}