diff --git a/core/math/big/common.odin b/core/math/big/common.odin index f030f111a..32d7b938f 100644 --- a/core/math/big/common.odin +++ b/core/math/big/common.odin @@ -35,6 +35,8 @@ _DEFAULT_SQR_KARATSUBA_CUTOFF :: 120; _DEFAULT_MUL_TOOM_CUTOFF :: 350; _DEFAULT_SQR_TOOM_CUTOFF :: 400; +_MAX_ITERATIONS_ROOT_N :: 500; + Sign :: enum u8 { Zero_or_Positive = 0, Negative = 1, diff --git a/core/math/big/exp_log.odin b/core/math/big/exp_log.odin index 840aec8df..746470956 100644 --- a/core/math/big/exp_log.odin +++ b/core/math/big/exp_log.odin @@ -233,15 +233,9 @@ int_sqrt :: proc(dest, src: ^Int) -> (err: Error) { if count, err = count_bits(src); err != .None { return err; } a, b := count >> 1, count & 1; - err = power_of_two(x, a+b); + if err = power_of_two(x, a+b); err != .None { return err; } - iter := 0; for { - iter += 1; - if iter > 100 { - swap(dest, x); - return .Max_Iterations_Reached; - } /* y = (x + n//x)//2 */ @@ -274,7 +268,7 @@ sqrt :: proc { int_sqrt, }; */ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) { /* Fast path for n == 2 */ - // if n == 2 { return sqrt(dest, src); } + if n == 2 { return sqrt(dest, src); } /* Initialize dest + src if needed. */ if err = clear_if_uninitialized(dest); err != .None { return err; } @@ -366,7 +360,7 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) { } if c, err = cmp(t1, t2); c == 0 { break; } iterations += 1; - if iterations == 101 { + if iterations == _MAX_ITERATIONS_ROOT_N { return .Max_Iterations_Reached; } } @@ -376,12 +370,6 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) { iterations = 0; for { - if iterations == 101 { - return .Max_Iterations_Reached; - } - //fmt.printf("root_n iteration: %v\n", iterations); - iterations += 1; - if err = pow(t2, t1, n); err != .None { return err; } c, err = cmp(t2, a); @@ -393,16 +381,16 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) { } else { break; } + + iterations += 1; + if iterations == _MAX_ITERATIONS_ROOT_N { + return .Max_Iterations_Reached; + } } iterations = 0; /* Correct overshoot from above or from recurrence. */ for { - if iterations == 101 { - return .Max_Iterations_Reached; - } - iterations += 1; - if err = pow(t2, t1, n); err != .None { return err; } c, err = cmp(t2, a); @@ -411,6 +399,11 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) { } else { break; } + + iterations += 1; + if iterations == _MAX_ITERATIONS_ROOT_N { + return .Max_Iterations_Reached; + } } /* Set the result. */ diff --git a/core/math/big/test.odin b/core/math/big/test.odin index 04635295d..ed95669f5 100644 --- a/core/math/big/test.odin +++ b/core/math/big/test.odin @@ -162,6 +162,24 @@ PyRes :: struct { return PyRes{res = r, err = .None}; } +/* + dest = root_n(src, power) +*/ +@export test_root_n :: proc "c" (source: cstring, power: int) -> (res: PyRes) { + context = runtime.default_context(); + err: Error; + + src := &Int{}; + defer destroy(src); + + if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":root_n:atoi(src):", err=err}; } + if err = root_n(src, src, power); err != .None { return PyRes{res=":root_n:root_n(src):", err=err}; } + + r: cstring; + r, err = int_itoa_cstring(src, 10, context.temp_allocator); + if err != .None { return PyRes{res=":root_n:itoa(res):", err=err}; } + return PyRes{res = r, err = .None}; +} /* dest = shr_digit(src, digits) diff --git a/core/math/big/test.py b/core/math/big/test.py index 956d75e1c..645383b0f 100644 --- a/core/math/big/test.py +++ b/core/math/big/test.py @@ -121,9 +121,10 @@ mul = load(l.test_mul, [c_char_p, c_char_p], Res) div = load(l.test_div, [c_char_p, c_char_p], Res) # Powers and such -int_log = load(l.test_log, [c_char_p, c_longlong], Res) -int_pow = load(l.test_pow, [c_char_p, c_longlong], Res) -int_sqrt = load(l.test_sqrt, [c_char_p], Res) +int_log = load(l.test_log, [c_char_p, c_longlong], Res) +int_pow = load(l.test_pow, [c_char_p, c_longlong], Res) +int_sqrt = load(l.test_sqrt, [c_char_p], Res) +int_root_n = load(l.test_root_n, [c_char_p, c_longlong], Res) # Logical operations @@ -266,6 +267,23 @@ def root_n(number, root): u = t // root return s +def test_root_n(number = 0, root = 0, expected_error = Error.Okay): + args = [str(number), root] + sa_c = args[0].encode('utf-8') + try: + res = int_root_n(sa_c, root) + except: + print("root_n:", number, root) + + expected_result = None + if expected_error == Error.Okay: + if number < 0: + expected_result = 0 + else: + expected_result = root_n(number, root) + + return test("test_root_n", res, args, expected_error, expected_result) + def test_shl_digit(a = 0, digits = 0, expected_error = Error.Okay): args = [str(a), digits] sa_c = args[0].encode('utf-8') @@ -384,6 +402,9 @@ TESTS = { [ 12345678901234567890, Error.Okay, ], [ 1298074214633706907132624082305024, Error.Okay, ], ], + test_root_n: [ + [ 1298074214633706907132624082305024, 2, Error.Okay, ], + ], test_shl_digit: [ [ 3192, 1 ], [ 1298074214633706907132624082305024, 2 ], @@ -420,10 +441,12 @@ total_failures = 0 # RANDOM_TESTS = [ test_add, test_sub, test_mul, test_div, - test_log, test_pow, test_sqrt, + test_log, test_pow, test_sqrt, test_root_n, test_shl_digit, test_shr_digit, test_shl, test_shr_signed, ] -SKIP_LARGE = [test_pow] +SKIP_LARGE = [ + test_pow, test_root_n, +] SKIP_LARGEST = [] # Untimed warmup. @@ -431,28 +454,6 @@ for test_proc in TESTS: for t in TESTS[test_proc]: res = test_proc(*t) - -def isqrt(x): - n = int(x) - a, b = divmod(n.bit_length(), 2) - print("isqrt({}), a: {}, b: {}". format(n, a, b)) - x = 2**(a+b) - print("initial: {}".format(x)) - i = 0 - while True: - # y = (x + n//x)//2 - t1 = n // x - t2 = x + t1 - t3 = t2 // 2 - y = (x + n//x)//2 - - i += 1 - print("iter {}\n\t x: {}\n\t y: {}\n\tt1: {}\n\tt2: {}\n\tsrc: {}".format(i, x, y, t1, t2, n)); - - if y >= x: - return x - x = y - if __name__ == '__main__': print("---- math/big tests ----") print() @@ -517,6 +518,9 @@ if __name__ == '__main__': elif test_proc == test_sqrt: a = randint(1, 1 << BITS) b = Error.Okay + elif test_proc == test_root_n: + a = randint(1, 1 << BITS) + b = randint(1, 10); elif test_proc == test_shl_digit: b = randint(0, 10); elif test_proc == test_shr_digit: