big: Move division internals.

This commit is contained in:
Jeroen van Rijn
2021-08-07 16:52:04 +02:00
parent e288a563e1
commit 62dcccd7ef
4 changed files with 358 additions and 362 deletions

View File

@@ -290,351 +290,6 @@ int_choose_digit :: proc(res: ^Int, n, k: int) -> (err: Error) {
}
choose :: proc { int_choose_digit, };
/*
Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
*/
_int_sqr :: proc(dest, src: ^Int) -> (err: Error) {
pa := src.used;
t := &Int{}; ix, iy: int;
/*
Grow `t` to maximum needed size, or `_DEFAULT_DIGIT_COUNT`, whichever is bigger.
*/
if err = grow(t, max((2 * pa) + 1, _DEFAULT_DIGIT_COUNT)); err != nil { return err; }
t.used = (2 * pa) + 1;
#no_bounds_check for ix = 0; ix < pa; ix += 1 {
carry := DIGIT(0);
/*
First calculate the digit at 2*ix; calculate double precision result.
*/
r := _WORD(t.digit[ix+ix]) + (_WORD(src.digit[ix]) * _WORD(src.digit[ix]));
/*
Store lower part in result.
*/
t.digit[ix+ix] = DIGIT(r & _WORD(_MASK));
/*
Get the carry.
*/
carry = DIGIT(r >> _DIGIT_BITS);
#no_bounds_check for iy = ix + 1; iy < pa; iy += 1 {
/*
First calculate the product.
*/
r = _WORD(src.digit[ix]) * _WORD(src.digit[iy]);
/* Now calculate the double precision result. Nóte we use
* addition instead of *2 since it's easier to optimize
*/
r = _WORD(t.digit[ix+iy]) + r + r + _WORD(carry);
/*
Store lower part.
*/
t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
/*
Get carry.
*/
carry = DIGIT(r >> _DIGIT_BITS);
}
/*
Propagate upwards.
*/
#no_bounds_check for carry != 0 {
r = _WORD(t.digit[ix+iy]) + _WORD(carry);
t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
carry = DIGIT(r >> _WORD(_DIGIT_BITS));
iy += 1;
}
}
err = clamp(t);
swap(dest, t);
destroy(t);
return err;
}
/*
Divide by three (based on routine from MPI and the GMP manual).
*/
_int_div_3 :: proc(quotient, numerator: ^Int) -> (remainder: DIGIT, err: Error) {
/*
b = 2**MP_DIGIT_BIT / 3
*/
b := _WORD(1) << _WORD(_DIGIT_BITS) / _WORD(3);
q := &Int{};
if err = grow(q, numerator.used); err != nil { return 0, err; }
q.used = numerator.used;
q.sign = numerator.sign;
w, t: _WORD;
for ix := numerator.used; ix >= 0; ix -= 1 {
w = (w << _WORD(_DIGIT_BITS)) | _WORD(numerator.digit[ix]);
if w >= 3 {
/*
Multiply w by [1/3].
*/
t = (w * b) >> _WORD(_DIGIT_BITS);
/*
Now subtract 3 * [w/3] from w, to get the remainder.
*/
w -= t+t+t;
/*
Fixup the remainder as required since the optimization is not exact.
*/
for w >= 3 {
t += 1;
w -= 3;
}
} else {
t = 0;
}
q.digit[ix] = DIGIT(t);
}
remainder = DIGIT(w);
/*
[optional] store the quotient.
*/
if quotient != nil {
err = clamp(q);
swap(q, quotient);
}
destroy(q);
return remainder, nil;
}
/*
Signed Integer Division
c*b + d == a [i.e. a/b, c=quotient, d=remainder], HAC pp.598 Algorithm 14.20
Note that the description in HAC is horribly incomplete.
For example, it doesn't consider the case where digits are removed from 'x' in
the inner loop.
It also doesn't consider the case that y has fewer than three digits, etc.
The overall algorithm is as described as 14.20 from HAC but fixed to treat these cases.
*/
_int_div_school :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
if err = error_if_immutable(quotient, remainder); err != nil { return err; }
if err = clear_if_uninitialized(quotient, numerator, denominator); err != nil { return err; }
q, x, y, t1, t2 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(q, x, y, t1, t2);
if err = grow(q, numerator.used + 2); err != nil { return err; }
q.used = numerator.used + 2;
if err = init_multi(t1, t2); err != nil { return err; }
if err = copy(x, numerator); err != nil { return err; }
if err = copy(y, denominator); err != nil { return err; }
/*
Fix the sign.
*/
neg := numerator.sign != denominator.sign;
x.sign = .Zero_or_Positive;
y.sign = .Zero_or_Positive;
/*
Normalize both x and y, ensure that y >= b/2, [b == 2**MP_DIGIT_BIT]
*/
norm, _ := count_bits(y);
norm %= _DIGIT_BITS;
if norm < _DIGIT_BITS - 1 {
norm = (_DIGIT_BITS - 1) - norm;
if err = shl(x, x, norm); err != nil { return err; }
if err = shl(y, y, norm); err != nil { return err; }
} else {
norm = 0;
}
/*
Note: HAC does 0 based, so if used==5 then it's 0,1,2,3,4, i.e. use 4
*/
n := x.used - 1;
t := y.used - 1;
/*
while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} }
y = y*b**{n-t}
*/
if err = shl_digit(y, n - t); err != nil { return err; }
c, _ := cmp(x, y);
for c != -1 {
q.digit[n - t] += 1;
if err = sub(x, x, y); err != nil { return err; }
c, _ = cmp(x, y);
}
/*
Reset y by shifting it back down.
*/
shr_digit(y, n - t);
/*
Step 3. for i from n down to (t + 1).
*/
for i := n; i >= (t + 1); i -= 1 {
if (i > x.used) { continue; }
/*
step 3.1 if xi == yt then set q{i-t-1} to b-1, otherwise set q{i-t-1} to (xi*b + x{i-1})/yt
*/
if x.digit[i] == y.digit[t] {
q.digit[(i - t) - 1] = 1 << (_DIGIT_BITS - 1);
} else {
tmp := _WORD(x.digit[i]) << _DIGIT_BITS;
tmp |= _WORD(x.digit[i - 1]);
tmp /= _WORD(y.digit[t]);
if tmp > _WORD(_MASK) {
tmp = _WORD(_MASK);
}
q.digit[(i - t) - 1] = DIGIT(tmp & _WORD(_MASK));
}
/* while (q{i-t-1} * (yt * b + y{t-1})) >
xi * b**2 + xi-1 * b + xi-2
do q{i-t-1} -= 1;
*/
iter := 0;
q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] + 1) & _MASK;
for {
q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] - 1) & _MASK;
/*
Find left hand.
*/
zero(t1);
t1.digit[0] = ((t - 1) < 0) ? 0 : y.digit[t - 1];
t1.digit[1] = y.digit[t];
t1.used = 2;
if err = mul(t1, t1, q.digit[(i - t) - 1]); err != nil { return err; }
/*
Find right hand.
*/
t2.digit[0] = ((i - 2) < 0) ? 0 : x.digit[i - 2];
t2.digit[1] = x.digit[i - 1]; /* i >= 1 always holds */
t2.digit[2] = x.digit[i];
t2.used = 3;
if t1_t2, _ := cmp_mag(t1, t2); t1_t2 != 1 {
break;
}
iter += 1; if iter > 100 { return .Max_Iterations_Reached; }
}
/*
Step 3.3 x = x - q{i-t-1} * y * b**{i-t-1}
*/
if err = int_mul_digit(t1, y, q.digit[(i - t) - 1]); err != nil { return err; }
if err = shl_digit(t1, (i - t) - 1); err != nil { return err; }
if err = sub(x, x, t1); err != nil { return err; }
/*
if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; }
*/
if x.sign == .Negative {
if err = copy(t1, y); err != nil { return err; }
if err = shl_digit(t1, (i - t) - 1); err != nil { return err; }
if err = add(x, x, t1); err != nil { return err; }
q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] - 1) & _MASK;
}
}
/*
Now q is the quotient and x is the remainder, [which we have to normalize]
Get sign before writing to c.
*/
z, _ := is_zero(x);
x.sign = .Zero_or_Positive if z else numerator.sign;
if quotient != nil {
clamp(q);
swap(q, quotient);
quotient.sign = .Negative if neg else .Zero_or_Positive;
}
if remainder != nil {
if err = shr(x, x, norm); err != nil { return err; }
swap(x, remainder);
}
return nil;
}
/*
Slower bit-bang division... also smaller.
*/
@(deprecated="Use `_int_div_school`, it's 3.5x faster.")
_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
c: int;
goto_end: for {
if err = one(tq); err != nil { break goto_end; }
num_bits, _ := count_bits(numerator);
den_bits, _ := count_bits(denominator);
n := num_bits - den_bits;
if err = abs(ta, numerator); err != nil { break goto_end; }
if err = abs(tb, denominator); err != nil { break goto_end; }
if err = shl(tb, tb, n); err != nil { break goto_end; }
if err = shl(tq, tq, n); err != nil { break goto_end; }
for n >= 0 {
if c, _ = cmp_mag(ta, tb); c == 0 || c == 1 {
// ta -= tb
if err = sub(ta, ta, tb); err != nil { break goto_end; }
// q += tq
if err = add( q, q, tq); err != nil { break goto_end; }
}
if err = shr1(tb, tb); err != nil { break goto_end; }
if err = shr1(tq, tq); err != nil { break goto_end; }
n -= 1;
}
/*
Now q == quotient and ta == remainder.
*/
neg := numerator.sign != denominator.sign;
if quotient != nil {
swap(quotient, q);
z, _ := is_zero(quotient);
quotient.sign = .Negative if neg && !z else .Zero_or_Positive;
}
if remainder != nil {
swap(remainder, ta);
z, _ := is_zero(numerator);
remainder.sign = .Zero_or_Positive if z else numerator.sign;
}
break goto_end;
}
destroy(ta, tb, tq, q);
return err;
}
/*
Function computing both GCD and (if target isn't `nil`) also LCM.
*/

