mirror of
https://github.com/odin-lang/Odin.git
synced 2026-02-22 10:56:41 +00:00
big: Add multiplication.
This commit is contained in:
@@ -13,6 +13,7 @@ package big
|
||||
|
||||
import "core:mem"
|
||||
import "core:intrinsics"
|
||||
import "core:fmt"
|
||||
|
||||
/*
|
||||
===========================
|
||||
@@ -467,13 +468,8 @@ shl1 :: double;
|
||||
remainder = numerator % (1 << bits)
|
||||
*/
|
||||
int_mod_bits :: proc(remainder, numerator: ^Int, bits: int) -> (err: Error) {
|
||||
remainder := remainder; numerator := numerator;
|
||||
if err = clear_if_uninitialized(remainder); err != .None {
|
||||
return err;
|
||||
}
|
||||
if err = clear_if_uninitialized(numerator); err != .None {
|
||||
return err;
|
||||
}
|
||||
if err = clear_if_uninitialized(remainder); err != .None { return err; }
|
||||
if err = clear_if_uninitialized(numerator); err != .None { return err; }
|
||||
|
||||
if bits < 0 { return .Invalid_Argument; }
|
||||
if bits == 0 { return zero(remainder); }
|
||||
@@ -505,6 +501,161 @@ int_mod_bits :: proc(remainder, numerator: ^Int, bits: int) -> (err: Error) {
|
||||
}
|
||||
mod_bits :: proc { int_mod_bits, };
|
||||
|
||||
/*
|
||||
Multiply by a DIGIT.
|
||||
*/
|
||||
int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT) -> (err: Error) {
|
||||
if err = clear_if_uninitialized(src ); err != .None { return err; }
|
||||
if err = clear_if_uninitialized(dest); err != .None { return err; }
|
||||
|
||||
if multiplier == 0 {
|
||||
return zero(dest);
|
||||
}
|
||||
if multiplier == 1 {
|
||||
return copy(dest, src);
|
||||
}
|
||||
|
||||
/*
|
||||
Power of two?
|
||||
*/
|
||||
if multiplier == 2 {
|
||||
return double(dest, src);
|
||||
}
|
||||
if is_power_of_two(int(multiplier)) {
|
||||
ix: int;
|
||||
if ix, err = log_n(multiplier, 2); err != .None { return err; }
|
||||
return shl(dest, src, ix);
|
||||
}
|
||||
|
||||
/*
|
||||
Ensure `dest` is big enough to hold `src` * `multiplier`.
|
||||
*/
|
||||
if err = grow(dest, max(src.used + 1, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
|
||||
|
||||
/*
|
||||
Save the original used count.
|
||||
*/
|
||||
old_used := dest.used;
|
||||
/*
|
||||
Set the sign.
|
||||
*/
|
||||
dest.sign = src.sign;
|
||||
/*
|
||||
Set up carry.
|
||||
*/
|
||||
carry := _WORD(0);
|
||||
/*
|
||||
Compute columns.
|
||||
*/
|
||||
ix := 0;
|
||||
for ; ix < src.used; ix += 1 {
|
||||
/*
|
||||
Compute product and carry sum for this term
|
||||
*/
|
||||
product := carry + _WORD(src.digit[ix]) * _WORD(multiplier);
|
||||
/*
|
||||
Mask off higher bits to get a single DIGIT.
|
||||
*/
|
||||
dest.digit[ix] = DIGIT(product & _WORD(_MASK));
|
||||
/*
|
||||
Send carry into next iteration
|
||||
*/
|
||||
carry = product >> _DIGIT_BITS;
|
||||
}
|
||||
|
||||
/*
|
||||
Store final carry [if any] and increment used.
|
||||
*/
|
||||
dest.digit[ix] = DIGIT(carry);
|
||||
dest.used = src.used + 1;
|
||||
|
||||
/*
|
||||
Zero unused digits.
|
||||
*/
|
||||
zero_count := old_used - dest.used;
|
||||
if zero_count > 0 {
|
||||
mem.zero_slice(dest.digit[zero_count:]);
|
||||
}
|
||||
return clamp(dest);
|
||||
}
|
||||
|
||||
/*
|
||||
High level multiplication (handles sign).
|
||||
*/
|
||||
int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
|
||||
if err = clear_if_uninitialized(src); err != .None { return err; }
|
||||
if err = clear_if_uninitialized(dest); err != .None { return err; }
|
||||
if err = clear_if_uninitialized(multiplier); err != .None { return err; }
|
||||
|
||||
/*
|
||||
Early out for `multiplier` is zero; Set `dest` to zero.
|
||||
*/
|
||||
if z, _ := is_zero(multiplier); z {
|
||||
return zero(dest);
|
||||
}
|
||||
|
||||
min_used := min(src.used, multiplier.used);
|
||||
max_used := max(src.used, multiplier.used);
|
||||
digits := src.used + multiplier.used + 1;
|
||||
neg := src.sign != multiplier.sign;
|
||||
|
||||
if false && src == multiplier {
|
||||
/*
|
||||
Do we need to square?
|
||||
*/
|
||||
if false && src.used >= _SQR_TOOM_CUTOFF {
|
||||
/* Use Toom-Cook? */
|
||||
// err = s_mp_sqr_toom(a, c);
|
||||
} else if false && src.used >= _SQR_KARATSUBA_CUTOFF {
|
||||
/* Karatsuba? */
|
||||
// err = s_mp_sqr_karatsuba(a, c);
|
||||
} else if false && ((src.used * 2) + 1) < _WARRAY &&
|
||||
src.used < (_MAX_COMBA / 2) {
|
||||
/* Fast comba? */
|
||||
// err = s_mp_sqr_comba(a, c);
|
||||
} else {
|
||||
// err = s_mp_sqr(a, c);
|
||||
}
|
||||
} else {
|
||||
/*
|
||||
Can we use the balance method? Check sizes.
|
||||
* The smaller one needs to be larger than the Karatsuba cut-off.
|
||||
* The bigger one needs to be at least about one `_MUL_KARATSUBA_CUTOFF` bigger
|
||||
* to make some sense, but it depends on architecture, OS, position of the
|
||||
* stars... so YMMV.
|
||||
* Using it to cut the input into slices small enough for _mul_comba
|
||||
* was actually slower on the author's machine, but YMMV.
|
||||
*/
|
||||
if false && min_used >= _MUL_KARATSUBA_CUTOFF &&
|
||||
max_used / 2 >= _MUL_KARATSUBA_CUTOFF &&
|
||||
/*
|
||||
Not much effect was observed below a ratio of 1:2, but again: YMMV.
|
||||
*/
|
||||
max_used >= 2 * min_used {
|
||||
// err = s_mp_mul_balance(a,b,c);
|
||||
} else if false && min_used >= _MUL_TOOM_CUTOFF {
|
||||
// err = s_mp_mul_toom(a, b, c);
|
||||
} else if false && min_used >= _MUL_KARATSUBA_CUTOFF {
|
||||
// err = s_mp_mul_karatsuba(a, b, c);
|
||||
} else if false && digits < _WARRAY && min_used <= _MAX_COMBA {
|
||||
/*
|
||||
Can we use the fast multiplier?
|
||||
* The fast multiplier can be used if the output will
|
||||
* have less than MP_WARRAY digits and the number of
|
||||
* digits won't affect carry propagation
|
||||
*/
|
||||
// err = s_mp_mul_comba(a, b, c, digs);
|
||||
} else {
|
||||
fmt.println("Hai");
|
||||
err = _int_mul(dest, src, multiplier, digits);
|
||||
}
|
||||
}
|
||||
dest.sign = .Negative if dest.used > 0 && neg else .Zero_or_Positive;
|
||||
return err;
|
||||
}
|
||||
|
||||
mul :: proc { int_mul, int_mul_digit, };
|
||||
|
||||
/*
|
||||
==========================
|
||||
Low-level routines
|
||||
@@ -688,46 +839,54 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err = grow(dest, digits); err != .None { return err; }
|
||||
dest.used = digits;
|
||||
/*
|
||||
Set up temporary output `Int`, which we'll swap for `dest` when done.
|
||||
*/
|
||||
|
||||
t := &Int{};
|
||||
|
||||
if err = grow(t, max(digits, _DEFAULT_DIGIT_COUNT)); err != .None { return err; }
|
||||
t.used = digits;
|
||||
|
||||
/*
|
||||
Compute the digits of the product directly.
|
||||
*/
|
||||
pa := a.used;
|
||||
for ix := 0; ix < pa; ix += 1 {
|
||||
/*
|
||||
Limit ourselves to `digits` DIGITs of output.
|
||||
*/
|
||||
pb := min(b.used, digits - ix);
|
||||
carry := DIGIT(0);
|
||||
iy := 0;
|
||||
/*
|
||||
Compute the column of the output and propagate the carry.
|
||||
*/
|
||||
for iy = 0; iy < pb; iy += 1 {
|
||||
/*
|
||||
Compute the column as a _WORD.
|
||||
*/
|
||||
column := t.digit[ix + iy] + a.digit[ix] * b.digit[iy] + carry;
|
||||
|
||||
/* compute the digits of the product directly */
|
||||
pa = a->used;
|
||||
for (ix = 0; ix < pa; ix++) {
|
||||
int iy, pb;
|
||||
mp_digit u = 0;
|
||||
/*
|
||||
The new column is the lower part of the result.
|
||||
*/
|
||||
t.digit[ix + iy] = column & _MASK;
|
||||
|
||||
/* limit ourselves to making digs digits of output */
|
||||
pb = MP_MIN(b->used, digs - ix);
|
||||
|
||||
/* compute the columns of the output and propagate the carry */
|
||||
for (iy = 0; iy < pb; iy++) {
|
||||
/* compute the column as a mp_word */
|
||||
mp_word r = (mp_word)t.dp[ix + iy] +
|
||||
((mp_word)a->dp[ix] * (mp_word)b->dp[iy]) +
|
||||
(mp_word)u;
|
||||
|
||||
/* the new column is the lower part of the result */
|
||||
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
|
||||
|
||||
/* get the carry word from the result */
|
||||
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
|
||||
/*
|
||||
Get the carry word from the result.
|
||||
*/
|
||||
carry = column >> _DIGIT_BITS;
|
||||
}
|
||||
/* set carry if it is placed below digs */
|
||||
if ((ix + iy) < digs) {
|
||||
t.dp[ix + pb] = u;
|
||||
/*
|
||||
Set carry if it is placed below digits
|
||||
*/
|
||||
if ix + iy < digits {
|
||||
t.digit[ix + pb] = carry;
|
||||
}
|
||||
}
|
||||
|
||||
mp_clamp(&t);
|
||||
mp_exch(&t, c);
|
||||
|
||||
mp_clear(&t);
|
||||
return MP_OKAY;
|
||||
}
|
||||
|
||||
*/
|
||||
return .None;
|
||||
swap(dest, t);
|
||||
destroy(t);
|
||||
return clamp(dest);
|
||||
}
|
||||
@@ -57,17 +57,15 @@ demo :: proc() {
|
||||
a, b, c := &Int{}, &Int{}, &Int{};
|
||||
defer destroy(a, b, c);
|
||||
|
||||
err = set(a, -512);
|
||||
err = set(b, 1024);
|
||||
err = set(a, -1024);
|
||||
err = set(b, -1024);
|
||||
|
||||
print("a", a, 16);
|
||||
print("b", b, 16);
|
||||
print("a", a, 10);
|
||||
print("b", b, 10);
|
||||
|
||||
fmt.println("--- swap ---");
|
||||
foo(a, b);
|
||||
|
||||
print("a", a, 16);
|
||||
print("b", b, 16);
|
||||
fmt.println("--- mul ---");
|
||||
mul(c, a, b);
|
||||
print("c", c, 10);
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user