big: Squashed shl1 bug when a larger dest was reused for a smaller result.

This commit is contained in:
Jeroen van Rijn
2021-08-06 14:57:53 +02:00
parent f8442e0524
commit 4be48973ad
7 changed files with 113 additions and 102 deletions

View File

@@ -152,45 +152,18 @@ int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: E
/*
Early out if neither of the results is wanted.
*/
if quotient == nil && remainder == nil { return nil; }
if quotient == nil && remainder == nil { return nil; }
if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; }
if err = clear_if_uninitialized(numerator); err != nil { return err; }
if err = clear_if_uninitialized(denominator); err != nil { return err; }
z: bool;
if z, err = is_zero(denominator); z { return .Division_by_Zero; }
/*
If numerator < denominator then quotient = 0, remainder = numerator.
*/
c: int;
if c, err = cmp_mag(numerator, denominator); c == -1 {
if remainder != nil {
if err = copy(remainder, numerator); err != nil { return err; }
}
if quotient != nil {
zero(quotient);
}
return nil;
}
if false && (denominator.used > 2 * _MUL_KARATSUBA_CUTOFF) && (denominator.used <= (numerator.used/3) * 2) {
// err = _int_div_recursive(quotient, remainder, numerator, denominator);
} else {
err = _int_div_school(quotient, remainder, numerator, denominator);
/*
NOTE(Jeroen): We no longer need or use `_int_div_small`.
We'll keep it around for a bit.
err = _int_div_small(quotient, remainder, numerator, denominator);
*/
}
return err;
return #force_inline internal_int_divmod(quotient, remainder, numerator, denominator);
}
divmod :: proc{ int_divmod, };
int_div :: proc(quotient, numerator, denominator: ^Int) -> (err: Error) {
return int_divmod(quotient, nil, numerator, denominator);
if quotient == nil { return .Invalid_Pointer; };
if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; }
return #force_inline internal_int_divmod(quotient, nil, numerator, denominator);
}
div :: proc { int_div, };
@@ -200,11 +173,10 @@ div :: proc { int_div, };
denominator < remainder <= 0 if denominator < 0
*/
int_mod :: proc(remainder, numerator, denominator: ^Int) -> (err: Error) {
if err = divmod(nil, remainder, numerator, denominator); err != nil { return err; }
if remainder == nil { return .Invalid_Pointer; };
if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; }
z: bool;
if z, err = is_zero(remainder); z || denominator.sign == remainder.sign { return nil; }
return add(remainder, remainder, numerator);
return #force_inline internal_int_mod(remainder, numerator, denominator);
}
int_mod_digit :: proc(numerator: ^Int, denominator: DIGIT) -> (remainder: DIGIT, err: Error) {
@@ -776,7 +748,6 @@ _int_div_school :: proc(quotient, remainder, numerator, denominator: ^Int) -> (e
t2.used = 3;
if t1_t2, _ := cmp_mag(t1, t2); t1_t2 != 1 {
break;
}
iter += 1; if iter > 100 { return .Max_Iterations_Reached; }

View File

@@ -1,5 +1,5 @@
@echo off
:odin run . -vet-more
:odin run . -vet
: -o:size -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:size -no-bounds-check

View File

@@ -62,30 +62,16 @@ print :: proc(name: string, a: ^Int, base := i8(10), print_name := true, newline
}
demo :: proc() {
err: Error;
as: string;
defer delete(as);
a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(a, b, c, d, e, f);
err = factorial(a, 1224);
count, _ := count_bits(a);
foo := "686885735734829009541949746871140768343076607029752932751182108475420900392874228486622313727012705619148037570309621219533087263900443932890792804879473795673302686046941536636874184361869252299636701671980034458333859202703255467709267777184095435235980845369829397344182319113372092844648570818726316581751114346501124871729572474923695509057166373026411194094493240101036672016770945150422252961487398124677567028263059046193391737576836378376192651849283925197438927999526058932679219572030021792914065825542626400207956134072247020690107136531852625253942429167557531123651471221455967386267137846791963149859804549891438562641323068751514370656287452006867713758971418043865298618635213551059471668293725548570452377976322899027050925842868079489675596835389444833567439058609775325447891875359487104691935576723532407937236505941186660707032433807075470656782452889754501872408562496805517394619388777930253411467941214807849472083814447498068636264021405175653742244368865090604940094889189800007448083930490871954101880815781177612910234741529950538835837693870921008635195545246771593130784786737543736434086434015200264933536294884482218945403958647118802574342840790536176272341586020230110889699633073513016344826709214";
err := atoi(a, foo, 10);
bits := 51;
be1: _WORD;
if err != nil do fmt.printf("atoi returned %v\n", err);
print("foo: ", a);
/*
Timing loop
*/
{
SCOPED_TIMING(.bitfield_extract);
for o := 0; o < count - bits; o += 1 {
be1, _ = int_bitfield_extract(a, o, bits);
}
}
SCOPED_COUNT_ADD(.bitfield_extract, count - bits - 1);
fmt.printf("be1: %v\n", be1);
}
main :: proc() {

View File

@@ -253,7 +253,7 @@ int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
swap(dest, x);
return err;
} else {
// return root_n(dest, src, 2);
return root_n(dest, src, 2);
}
}
sqrt :: proc { int_sqrt, };

View File

@@ -491,49 +491,30 @@ internal_int_shr1 :: proc(dest, src: ^Int) -> (err: Error) {
dest = src << 1
*/
internal_int_shl1 :: proc(dest, src: ^Int) -> (err: Error) {
old_used := dest.used; dest.used = src.used + 1;
if err = copy(dest, src); err != nil { return err; }
/*
Forward carry
Grow `dest` to accommodate the additional bits.
*/
digits_needed := dest.used + 1;
if err = grow(dest, digits_needed); err != nil { return err; }
dest.used = digits_needed;
mask := (DIGIT(1) << uint(1)) - DIGIT(1);
shift := DIGIT(_DIGIT_BITS - 1);
carry := DIGIT(0);
#no_bounds_check for x := 0; x < src.used; x += 1 {
/*
Get what will be the *next* carry bit from the MSB of the current digit.
*/
src_digit := src.digit[x];
fwd_carry := src_digit >> (_DIGIT_BITS - 1);
/*
Now shift up this digit, add in the carry [from the previous]
*/
dest.digit[x] = (src_digit << 1 | carry) & _MASK;
/*
Update carry
*/
#no_bounds_check for x:= 0; x < dest.used; x+= 1 {
fwd_carry := (dest.digit[x] >> shift) & mask;
dest.digit[x] = (dest.digit[x] << uint(1) | carry) & _MASK;
carry = fwd_carry;
}
/*
New leading digit?
Use final carry.
*/
if carry != 0 {
/*
Add a MSB which is always 1 at this point.
*/
dest.digit[dest.used] = 1;
dest.digit[dest.used] = carry;
dest.used += 1;
}
zero_count := old_used - dest.used;
/*
Zero remainder.
*/
if zero_count > 0 {
mem.zero_slice(dest.digit[dest.used:][:zero_count]);
}
/*
Adjust dest.used based on leading zeroes.
*/
dest.sign = src.sign;
return clamp(dest);
}
@@ -552,7 +533,7 @@ internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator :=
Power of two?
*/
if multiplier == 2 {
return #force_inline shl1(dest, src);
return #force_inline internal_int_shl1(dest, src);
}
if is_power_of_two(int(multiplier)) {
ix: int;
@@ -581,7 +562,7 @@ internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator :=
Compute columns.
*/
ix := 0;
#no_bounds_check for ; ix < src.used; ix += 1 {
for ; ix < src.used; ix += 1 {
/*
Compute product and carry sum for this term
*/
@@ -600,13 +581,15 @@ internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator :=
Store final carry [if any] and increment used.
*/
dest.digit[ix] = DIGIT(carry);
dest.used = src.used + 1;
/*
Zero unused digits.
*/
//_zero_unused(dest);
zero_count := old_used - dest.used;
if zero_count > 0 {
mem.zero_slice(dest.digit[zero_count:]);
if zero_count > 0 {
mem.zero_slice(dest.digit[dest.used:][:zero_count]);
}
return clamp(dest);
}
@@ -675,9 +658,72 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
err = _int_mul(dest, src, multiplier, digits);
}
}
neg := src.sign != multiplier.sign;
neg := src.sign != multiplier.sign;
dest.sign = .Negative if dest.used > 0 && neg else .Zero_or_Positive;
return err;
}
internal_mul :: proc { internal_int_mul, internal_int_mul_digit, };
internal_mul :: proc { internal_int_mul, internal_int_mul_digit, };
/*
divmod.
Both the quotient and remainder are optional and may be passed a nil.
*/
internal_int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int, allocator := context.allocator) -> (err: Error) {
if denominator.used == 0 { return .Division_by_Zero; }
/*
If numerator < denominator then quotient = 0, remainder = numerator.
*/
c: int;
if c, err = #force_inline cmp_mag(numerator, denominator); c == -1 {
if remainder != nil {
if err = copy(remainder, numerator, false, allocator); err != nil { return err; }
}
if quotient != nil {
zero(quotient);
}
return nil;
}
if false && (denominator.used > 2 * _MUL_KARATSUBA_CUTOFF) && (denominator.used <= (numerator.used/3) * 2) {
// err = _int_div_recursive(quotient, remainder, numerator, denominator);
} else {
when true {
err = _int_div_school(quotient, remainder, numerator, denominator);
} else {
/*
NOTE(Jeroen): We no longer need or use `_int_div_small`.
We'll keep it around for a bit until we're reasonably certain div_school is bug free.
err = _int_div_small(quotient, remainder, numerator, denominator);
*/
err = _int_div_small(quotient, remainder, numerator, denominator);
}
}
return;
}
internal_divmod :: proc { internal_int_divmod, };
/*
Asssumes quotient, numerator and denominator to have been initialized and not to be nil.
*/
internal_int_div :: proc(quotient, numerator, denominator: ^Int) -> (err: Error) {
return #force_inline internal_int_divmod(quotient, nil, numerator, denominator);
}
internal_div :: proc { internal_int_div, };
/*
remainder = numerator % denominator.
0 <= remainder < denominator if denominator > 0
denominator < remainder <= 0 if denominator < 0
Asssumes quotient, numerator and denominator to have been initialized and not to be nil.
*/
internal_int_mod :: proc(remainder, numerator, denominator: ^Int) -> (err: Error) {
if err = #force_inline internal_int_divmod(nil, remainder, numerator, denominator); err != nil { return err; }
if remainder.used == 0 || denominator.sign == remainder.sign { return nil; }
return #force_inline internal_add(remainder, remainder, numerator);
}
internal_mod :: proc{ internal_int_mod, };