View File

@@ -1,10 +1,10 @@
@echo off
odin run . -vet
:odin run . -vet
: -o:size
:odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:size -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:size
:odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check
odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:speed
:python test.py
python test.py

View File

@@ -608,7 +608,7 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
/* Fast comba? */
// err = s_mp_sqr_comba(a, c);
} else {
err = _int_sqr(dest, src);
err = _private_int_sqr(dest, src);
}
} else {
/*
@@ -680,14 +680,13 @@ internal_int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int, a
// err = _int_div_recursive(quotient, remainder, numerator, denominator);
} else {
when true {
err = _int_div_school(quotient, remainder, numerator, denominator);
err = _private_int_div_school(quotient, remainder, numerator, denominator);
} else {
/*
NOTE(Jeroen): We no longer need or use `_int_div_small`.
NOTE(Jeroen): We no longer need or use `_private_int_div_small`.
We'll keep it around for a bit until we're reasonably certain div_school is bug free.
err = _int_div_small(quotient, remainder, numerator, denominator);
*/
err = _int_div_small(quotient, remainder, numerator, denominator);
err = _private_int_div_small(quotient, remainder, numerator, denominator);
}
}
return;
@@ -744,7 +743,7 @@ internal_int_divmod_digit :: proc(quotient, numerator: ^Int, denominator: DIGIT)
Three?
*/
if denominator == 3 {
return _int_div_3(quotient, numerator);
return _private_int_div_3(quotient, numerator);
}
/*
@@ -1049,8 +1048,6 @@ _private_int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
/*
Now extract the previous digit [below the carry].
*/
// for ix = 0; ix < pa; ix += 1 { dest.digit[ix] = W[ix]; }
copy_slice(dest.digit[0:], W[:pa]);
/*
@@ -1065,6 +1062,350 @@ _private_int_mul_comba :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
return clamp(dest);
}
/*
Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
*/
_private_int_sqr :: proc(dest, src: ^Int) -> (err: Error) {
pa := src.used;
t := &Int{}; ix, iy: int;
/*
Grow `t` to maximum needed size, or `_DEFAULT_DIGIT_COUNT`, whichever is bigger.
*/
if err = grow(t, max((2 * pa) + 1, _DEFAULT_DIGIT_COUNT)); err != nil { return err; }
t.used = (2 * pa) + 1;
#no_bounds_check for ix = 0; ix < pa; ix += 1 {
carry := DIGIT(0);
/*
First calculate the digit at 2*ix; calculate double precision result.
*/
r := _WORD(t.digit[ix+ix]) + (_WORD(src.digit[ix]) * _WORD(src.digit[ix]));
/*
Store lower part in result.
*/
t.digit[ix+ix] = DIGIT(r & _WORD(_MASK));
/*
Get the carry.
*/
carry = DIGIT(r >> _DIGIT_BITS);
#no_bounds_check for iy = ix + 1; iy < pa; iy += 1 {
/*
First calculate the product.
*/
r = _WORD(src.digit[ix]) * _WORD(src.digit[iy]);
/* Now calculate the double precision result. Nóte we use
* addition instead of *2 since it's easier to optimize
*/
r = _WORD(t.digit[ix+iy]) + r + r + _WORD(carry);
/*
Store lower part.
*/
t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
/*
Get carry.
*/
carry = DIGIT(r >> _DIGIT_BITS);
}
/*
Propagate upwards.
*/
#no_bounds_check for carry != 0 {
r = _WORD(t.digit[ix+iy]) + _WORD(carry);
t.digit[ix+iy] = DIGIT(r & _WORD(_MASK));
carry = DIGIT(r >> _WORD(_DIGIT_BITS));
iy += 1;
}
}
err = clamp(t);
swap(dest, t);
destroy(t);
return err;
}
/*
Divide by three (based on routine from MPI and the GMP manual).
*/
_private_int_div_3 :: proc(quotient, numerator: ^Int) -> (remainder: DIGIT, err: Error) {
/*
b = 2^_DIGIT_BITS / 3
*/
b := _WORD(1) << _WORD(_DIGIT_BITS) / _WORD(3);
q := &Int{};
if err = grow(q, numerator.used); err != nil { return 0, err; }
q.used = numerator.used;
q.sign = numerator.sign;
w, t: _WORD;
#no_bounds_check for ix := numerator.used; ix >= 0; ix -= 1 {
w = (w << _WORD(_DIGIT_BITS)) | _WORD(numerator.digit[ix]);
if w >= 3 {
/*
Multiply w by [1/3].
*/
t = (w * b) >> _WORD(_DIGIT_BITS);
/*
Now subtract 3 * [w/3] from w, to get the remainder.
*/
w -= t+t+t;
/*
Fixup the remainder as required since the optimization is not exact.
*/
for w >= 3 {
t += 1;
w -= 3;
}
} else {
t = 0;
}
q.digit[ix] = DIGIT(t);
}
remainder = DIGIT(w);
/*
[optional] store the quotient.
*/
if quotient != nil {
err = clamp(q);
swap(q, quotient);
}
destroy(q);
return remainder, nil;
}
/*
Signed Integer Division
c*b + d == a [i.e. a/b, c=quotient, d=remainder], HAC pp.598 Algorithm 14.20
Note that the description in HAC is horribly incomplete.
For example, it doesn't consider the case where digits are removed from 'x' in
the inner loop.
It also doesn't consider the case that y has fewer than three digits, etc.
The overall algorithm is as described as 14.20 from HAC but fixed to treat these cases.
*/
_private_int_div_school :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
// if err = error_if_immutable(quotient, remainder); err != nil { return err; }
// if err = clear_if_uninitialized(quotient, numerator, denominator); err != nil { return err; }
q, x, y, t1, t2 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(q, x, y, t1, t2);
if err = grow(q, numerator.used + 2); err != nil { return err; }
q.used = numerator.used + 2;
if err = init_multi(t1, t2); err != nil { return err; }
if err = copy(x, numerator); err != nil { return err; }
if err = copy(y, denominator); err != nil { return err; }
/*
Fix the sign.
*/
neg := numerator.sign != denominator.sign;
x.sign = .Zero_or_Positive;
y.sign = .Zero_or_Positive;
/*
Normalize both x and y, ensure that y >= b/2, [b == 2**MP_DIGIT_BIT]
*/
norm, _ := count_bits(y);
norm %= _DIGIT_BITS;
if norm < _DIGIT_BITS - 1 {
norm = (_DIGIT_BITS - 1) - norm;
if err = shl(x, x, norm); err != nil { return err; }
if err = shl(y, y, norm); err != nil { return err; }
} else {
norm = 0;
}
/*
Note: HAC does 0 based, so if used==5 then it's 0,1,2,3,4, i.e. use 4
*/
n := x.used - 1;
t := y.used - 1;
/*
while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} }
y = y*b**{n-t}
*/
if err = shl_digit(y, n - t); err != nil { return err; }
c, _ := cmp(x, y);
for c != -1 {
q.digit[n - t] += 1;
if err = sub(x, x, y); err != nil { return err; }
c, _ = cmp(x, y);
}
/*
Reset y by shifting it back down.
*/
shr_digit(y, n - t);
/*
Step 3. for i from n down to (t + 1).
*/
#no_bounds_check for i := n; i >= (t + 1); i -= 1 {
if (i > x.used) { continue; }
/*
step 3.1 if xi == yt then set q{i-t-1} to b-1, otherwise set q{i-t-1} to (xi*b + x{i-1})/yt
*/
if x.digit[i] == y.digit[t] {
q.digit[(i - t) - 1] = 1 << (_DIGIT_BITS - 1);
} else {
tmp := _WORD(x.digit[i]) << _DIGIT_BITS;
tmp |= _WORD(x.digit[i - 1]);
tmp /= _WORD(y.digit[t]);
if tmp > _WORD(_MASK) {
tmp = _WORD(_MASK);
}
q.digit[(i - t) - 1] = DIGIT(tmp & _WORD(_MASK));
}
/* while (q{i-t-1} * (yt * b + y{t-1})) >
xi * b**2 + xi-1 * b + xi-2
do q{i-t-1} -= 1;
*/
iter := 0;
q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] + 1) & _MASK;
#no_bounds_check for {
q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] - 1) & _MASK;
/*
Find left hand.
*/
zero(t1);
t1.digit[0] = ((t - 1) < 0) ? 0 : y.digit[t - 1];
t1.digit[1] = y.digit[t];
t1.used = 2;
if err = mul(t1, t1, q.digit[(i - t) - 1]); err != nil { return err; }
/*
Find right hand.
*/
t2.digit[0] = ((i - 2) < 0) ? 0 : x.digit[i - 2];
t2.digit[1] = x.digit[i - 1]; /* i >= 1 always holds */
t2.digit[2] = x.digit[i];
t2.used = 3;
if t1_t2, _ := cmp_mag(t1, t2); t1_t2 != 1 {
break;
}
iter += 1; if iter > 100 { return .Max_Iterations_Reached; }
}
/*
Step 3.3 x = x - q{i-t-1} * y * b**{i-t-1}
*/
if err = int_mul_digit(t1, y, q.digit[(i - t) - 1]); err != nil { return err; }
if err = shl_digit(t1, (i - t) - 1); err != nil { return err; }
if err = sub(x, x, t1); err != nil { return err; }
/*
if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; }
*/
if x.sign == .Negative {
if err = copy(t1, y); err != nil { return err; }
if err = shl_digit(t1, (i - t) - 1); err != nil { return err; }
if err = add(x, x, t1); err != nil { return err; }
q.digit[(i - t) - 1] = (q.digit[(i - t) - 1] - 1) & _MASK;
}
}
/*
Now q is the quotient and x is the remainder, [which we have to normalize]
Get sign before writing to c.
*/
z, _ := is_zero(x);
x.sign = .Zero_or_Positive if z else numerator.sign;
if quotient != nil {
clamp(q);
swap(q, quotient);
quotient.sign = .Negative if neg else .Zero_or_Positive;
}
if remainder != nil {
if err = shr(x, x, norm); err != nil { return err; }
swap(x, remainder);
}
return nil;
}
/*
Slower bit-bang division... also smaller.
*/
@(deprecated="Use `_int_div_school`, it's 3.5x faster.")
_private_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
c: int;
goto_end: for {
if err = one(tq); err != nil { break goto_end; }
num_bits, _ := count_bits(numerator);
den_bits, _ := count_bits(denominator);
n := num_bits - den_bits;
if err = abs(ta, numerator); err != nil { break goto_end; }
if err = abs(tb, denominator); err != nil { break goto_end; }
if err = shl(tb, tb, n); err != nil { break goto_end; }
if err = shl(tq, tq, n); err != nil { break goto_end; }
for n >= 0 {
if c, _ = cmp_mag(ta, tb); c == 0 || c == 1 {
// ta -= tb
if err = sub(ta, ta, tb); err != nil { break goto_end; }
// q += tq
if err = add( q, q, tq); err != nil { break goto_end; }
}
if err = shr1(tb, tb); err != nil { break goto_end; }
if err = shr1(tq, tq); err != nil { break goto_end; }
n -= 1;
}
/*
Now q == quotient and ta == remainder.
*/
neg := numerator.sign != denominator.sign;
if quotient != nil {
swap(quotient, q);
z, _ := is_zero(quotient);
quotient.sign = .Negative if neg && !z else .Zero_or_Positive;
}
if remainder != nil {
swap(remainder, ta);
z, _ := is_zero(numerator);
remainder.sign = .Zero_or_Positive if z else numerator.sign;
}
break goto_end;
}
destroy(ta, tb, tq, q);
return err;
}
/*

View File

@@ -17,7 +17,7 @@ EXIT_ON_FAIL = False
# We skip randomized tests altogether if NO_RANDOM_TESTS is set.
#
NO_RANDOM_TESTS = True
#NO_RANDOM_TESTS = False
NO_RANDOM_TESTS = False
#
# If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations.
@@ -96,11 +96,11 @@ class Error(Enum):
# Set up exported procedures
#
# try:
l = cdll.LoadLibrary(LIB_PATH)
# except:
# print("Couldn't find or load " + LIB_PATH + ".")
# exit(1)
try:
l = cdll.LoadLibrary(LIB_PATH)
except:
print("Couldn't find or load " + LIB_PATH + ".")
exit(1)
def load(export_name, args, res):
export_name.argtypes = args