Merge pull request #1108 from Kelimion/bigint

big: Add two more asymptotically optimal multiplication methods.
This commit is contained in:
Jeroen van Rijn
2021-08-28 18:19:55 +02:00
committed by GitHub
4 changed files with 227 additions and 40 deletions

View File

@@ -205,15 +205,6 @@ int_to_byte_little :: proc(v: ^Int) {
demo :: proc() {
a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(a, b, c, d, e, f);
foo := "92232459121502451677697058974826760244863271517919321608054113675118660929276431348516553336313179167211015633639725554914519355444316239500734169769447134357534241879421978647995614218985202290368055757891124109355450669008628757662409138767505519391883751112010824030579849970582074544353971308266211776494228299586414907715854328360867232691292422194412634523666770452490676515117702116926803826546868467146319938818238521874072436856528051486567230096290549225463582766830777324099589751817442141036031904145041055454639783559905920619197290800070679733841430619962318433709503256637256772215111521321630777950145713049902839937043785039344243357384899099910837463164007565230287809026956254332260375327814271845678201";
set(a, foo);
print("a: ", a);
is_sqr, _ := internal_int_is_square(a);
fmt.printf("is_square: %v\n", is_sqr);
}
main :: proc() {

View File

@@ -659,8 +659,7 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
Can we use the balance method? Check sizes.
* The smaller one needs to be larger than the Karatsuba cut-off.
* The bigger one needs to be at least about one `_MUL_KARATSUBA_CUTOFF` bigger
* to make some sense, but it depends on architecture, OS, position of the
* stars... so YMMV.
* to make some sense, but it depends on architecture, OS, position of the stars... so YMMV.
* Using it to cut the input into slices small enough for _mul_comba
* was actually slower on the author's machine, but YMMV.
*/
@@ -669,13 +668,11 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
max_used := max(src.used, multiplier.used);
digits := src.used + multiplier.used + 1;
if false && min_used >= MUL_KARATSUBA_CUTOFF &&
max_used / 2 >= MUL_KARATSUBA_CUTOFF &&
if min_used >= MUL_KARATSUBA_CUTOFF && (max_used / 2) >= MUL_KARATSUBA_CUTOFF && max_used >= (2 * min_used) {
/*
Not much effect was observed below a ratio of 1:2, but again: YMMV.
*/
max_used >= 2 * min_used {
// err = s_mp_mul_balance(a,b,c);
err = _private_int_mul_balance(dest, src, multiplier);
} else if min_used >= MUL_TOOM_CUTOFF {
/*
Toom path commented out until it no longer fails Factorial 10k or 100k,
@@ -914,7 +911,7 @@ internal_int_factorial :: proc(res: ^Int, n: int, allocator := context.allocator
context.allocator = allocator;
if n >= FACTORIAL_BINARY_SPLIT_CUTOFF {
return #force_inline _private_int_factorial_binary_split(res, n);
return _private_int_factorial_binary_split(res, n);
}
i := len(_factorial_table);

View File

@@ -113,7 +113,7 @@ _private_int_mul_toom :: proc(dest, a, b: ^Int, allocator := context.allocator)
context.allocator = allocator;
S1, S2, T1, a0, a1, a2, b0, b1, b2 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(S1, S2, T1, a0, a1, a2, b0, b1, b2);
defer internal_destroy(S1, S2, T1, a0, a1, a2, b0, b1, b2);
/*
Init temps.
@@ -258,7 +258,7 @@ _private_int_mul_karatsuba :: proc(dest, a, b: ^Int, allocator := context.alloca
context.allocator = allocator;
x0, x1, y0, y1, t1, x0y0, x1y1 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(x0, x1, y0, y1, t1, x0y0, x1y1);
defer internal_destroy(x0, x1, y0, y1, t1, x0y0, x1y1);
/*
min # of digits, divided by two.
@@ -426,6 +426,195 @@ _private_int_mul_comba :: proc(dest, a, b: ^Int, digits: int, allocator := conte
return internal_clamp(dest);
}
/*
Multiplies |a| * |b| and does not compute the lower digs digits
[meant to get the higher part of the product]
*/
_private_int_mul_high :: proc(dest, a, b: ^Int, digits: int, allocator := context.allocator) -> (err: Error) {
context.allocator = allocator;
/*
Can we use the fast multiplier?
*/
if a.used + b.used + 1 < _WARRAY && min(a.used, b.used) < _MAX_COMBA {
return _private_int_mul_high_comba(dest, a, b, digits);
}
internal_grow(dest, a.used + b.used + 1) or_return;
dest.used = a.used + b.used + 1;
pa := a.used;
pb := b.used;
for ix := 0; ix < pa; ix += 1 {
carry := DIGIT(0);
for iy := digits - ix; iy < pb; iy += 1 {
/*
Calculate the double precision result.
*/
r := _WORD(dest.digit[ix + iy]) + _WORD(a.digit[ix]) * _WORD(b.digit[iy]) + _WORD(carry);
/*
Get the lower part.
*/
dest.digit[ix + iy] = DIGIT(r & _WORD(_MASK));
/*
Carry the carry.
*/
carry = DIGIT(r >> _WORD(_DIGIT_BITS));
}
dest.digit[ix + pb] = carry;
}
return internal_clamp(dest);
}
/*
This is a modified version of `_private_int_mul_comba` that only produces output digits *above* `digits`.
See the comments for `_private_int_mul_comba` to see how it works.
This is used in the Barrett reduction since for one of the multiplications
only the higher digits were needed. This essentially halves the work.
Based on Algorithm 14.12 on pp.595 of HAC.
*/
_private_int_mul_high_comba :: proc(dest, a, b: ^Int, digits: int, allocator := context.allocator) -> (err: Error) {
context.allocator = allocator;
W: [_WARRAY]DIGIT = ---;
_W: _WORD = 0;
/*
Number of output digits to produce. Grow the destination as required.
*/
pa := a.used + b.used;
internal_grow(dest, pa) or_return;
ix: int;
for ix = digits; ix < pa; ix += 1 {
/*
Get offsets into the two bignums.
*/
ty := min(b.used - 1, ix);
tx := ix - ty;
/*
This is the number of times the loop will iterrate, essentially it's
while (tx++ < a->used && ty-- >= 0) { ... }
*/
iy := min(a.used - tx, ty + 1);
/*
Execute loop.
*/
for iz := 0; iz < iy; iz += 1 {
_W += _WORD(a.digit[tx + iz]) * _WORD(b.digit[ty - iz]);
}
/*
Store term.
*/
W[ix] = DIGIT(_W) & DIGIT(_MASK);
/*
Make next carry.
*/
_W = _W >> _WORD(_DIGIT_BITS);
}
/*
Setup dest
*/
old_used := dest.used;
dest.used = pa;
for ix = digits; ix < pa; ix += 1 {
/*
Now extract the previous digit [below the carry].
*/
dest.digit[ix] = W[ix];
}
/*
Zero remainder.
*/
internal_zero_unused(dest, old_used);
/*
Adjust dest.used based on leading zeroes.
*/
return internal_clamp(dest);
}
/*
Single-digit multiplication with the smaller number as the single-digit.
*/
_private_int_mul_balance :: proc(dest, a, b: ^Int, allocator := context.allocator) -> (err: Error) {
context.allocator = allocator;
a, b := a, b;
a0, tmp, r := &Int{}, &Int{}, &Int{};
defer internal_destroy(a0, tmp, r);
b_size := min(a.used, b.used);
n_blocks := max(a.used, b.used) / b_size;
internal_grow(a0, b_size + 2) or_return;
internal_init_multi(tmp, r) or_return;
/*
Make sure that `a` is the larger one.
*/
if a.used < b.used {
a, b = b, a;
}
assert(a.used >= b.used);
i, j := 0, 0;
for ; i < n_blocks; i += 1 {
/*
Cut a slice off of `a`.
*/
a0.used = b_size;
internal_copy_digits(a0, a, a0.used, j);
j += a0.used;
internal_clamp(a0);
/*
Multiply with `b`.
*/
internal_mul(tmp, a0, b) or_return;
/*
Shift `tmp` to the correct position.
*/
internal_shl_digit(tmp, b_size * i) or_return;
/*
Add to output. No carry needed.
*/
internal_add(r, r, tmp) or_return;
}
/*
The left-overs; there are always left-overs.
*/
if j < a.used {
a0.used = a.used - j;
internal_copy_digits(a0, a, a0.used, j);
j += a0.used;
internal_clamp(a0);
internal_mul(tmp, a0, b) or_return;
internal_shl_digit(tmp, b_size * i) or_return;
internal_add(r, r, tmp) or_return;
}
internal_swap(dest, r);
return;
}
/*
Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
Assumes `dest` and `src` to not be `nil`, and `src` to have been initialized.
@@ -1188,7 +1377,7 @@ _private_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int
ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
c: int;
defer destroy(ta, tb, tq, q);
defer internal_destroy(ta, tb, tq, q);
for {
internal_one(tq) or_return;
@@ -1241,31 +1430,34 @@ _private_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int
Binary split factorial algo due to: http://www.luschny.de/math/factorial/binarysplitfact.html
*/
_private_int_factorial_binary_split :: proc(res: ^Int, n: int, allocator := context.allocator) -> (err: Error) {
context.allocator = allocator;
inner, outer, start, stop, temp := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer internal_destroy(inner, outer, start, stop, temp);
internal_one(inner, false, allocator) or_return;
internal_one(outer, false, allocator) or_return;
internal_one(inner, false) or_return;
internal_one(outer, false) or_return;
bits_used := int(_DIGIT_TYPE_BITS - intrinsics.count_leading_zeros(n));
for i := bits_used; i >= 0; i -= 1 {
start := (n >> (uint(i) + 1)) + 1 | 1;
stop := (n >> uint(i)) + 1 | 1;
_private_int_recursive_product(temp, start, stop, 0, allocator) or_return;
internal_mul(inner, inner, temp, allocator) or_return;
internal_mul(outer, outer, inner, allocator) or_return;
_private_int_recursive_product(temp, start, stop, 0) or_return;
internal_mul(inner, inner, temp) or_return;
internal_mul(outer, outer, inner) or_return;
}
shift := n - intrinsics.count_ones(n);
return internal_shl(res, outer, int(shift), allocator);
return internal_shl(res, outer, int(shift));
}
/*
Recursive product used by binary split factorial algorithm.
*/
_private_int_recursive_product :: proc(res: ^Int, start, stop: int, level := int(0), allocator := context.allocator) -> (err: Error) {
context.allocator = allocator;
t1, t2 := &Int{}, &Int{};
defer internal_destroy(t1, t2);
@@ -1275,28 +1467,28 @@ _private_int_recursive_product :: proc(res: ^Int, start, stop: int, level := int
num_factors := (stop - start) >> 1;
if num_factors == 2 {
internal_set(t1, start, false, allocator) or_return;
internal_set(t1, start, false) or_return;
when true {
internal_grow(t2, t1.used + 1, false, allocator) or_return;
internal_add(t2, t1, 2, allocator) or_return;
internal_grow(t2, t1.used + 1, false) or_return;
internal_add(t2, t1, 2) or_return;
} else {
add(t2, t1, 2) or_return;
internal_add(t2, t1, 2) or_return;
}
return internal_mul(res, t1, t2, allocator);
return internal_mul(res, t1, t2);
}
if num_factors > 1 {
mid := (start + num_factors) | 1;
_private_int_recursive_product(t1, start, mid, level + 1, allocator) or_return;
_private_int_recursive_product(t2, mid, stop, level + 1, allocator) or_return;
return internal_mul(res, t1, t2, allocator);
_private_int_recursive_product(t1, start, mid, level + 1) or_return;
_private_int_recursive_product(t2, mid, stop, level + 1) or_return;
return internal_mul(res, t1, t2);
}
if num_factors == 1 {
return #force_inline internal_set(res, start, true, allocator);
return #force_inline internal_set(res, start, true);
}
return #force_inline internal_one(res, true, allocator);
return #force_inline internal_one(res, true);
}
/*

View File

@@ -403,14 +403,21 @@ def test_shr_signed(a = 0, bits = 0, expected_error = Error.Okay):
return test("test_shr_signed", res, [a, bits], expected_error, expected_result)
def test_factorial(n = 0, expected_error = Error.Okay):
args = [n]
res = int_factorial(*args)
def test_factorial(number = 0, expected_error = Error.Okay):
print("Factorial:", number)
args = [number]
try:
res = int_factorial(*args)
except OSError as e:
print("{} while trying to factorial {}.".format(e, number))
if EXIT_ON_FAIL: exit(3)
return False
expected_result = None
if expected_error == Error.Okay:
expected_result = math.factorial(n)
expected_result = math.factorial(number)
return test("test_factorial", res, [n], expected_error, expected_result)
return test("test_factorial", res, [number], expected_error, expected_result)
def test_gcd(a = 0, b = 0, expected_error = Error.Okay):
args = [arg_to_odin(a), arg_to_odin(b)]