mirror of
https://github.com/odin-lang/Odin.git
synced 2026-02-16 16:14:06 +00:00
big: Test root_n.
This commit is contained in:
@@ -35,6 +35,8 @@ _DEFAULT_SQR_KARATSUBA_CUTOFF :: 120;
|
||||
_DEFAULT_MUL_TOOM_CUTOFF :: 350;
|
||||
_DEFAULT_SQR_TOOM_CUTOFF :: 400;
|
||||
|
||||
_MAX_ITERATIONS_ROOT_N :: 500;
|
||||
|
||||
Sign :: enum u8 {
|
||||
Zero_or_Positive = 0,
|
||||
Negative = 1,
|
||||
|
||||
@@ -233,15 +233,9 @@ int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
|
||||
if count, err = count_bits(src); err != .None { return err; }
|
||||
|
||||
a, b := count >> 1, count & 1;
|
||||
err = power_of_two(x, a+b);
|
||||
if err = power_of_two(x, a+b); err != .None { return err; }
|
||||
|
||||
iter := 0;
|
||||
for {
|
||||
iter += 1;
|
||||
if iter > 100 {
|
||||
swap(dest, x);
|
||||
return .Max_Iterations_Reached;
|
||||
}
|
||||
/*
|
||||
y = (x + n//x)//2
|
||||
*/
|
||||
@@ -274,7 +268,7 @@ sqrt :: proc { int_sqrt, };
|
||||
*/
|
||||
int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
|
||||
/* Fast path for n == 2 */
|
||||
// if n == 2 { return sqrt(dest, src); }
|
||||
if n == 2 { return sqrt(dest, src); }
|
||||
|
||||
/* Initialize dest + src if needed. */
|
||||
if err = clear_if_uninitialized(dest); err != .None { return err; }
|
||||
@@ -366,7 +360,7 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
|
||||
}
|
||||
if c, err = cmp(t1, t2); c == 0 { break; }
|
||||
iterations += 1;
|
||||
if iterations == 101 {
|
||||
if iterations == _MAX_ITERATIONS_ROOT_N {
|
||||
return .Max_Iterations_Reached;
|
||||
}
|
||||
}
|
||||
@@ -376,12 +370,6 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
|
||||
|
||||
iterations = 0;
|
||||
for {
|
||||
if iterations == 101 {
|
||||
return .Max_Iterations_Reached;
|
||||
}
|
||||
//fmt.printf("root_n iteration: %v\n", iterations);
|
||||
iterations += 1;
|
||||
|
||||
if err = pow(t2, t1, n); err != .None { return err; }
|
||||
|
||||
c, err = cmp(t2, a);
|
||||
@@ -393,16 +381,16 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
||||
iterations += 1;
|
||||
if iterations == _MAX_ITERATIONS_ROOT_N {
|
||||
return .Max_Iterations_Reached;
|
||||
}
|
||||
}
|
||||
|
||||
iterations = 0;
|
||||
/* Correct overshoot from above or from recurrence. */
|
||||
for {
|
||||
if iterations == 101 {
|
||||
return .Max_Iterations_Reached;
|
||||
}
|
||||
iterations += 1;
|
||||
|
||||
if err = pow(t2, t1, n); err != .None { return err; }
|
||||
|
||||
c, err = cmp(t2, a);
|
||||
@@ -411,6 +399,11 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
||||
iterations += 1;
|
||||
if iterations == _MAX_ITERATIONS_ROOT_N {
|
||||
return .Max_Iterations_Reached;
|
||||
}
|
||||
}
|
||||
|
||||
/* Set the result. */
|
||||
|
||||
@@ -162,6 +162,24 @@ PyRes :: struct {
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
/*
|
||||
dest = root_n(src, power)
|
||||
*/
|
||||
@export test_root_n :: proc "c" (source: cstring, power: int) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
src := &Int{};
|
||||
defer destroy(src);
|
||||
|
||||
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":root_n:atoi(src):", err=err}; }
|
||||
if err = root_n(src, src, power); err != .None { return PyRes{res=":root_n:root_n(src):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":root_n:itoa(res):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
/*
|
||||
dest = shr_digit(src, digits)
|
||||
|
||||
@@ -121,9 +121,10 @@ mul = load(l.test_mul, [c_char_p, c_char_p], Res)
|
||||
div = load(l.test_div, [c_char_p, c_char_p], Res)
|
||||
|
||||
# Powers and such
|
||||
int_log = load(l.test_log, [c_char_p, c_longlong], Res)
|
||||
int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
|
||||
int_sqrt = load(l.test_sqrt, [c_char_p], Res)
|
||||
int_log = load(l.test_log, [c_char_p, c_longlong], Res)
|
||||
int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
|
||||
int_sqrt = load(l.test_sqrt, [c_char_p], Res)
|
||||
int_root_n = load(l.test_root_n, [c_char_p, c_longlong], Res)
|
||||
|
||||
# Logical operations
|
||||
|
||||
@@ -266,6 +267,23 @@ def root_n(number, root):
|
||||
u = t // root
|
||||
return s
|
||||
|
||||
def test_root_n(number = 0, root = 0, expected_error = Error.Okay):
|
||||
args = [str(number), root]
|
||||
sa_c = args[0].encode('utf-8')
|
||||
try:
|
||||
res = int_root_n(sa_c, root)
|
||||
except:
|
||||
print("root_n:", number, root)
|
||||
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
if number < 0:
|
||||
expected_result = 0
|
||||
else:
|
||||
expected_result = root_n(number, root)
|
||||
|
||||
return test("test_root_n", res, args, expected_error, expected_result)
|
||||
|
||||
def test_shl_digit(a = 0, digits = 0, expected_error = Error.Okay):
|
||||
args = [str(a), digits]
|
||||
sa_c = args[0].encode('utf-8')
|
||||
@@ -384,6 +402,9 @@ TESTS = {
|
||||
[ 12345678901234567890, Error.Okay, ],
|
||||
[ 1298074214633706907132624082305024, Error.Okay, ],
|
||||
],
|
||||
test_root_n: [
|
||||
[ 1298074214633706907132624082305024, 2, Error.Okay, ],
|
||||
],
|
||||
test_shl_digit: [
|
||||
[ 3192, 1 ],
|
||||
[ 1298074214633706907132624082305024, 2 ],
|
||||
@@ -420,10 +441,12 @@ total_failures = 0
|
||||
#
|
||||
RANDOM_TESTS = [
|
||||
test_add, test_sub, test_mul, test_div,
|
||||
test_log, test_pow, test_sqrt,
|
||||
test_log, test_pow, test_sqrt, test_root_n,
|
||||
test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
|
||||
]
|
||||
SKIP_LARGE = [test_pow]
|
||||
SKIP_LARGE = [
|
||||
test_pow, test_root_n,
|
||||
]
|
||||
SKIP_LARGEST = []
|
||||
|
||||
# Untimed warmup.
|
||||
@@ -431,28 +454,6 @@ for test_proc in TESTS:
|
||||
for t in TESTS[test_proc]:
|
||||
res = test_proc(*t)
|
||||
|
||||
|
||||
def isqrt(x):
|
||||
n = int(x)
|
||||
a, b = divmod(n.bit_length(), 2)
|
||||
print("isqrt({}), a: {}, b: {}". format(n, a, b))
|
||||
x = 2**(a+b)
|
||||
print("initial: {}".format(x))
|
||||
i = 0
|
||||
while True:
|
||||
# y = (x + n//x)//2
|
||||
t1 = n // x
|
||||
t2 = x + t1
|
||||
t3 = t2 // 2
|
||||
y = (x + n//x)//2
|
||||
|
||||
i += 1
|
||||
print("iter {}\n\t x: {}\n\t y: {}\n\tt1: {}\n\tt2: {}\n\tsrc: {}".format(i, x, y, t1, t2, n));
|
||||
|
||||
if y >= x:
|
||||
return x
|
||||
x = y
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("---- math/big tests ----")
|
||||
print()
|
||||
@@ -517,6 +518,9 @@ if __name__ == '__main__':
|
||||
elif test_proc == test_sqrt:
|
||||
a = randint(1, 1 << BITS)
|
||||
b = Error.Okay
|
||||
elif test_proc == test_root_n:
|
||||
a = randint(1, 1 << BITS)
|
||||
b = randint(1, 10);
|
||||
elif test_proc == test_shl_digit:
|
||||
b = randint(0, 10);
|
||||
elif test_proc == test_shr_digit:
|
||||
|
||||
Reference in New Issue
Block a user