big: Move _mul private functions.

This commit is contained in:
Jeroen van Rijn
2021-08-07 15:27:27 +02:00
parent 6298226238
commit e288a563e1
4 changed files with 257 additions and 229 deletions

View File

@@ -290,168 +290,6 @@ int_choose_digit :: proc(res: ^Int, n, k: int) -> (err: Error) {
}
choose :: proc { int_choose_digit, };
/*
Multiplies |a| * |b| and only computes upto digs digits of result.
HAC pp. 595, Algorithm 14.12 Modified so you can control how
many digits of output are created.
*/
_int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
/*
Can we use the fast multiplier?
*/
if digits < _WARRAY && min(a.used, b.used) < _MAX_COMBA {
return _int_mul_comba(dest, a, b, digits);
}
/*
Set up temporary output `Int`, which we'll swap for `dest` when done.
*/
t := &Int{};
if err = grow(t, max(digits, _DEFAULT_DIGIT_COUNT)); err != nil { return err; }
t.used = digits;
/*
Compute the digits of the product directly.
*/
pa := a.used;
for ix := 0; ix < pa; ix += 1 {
/*
Limit ourselves to `digits` DIGITs of output.
*/
pb := min(b.used, digits - ix);
carry := _WORD(0);
iy := 0;
/*
Compute the column of the output and propagate the carry.
*/
#no_bounds_check for iy = 0; iy < pb; iy += 1 {
/*
Compute the column as a _WORD.
*/
column := _WORD(t.digit[ix + iy]) + _WORD(a.digit[ix]) * _WORD(b.digit[iy]) + carry;
/*
The new column is the lower part of the result.
*/
t.digit[ix + iy] = DIGIT(column & _WORD(_MASK));
/*
Get the carry word from the result.
*/
carry = column >> _DIGIT_BITS;
}
/*
Set carry if it is placed below digits
*/
if ix + iy < digits {
t.digit[ix + pb] = DIGIT(carry);
}
}
swap(dest, t);
destroy(t);
return clamp(dest);
}
/*
Fast (comba) multiplier
This is the fast column-array [comba] multiplier. It is
designed to compute the columns of the product first
then handle the carries afterwards. This has the effect
of making the nested loops that compute the columns very
simple and schedulable on super-scalar processors.
This has been modified to produce a variable number of
digits of output so if say only a half-product is required
you don't have to compute the upper half (a feature
required for fast Barrett reduction).
Based on Algorithm 14.12 on pp.595 of HAC.
*/
_int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
/*
Set up array.
*/
W: [_WARRAY]DIGIT = ---;
/*
Grow the destination as required.
*/
if err = grow(dest, digits); err != nil { return err; }
/*
Number of output digits to produce.
*/
pa := min(digits, a.used + b.used);
/*
Clear the carry
*/
_W := _WORD(0);
ix: int;
for ix = 0; ix < pa; ix += 1 {
tx, ty, iy, iz: int;
/*
Get offsets into the two bignums.
*/
ty = min(b.used - 1, ix);
tx = ix - ty;
/*
This is the number of times the loop will iterate, essentially.
while (tx++ < a->used && ty-- >= 0) { ... }
*/
iy = min(a.used - tx, ty + 1);
/*
Execute loop.
*/
#no_bounds_check for iz = 0; iz < iy; iz += 1 {
_W += _WORD(a.digit[tx + iz]) * _WORD(b.digit[ty - iz]);
}
/*
Store term.
*/
W[ix] = DIGIT(_W) & _MASK;
/*
Make next carry.
*/
_W = _W >> _WORD(_DIGIT_BITS);
}
/*
Setup dest.
*/
old_used := dest.used;
dest.used = pa;
/*
Now extract the previous digit [below the carry].
*/
// for ix = 0; ix < pa; ix += 1 { dest.digit[ix] = W[ix]; }
copy_slice(dest.digit[0:], W[:pa]);
/*
Clear unused digits [that existed in the old copy of dest].
*/
zero_unused(dest, old_used);
/*
Adjust dest.used based on leading zeroes.
*/
return clamp(dest);
}
/*
Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16

View File

@@ -65,7 +65,7 @@ demo :: proc() {
a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(a, b, c, d, e, f);
n := 50_000;
n := 1_024;
k := 3;
{

View File

@@ -643,9 +643,9 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
* have less than MP_WARRAY digits and the number of
* digits won't affect carry propagation
*/
err = _int_mul_comba(dest, src, multiplier, digits);
err = _private_int_mul_comba(dest, src, multiplier, digits);
} else {
err = _int_mul(dest, src, multiplier, digits);
err = _private_int_mul(dest, src, multiplier, digits);
}
}
neg := src.sign != multiplier.sign;
@@ -848,7 +848,7 @@ internal_sqrmod :: proc { internal_int_sqrmod, };
*/
internal_int_factorial :: proc(res: ^Int, n: int) -> (err: Error) {
if n >= _FACTORIAL_BINARY_SPLIT_CUTOFF {
return #force_inline _int_factorial_binary_split(res, n);
return #force_inline _private_int_factorial_binary_split(res, n);
}
i := len(_factorial_table);
@@ -865,62 +865,6 @@ internal_int_factorial :: proc(res: ^Int, n: int) -> (err: Error) {
return nil;
}
_int_recursive_product :: proc(res: ^Int, start, stop: int, level := int(0)) -> (err: Error) {
t1, t2 := &Int{}, &Int{};
defer destroy(t1, t2);
if level > _FACTORIAL_BINARY_SPLIT_MAX_RECURSIONS { return .Max_Iterations_Reached; }
num_factors := (stop - start) >> 1;
if num_factors == 2 {
if err = set(t1, start); err != nil { return err; }
when true {
if err = grow(t2, t1.used + 1); err != nil { return err; }
if err = internal_add(t2, t1, 2); err != nil { return err; }
} else {
if err = add(t2, t1, 2); err != nil { return err; }
}
return internal_mul(res, t1, t2);
}
if num_factors > 1 {
mid := (start + num_factors) | 1;
if err = _int_recursive_product(t1, start, mid, level + 1); err != nil { return err; }
if err = _int_recursive_product(t2, mid, stop, level + 1); err != nil { return err; }
return internal_mul(res, t1, t2);
}
if num_factors == 1 { return #force_inline set(res, start); }
return #force_inline set(res, 1);
}
/*
Binary split factorial algo due to: http://www.luschny.de/math/factorial/binarysplitfact.html
*/
_int_factorial_binary_split :: proc(res: ^Int, n: int) -> (err: Error) {
inner, outer, start, stop, temp := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(inner, outer, start, stop, temp);
if err = set(inner, 1); err != nil { return err; }
if err = set(outer, 1); err != nil { return err; }
bits_used := int(_DIGIT_TYPE_BITS - intrinsics.count_leading_zeros(n));
for i := bits_used; i >= 0; i -= 1 {
start := (n >> (uint(i) + 1)) + 1 | 1;
stop := (n >> uint(i)) + 1 | 1;
if err = _int_recursive_product(temp, start, stop); err != nil { return err; }
if err = internal_mul(inner, inner, temp); err != nil { return err; }
if err = internal_mul(outer, outer, inner); err != nil { return err; }
}
shift := n - intrinsics.count_ones(n);
return shl(res, outer, int(shift));
}
internal_int_zero_unused :: #force_inline proc(dest: ^Int, old_used := -1) {
/*
@@ -943,8 +887,250 @@ internal_int_zero_unused :: #force_inline proc(dest: ^Int, old_used := -1) {
internal_zero_unused :: proc { internal_int_zero_unused, };
/*
Tables.
========================== End of low-level routines ==========================
============================= Private procedures =============================
Private procedures used by the above low-level routines follow.
Don't call these yourself unless you really know what you're doing.
They include implementations that are optimimal for certain ranges of input only.
These aren't exported for the same reasons.
*/
/*
Multiplies |a| * |b| and only computes upto digs digits of result.
HAC pp. 595, Algorithm 14.12 Modified so you can control how
many digits of output are created.
*/
_private_int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
/*
Can we use the fast multiplier?
*/
if digits < _WARRAY && min(a.used, b.used) < _MAX_COMBA {
return _private_int_mul_comba(dest, a, b, digits);
}
/*
Set up temporary output `Int`, which we'll swap for `dest` when done.
*/
t := &Int{};
if err = grow(t, max(digits, _DEFAULT_DIGIT_COUNT)); err != nil { return err; }
t.used = digits;
/*
Compute the digits of the product directly.
*/
pa := a.used;
for ix := 0; ix < pa; ix += 1 {
/*
Limit ourselves to `digits` DIGITs of output.
*/
pb := min(b.used, digits - ix);
carry := _WORD(0);
iy := 0;
/*
Compute the column of the output and propagate the carry.
*/
#no_bounds_check for iy = 0; iy < pb; iy += 1 {
/*
Compute the column as a _WORD.
*/
column := _WORD(t.digit[ix + iy]) + _WORD(a.digit[ix]) * _WORD(b.digit[iy]) + carry;
/*
The new column is the lower part of the result.
*/
t.digit[ix + iy] = DIGIT(column & _WORD(_MASK));
/*
Get the carry word from the result.
*/
carry = column >> _DIGIT_BITS;
}
/*
Set carry if it is placed below digits
*/
if ix + iy < digits {
t.digit[ix + pb] = DIGIT(carry);
}
}
swap(dest, t);
destroy(t);
return clamp(dest);
}
/*
Fast (comba) multiplier
This is the fast column-array [comba] multiplier. It is
designed to compute the columns of the product first
then handle the carries afterwards. This has the effect
of making the nested loops that compute the columns very
simple and schedulable on super-scalar processors.
This has been modified to produce a variable number of
digits of output so if say only a half-product is required
you don't have to compute the upper half (a feature
required for fast Barrett reduction).
Based on Algorithm 14.12 on pp.595 of HAC.
*/
_private_int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
/*
Set up array.
*/
W: [_WARRAY]DIGIT = ---;
/*
Grow the destination as required.
*/
if err = grow(dest, digits); err != nil { return err; }
/*
Number of output digits to produce.
*/
pa := min(digits, a.used + b.used);
/*
Clear the carry
*/
_W := _WORD(0);
ix: int;
for ix = 0; ix < pa; ix += 1 {
tx, ty, iy, iz: int;
/*
Get offsets into the two bignums.
*/
ty = min(b.used - 1, ix);
tx = ix - ty;
/*
This is the number of times the loop will iterate, essentially.
while (tx++ < a->used && ty-- >= 0) { ... }
*/
iy = min(a.used - tx, ty + 1);
/*
Execute loop.
*/
#no_bounds_check for iz = 0; iz < iy; iz += 1 {
_W += _WORD(a.digit[tx + iz]) * _WORD(b.digit[ty - iz]);
}
/*
Store term.
*/
W[ix] = DIGIT(_W) & _MASK;
/*
Make next carry.
*/
_W = _W >> _WORD(_DIGIT_BITS);
}
/*
Setup dest.
*/
old_used := dest.used;
dest.used = pa;
/*
Now extract the previous digit [below the carry].
*/
// for ix = 0; ix < pa; ix += 1 { dest.digit[ix] = W[ix]; }
copy_slice(dest.digit[0:], W[:pa]);
/*
Clear unused digits [that existed in the old copy of dest].
*/
zero_unused(dest, old_used);
/*
Adjust dest.used based on leading zeroes.
*/
return clamp(dest);
}
/*
Binary split factorial algo due to: http://www.luschny.de/math/factorial/binarysplitfact.html
*/
_private_int_factorial_binary_split :: proc(res: ^Int, n: int) -> (err: Error) {
inner, outer, start, stop, temp := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(inner, outer, start, stop, temp);
if err = set(inner, 1); err != nil { return err; }
if err = set(outer, 1); err != nil { return err; }
bits_used := int(_DIGIT_TYPE_BITS - intrinsics.count_leading_zeros(n));
for i := bits_used; i >= 0; i -= 1 {
start := (n >> (uint(i) + 1)) + 1 | 1;
stop := (n >> uint(i)) + 1 | 1;
if err = _private_int_recursive_product(temp, start, stop); err != nil { return err; }
if err = internal_mul(inner, inner, temp); err != nil { return err; }
if err = internal_mul(outer, outer, inner); err != nil { return err; }
}
shift := n - intrinsics.count_ones(n);
return shl(res, outer, int(shift));
}
/*
Recursive product used by binary split factorial algorithm.
*/
_private_int_recursive_product :: proc(res: ^Int, start, stop: int, level := int(0)) -> (err: Error) {
t1, t2 := &Int{}, &Int{};
defer destroy(t1, t2);
if level > _FACTORIAL_BINARY_SPLIT_MAX_RECURSIONS { return .Max_Iterations_Reached; }
num_factors := (stop - start) >> 1;
if num_factors == 2 {
if err = set(t1, start); err != nil { return err; }
when true {
if err = grow(t2, t1.used + 1); err != nil { return err; }
if err = internal_add(t2, t1, 2); err != nil { return err; }
} else {
if err = add(t2, t1, 2); err != nil { return err; }
}
return internal_mul(res, t1, t2);
}
if num_factors > 1 {
mid := (start + num_factors) | 1;
if err = _private_int_recursive_product(t1, start, mid, level + 1); err != nil { return err; }
if err = _private_int_recursive_product(t2, mid, stop, level + 1); err != nil { return err; }
return internal_mul(res, t1, t2);
}
if num_factors == 1 { return #force_inline set(res, start); }
return #force_inline set(res, 1);
}
/*
======================== End of private procedures =======================
=============================== Private tables ===============================
Tables used by `internal_*` and `_*`.
*/
when MATH_BIG_FORCE_64_BIT || (!MATH_BIG_FORCE_32_BIT && size_of(rawptr) == 8) {
@@ -1009,4 +1195,8 @@ when MATH_BIG_FORCE_64_BIT || (!MATH_BIG_FORCE_32_BIT && size_of(rawptr) == 8) {
/* f(19): */ 121_645_100_408_832_000,
/* f(20): */ 2_432_902_008_176_640_000,
};
};
};
/*
========================= End of private tables ========================
*/

View File

@@ -96,11 +96,11 @@ class Error(Enum):
# Set up exported procedures
#
try:
l = cdll.LoadLibrary(LIB_PATH)
except:
print("Couldn't find or load " + LIB_PATH + ".")
exit(1)
# try:
l = cdll.LoadLibrary(LIB_PATH)
# except:
# print("Couldn't find or load " + LIB_PATH + ".")
# exit(1)
def load(export_name, args, res):
export_name.argtypes = args