big: Test root_n.

This commit is contained in:
Jeroen van Rijn
2021-07-31 18:58:46 +02:00
parent 149c7b88df
commit db0196abc7
4 changed files with 64 additions and 47 deletions

View File

@@ -35,6 +35,8 @@ _DEFAULT_SQR_KARATSUBA_CUTOFF :: 120;
_DEFAULT_MUL_TOOM_CUTOFF :: 350;
_DEFAULT_SQR_TOOM_CUTOFF :: 400;
_MAX_ITERATIONS_ROOT_N :: 500;
Sign :: enum u8 {
Zero_or_Positive = 0,
Negative = 1,

View File

@@ -233,15 +233,9 @@ int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
if count, err = count_bits(src); err != .None { return err; }
a, b := count >> 1, count & 1;
err = power_of_two(x, a+b);
if err = power_of_two(x, a+b); err != .None { return err; }
iter := 0;
for {
iter += 1;
if iter > 100 {
swap(dest, x);
return .Max_Iterations_Reached;
}
/*
y = (x + n//x)//2
*/
@@ -274,7 +268,7 @@ sqrt :: proc { int_sqrt, };
*/
int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
/* Fast path for n == 2 */
// if n == 2 { return sqrt(dest, src); }
if n == 2 { return sqrt(dest, src); }
/* Initialize dest + src if needed. */
if err = clear_if_uninitialized(dest); err != .None { return err; }
@@ -366,7 +360,7 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
}
if c, err = cmp(t1, t2); c == 0 { break; }
iterations += 1;
if iterations == 101 {
if iterations == _MAX_ITERATIONS_ROOT_N {
return .Max_Iterations_Reached;
}
}
@@ -376,12 +370,6 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
iterations = 0;
for {
if iterations == 101 {
return .Max_Iterations_Reached;
}
//fmt.printf("root_n iteration: %v\n", iterations);
iterations += 1;
if err = pow(t2, t1, n); err != .None { return err; }
c, err = cmp(t2, a);
@@ -393,16 +381,16 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
} else {
break;
}
iterations += 1;
if iterations == _MAX_ITERATIONS_ROOT_N {
return .Max_Iterations_Reached;
}
}
iterations = 0;
/* Correct overshoot from above or from recurrence. */
for {
if iterations == 101 {
return .Max_Iterations_Reached;
}
iterations += 1;
if err = pow(t2, t1, n); err != .None { return err; }
c, err = cmp(t2, a);
@@ -411,6 +399,11 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
} else {
break;
}
iterations += 1;
if iterations == _MAX_ITERATIONS_ROOT_N {
return .Max_Iterations_Reached;
}
}
/* Set the result. */

View File

@@ -162,6 +162,24 @@ PyRes :: struct {
return PyRes{res = r, err = .None};
}
/*
dest = root_n(src, power)
*/
@export test_root_n :: proc "c" (source: cstring, power: int) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
src := &Int{};
defer destroy(src);
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":root_n:atoi(src):", err=err}; }
if err = root_n(src, src, power); err != .None { return PyRes{res=":root_n:root_n(src):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
if err != .None { return PyRes{res=":root_n:itoa(res):", err=err}; }
return PyRes{res = r, err = .None};
}
/*
dest = shr_digit(src, digits)

View File

@@ -121,9 +121,10 @@ mul = load(l.test_mul, [c_char_p, c_char_p], Res)
div = load(l.test_div, [c_char_p, c_char_p], Res)
# Powers and such
int_log = load(l.test_log, [c_char_p, c_longlong], Res)
int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
int_sqrt = load(l.test_sqrt, [c_char_p], Res)
int_log = load(l.test_log, [c_char_p, c_longlong], Res)
int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
int_sqrt = load(l.test_sqrt, [c_char_p], Res)
int_root_n = load(l.test_root_n, [c_char_p, c_longlong], Res)
# Logical operations
@@ -266,6 +267,23 @@ def root_n(number, root):
u = t // root
return s
def test_root_n(number = 0, root = 0, expected_error = Error.Okay):
args = [str(number), root]
sa_c = args[0].encode('utf-8')
try:
res = int_root_n(sa_c, root)
except:
print("root_n:", number, root)
expected_result = None
if expected_error == Error.Okay:
if number < 0:
expected_result = 0
else:
expected_result = root_n(number, root)
return test("test_root_n", res, args, expected_error, expected_result)
def test_shl_digit(a = 0, digits = 0, expected_error = Error.Okay):
args = [str(a), digits]
sa_c = args[0].encode('utf-8')
@@ -384,6 +402,9 @@ TESTS = {
[ 12345678901234567890, Error.Okay, ],
[ 1298074214633706907132624082305024, Error.Okay, ],
],
test_root_n: [
[ 1298074214633706907132624082305024, 2, Error.Okay, ],
],
test_shl_digit: [
[ 3192, 1 ],
[ 1298074214633706907132624082305024, 2 ],
@@ -420,10 +441,12 @@ total_failures = 0
#
RANDOM_TESTS = [
test_add, test_sub, test_mul, test_div,
test_log, test_pow, test_sqrt,
test_log, test_pow, test_sqrt, test_root_n,
test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
]
SKIP_LARGE = [test_pow]
SKIP_LARGE = [
test_pow, test_root_n,
]
SKIP_LARGEST = []
# Untimed warmup.
@@ -431,28 +454,6 @@ for test_proc in TESTS:
for t in TESTS[test_proc]:
res = test_proc(*t)
def isqrt(x):
n = int(x)
a, b = divmod(n.bit_length(), 2)
print("isqrt({}), a: {}, b: {}". format(n, a, b))
x = 2**(a+b)
print("initial: {}".format(x))
i = 0
while True:
# y = (x + n//x)//2
t1 = n // x
t2 = x + t1
t3 = t2 // 2
y = (x + n//x)//2
i += 1
print("iter {}\n\t x: {}\n\t y: {}\n\tt1: {}\n\tt2: {}\n\tsrc: {}".format(i, x, y, t1, t2, n));
if y >= x:
return x
x = y
if __name__ == '__main__':
print("---- math/big tests ----")
print()
@@ -517,6 +518,9 @@ if __name__ == '__main__':
elif test_proc == test_sqrt:
a = randint(1, 1 << BITS)
b = Error.Okay
elif test_proc == test_root_n:
a = randint(1, 1 << BITS)
b = randint(1, 10);
elif test_proc == test_shl_digit:
b = randint(0, 10);
elif test_proc == test_shr_digit: