Add _mul_comba path.

This commit is contained in:
Jeroen van Rijn
2021-08-02 21:01:46 +02:00
parent 491e4ecc74
commit a27612ec6a
3 changed files with 120 additions and 23 deletions

View File

@@ -499,8 +499,7 @@ mod_bits :: proc { int_mod_bits, };
Multiply by a DIGIT.
*/
int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT) -> (err: Error) {
if err = clear_if_uninitialized(src ); err != .None { return err; }
if err = clear_if_uninitialized(dest); err != .None { return err; }
if err = clear_if_uninitialized(src, dest); err != .None { return err; }
if multiplier == 0 {
return zero(dest);
@@ -576,9 +575,7 @@ int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT) -> (err: Error) {
High level multiplication (handles sign).
*/
int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
if err = clear_if_uninitialized(src); err != .None { return err; }
if err = clear_if_uninitialized(dest); err != .None { return err; }
if err = clear_if_uninitialized(multiplier); err != .None { return err; }
if err = clear_if_uninitialized(dest, src, multiplier); err != .None { return err; }
/*
Early out for `multiplier` is zero; Set `dest` to zero.
@@ -587,11 +584,6 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
return zero(dest);
}
min_used := min(src.used, multiplier.used);
max_used := max(src.used, multiplier.used);
digits := src.used + multiplier.used + 1;
neg := src.sign != multiplier.sign;
if src == multiplier {
/*
Do we need to square?
@@ -619,6 +611,11 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
* Using it to cut the input into slices small enough for _mul_comba
* was actually slower on the author's machine, but YMMV.
*/
min_used := min(src.used, multiplier.used);
max_used := max(src.used, multiplier.used);
digits := src.used + multiplier.used + 1;
if false && min_used >= _MUL_KARATSUBA_CUTOFF &&
max_used / 2 >= _MUL_KARATSUBA_CUTOFF &&
/*
@@ -630,18 +627,19 @@ int_mul :: proc(dest, src, multiplier: ^Int) -> (err: Error) {
// err = s_mp_mul_toom(a, b, c);
} else if false && min_used >= _MUL_KARATSUBA_CUTOFF {
// err = s_mp_mul_karatsuba(a, b, c);
} else if false && digits < _WARRAY && min_used <= _MAX_COMBA {
} else if digits < _WARRAY && min_used <= _MAX_COMBA {
/*
Can we use the fast multiplier?
* The fast multiplier can be used if the output will
* have less than MP_WARRAY digits and the number of
* digits won't affect carry propagation
*/
// err = s_mp_mul_comba(a, b, c, digs);
err = _int_mul_comba(dest, src, multiplier, digits);
} else {
err = _int_mul(dest, src, multiplier, digits);
}
}
neg := src.sign != multiplier.sign;
dest.sign = .Negative if dest.used > 0 && neg else .Zero_or_Positive;
return err;
}
@@ -1033,14 +1031,11 @@ _int_sub :: proc(dest, number, decrease: ^Int) -> (err: Error) {
many digits of output are created.
*/
_int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
/*
Can we use the fast multiplier?
*/
when false { // Have Comba?
if digits < _WARRAY && min(a.used, b.used) < _MAX_COMBA {
return _int_mul_comba(dest, a, b, digits);
}
if digits < _WARRAY && min(a.used, b.used) < _MAX_COMBA {
return _int_mul_comba(dest, a, b, digits);
}
/*
@@ -1095,6 +1090,108 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
return clamp(dest);
}
/*
Fast (comba) multiplier
This is the fast column-array [comba] multiplier. It is
designed to compute the columns of the product first
then handle the carries afterwards. This has the effect
of making the nested loops that compute the columns very
simple and schedulable on super-scalar processors.
This has been modified to produce a variable number of
digits of output so if say only a half-product is required
you don't have to compute the upper half (a feature
required for fast Barrett reduction).
Based on Algorithm 14.12 on pp.595 of HAC.
*/
_int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
/*
Set up array.
*/
W: [_WARRAY]DIGIT = ---;
/*
Grow the destination as required.
*/
if err = grow(dest, digits); err != .None { return err; }
/*
Number of output digits to produce.
*/
pa := min(digits, a.used + b.used);
/*
Clear the carry
*/
_W := _WORD(0);
ix: int;
for ix = 0; ix < pa; ix += 1 {
tx, ty, iy, iz: int;
/*
Get offsets into the two bignums.
*/
ty = min(b.used - 1, ix);
tx = ix - ty;
/*
This is the number of times the loop will iterate, essentially.
while (tx++ < a->used && ty-- >= 0) { ... }
*/
iy = min(a.used - tx, ty + 1);
/*
Execute loop.
*/
for iz = 0; iz < iy; iz += 1 {
_W += _WORD(a.digit[tx + iz]) * _WORD(b.digit[ty - iz]);
}
/*
Store term.
*/
W[ix] = DIGIT(_W) & _MASK;
/*
Make next carry.
*/
_W = _W >> _WORD(_DIGIT_BITS);
}
/*
Setup dest.
*/
old_used := dest.used;
dest.used = pa;
for ix = 0; ix < pa; ix += 1 {
/*
Now extract the previous digit [below the carry].
*/
dest.digit[ix] = W[ix];
}
/*
Clear unused digits [that existed in the old copy of dest].
*/
zero_count := old_used - dest.used;
/*
Zero remainder.
*/
if zero_count > 0 {
mem.zero_slice(dest.digit[dest.used:][:zero_count]);
}
/*
Adjust dest.used based on leading zeroes.
*/
return clamp(dest);
}
/*
Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
*/

View File

@@ -1,10 +1,10 @@
@echo off
odin run . -vet
:odin run . -vet
: -o:size -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:minimal -use-separate-modules
:odin build . -build-mode:shared -show-timings -o:size -use-separate-modules -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:size -use-separate-modules
:odin build . -build-mode:shared -show-timings -o:speed -use-separate-modules -no-bounds-check
odin build . -build-mode:shared -show-timings -o:speed -use-separate-modules -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:speed -use-separate-modules
:python test.py
python test.py

View File

@@ -87,7 +87,7 @@ Event :: struct {
}
Timings := [Category]Event{};
print :: proc(name: string, a: ^Int, base := i8(10), print_extra_info := false, print_name := false, newline := true) {
print :: proc(name: string, a: ^Int, base := i8(10), print_name := false, newline := true, print_extra_info := false) {
s := time.tick_now();
as, err := itoa(a, base);
Timings[.itoa].t += time.tick_since(s); Timings[.itoa].c += 1;
@@ -117,10 +117,10 @@ demo :: proc() {
defer destroy(a, b, c, d, e, f);
s := time.tick_now();
err = choose(a, 1024, 255);
err = choose(a, 65535, 255);
Timings[.choose].t += time.tick_since(s); Timings[.choose].c += 1;
print("1024 choose 255", a, 10, true, true, true);
print("65535 choose 255", a, 10, true, true, true);
fmt.printf("Error: %v\n", err);
}