View File

@@ -24,9 +24,9 @@ PyRes :: struct {
err: Error,
}
@export test_initialize_constants :: proc "c" () -> (res: int) {
@export test_initialize_constants :: proc "c" () -> (res: u64) {
context = runtime.default_context();
return initialize_constants();
return u64(initialize_constants());
}
@export test_error_string :: proc "c" (err: Error) -> (res: cstring) {

View File

@@ -254,7 +254,13 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay):
def test_sqrt(number = 0, expected_error = Error.Okay):
args = [arg_to_odin(number)]
res = int_sqrt(*args)
try:
res = int_sqrt(*args)
except OSError as e:
print("{} while trying to sqrt {}.".format(e, number))
if EXIT_ON_FAIL: exit(3)
return False
expected_result = None
if expected_error == Error.Okay:
if number < 0:
@@ -384,6 +390,7 @@ TESTS = {
[ 54321, 12345],
[ 55431, 0, Error.Division_by_Zero],
[ 12980742146337069150589594264770969721, 4611686018427387904 ],
[ 831956404029821402159719858789932422, 243087903122332132 ],
],
test_log: [
[ 3192, 1, Error.Invalid_Argument],
@@ -405,6 +412,7 @@ TESTS = {
[ 42, Error.Okay, ],
[ 12345678901234567890, Error.Okay, ],
[ 1298074214633706907132624082305024, Error.Okay, ],
[ 686885735734829009541949746871140768343076607029752932751182108475420900392874228486622313727012705619148037570309621219533087263900443932890792804879473795673302686046941536636874184361869252299636701671980034458333859202703255467709267777184095435235980845369829397344182319113372092844648570818726316581751114346501124871729572474923695509057166373026411194094493240101036672016770945150422252961487398124677567028263059046193391737576836378376192651849283925197438927999526058932679219572030021792914065825542626400207956134072247020690107136531852625253942429167557531123651471221455967386267137846791963149859804549891438562641323068751514370656287452006867713758971418043865298618635213551059471668293725548570452377976322899027050925842868079489675596835389444833567439058609775325447891875359487104691935576723532407937236505941186660707032433807075470656782452889754501872408562496805517394619388777930253411467941214807849472083814447498068636264021405175653742244368865090604940094889189800007448083930490871954101880815781177612910234741529950538835837693870921008635195545246771593130784786737543736434086434015200264933536294884482218945403958647118802574342840790536176272341586020230110889699633073513016344826709214, Error.Okay, ],
],
test_root_n: [
[ 1298074214633706907132624082305024, 2, Error.Okay, ],