diff --git a/core/math/big/private.odin b/core/math/big/private.odin index b3d4d80e5..14a27f600 100644 --- a/core/math/big/private.odin +++ b/core/math/big/private.odin @@ -1642,8 +1642,8 @@ _private_int_gcd_lcm :: proc(res_gcd, res_lcm, a, b: ^Int, allocator := context. /* Store quotient in `t2` such that `t2 * a` is the LCM. */ - internal_div(res_lcm, a, temp_gcd_res) or_return - err = internal_mul(res_lcm, res_lcm, b) + internal_div(res_lcm, b, temp_gcd_res) or_return + err = internal_mul(res_lcm, res_lcm, a) } if res_gcd != nil { diff --git a/core/math/big/tests/test.py b/core/math/big/tests/test.py index 79580b620..44659a638 100644 --- a/core/math/big/tests/test.py +++ b/core/math/big/tests/test.py @@ -236,24 +236,40 @@ def arg_to_odin(a): return s.encode('utf-8') -def integer_sqrt(src): - # The Python version on Github's CI doesn't offer math.isqrt. - # We implement our own - count = src.bit_length() - a, b = count >> 1, count & 1 +def big_integer_sqrt(src): + # The Python version on Github's CI doesn't offer math.isqrt. + # We implement our own + count = src.bit_length() + a, b = count >> 1, count & 1 - x = 1 << (a + b) + x = 1 << (a + b) - while True: - # y = (x + n // x) // 2 - t1 = src // x - t2 = t1 + x - y = t2 >> 1 + while True: + # y = (x + n // x) // 2 + t1 = src // x + t2 = t1 + x + y = t2 >> 1 - if y >= x: - return x + if y >= x: + return x - x, y = y, x + x, y = y, x + +def big_integer_lcm(a, b): + # Computes least common multiple as `|a*b|/gcd(a,b)` + # Divide the smallest by the GCD. + + if a == 0 or b == 0: + return 0 + + if abs(a) < abs(b): + # Store quotient in `t2` such that `t2 * b` is the LCM. + lcm = a // math.gcd(a, b) + return abs(b * lcm) + else: + # Store quotient in `t2` such that `t2 * a` is the LCM. + lcm = b // math.gcd(a, b) + return abs(a * lcm) def test_add(a = 0, b = 0, expected_error = Error.Okay): args = [arg_to_odin(a), arg_to_odin(b)] @@ -358,7 +374,7 @@ def test_sqrt(number = 0, expected_error = Error.Okay): if number < 0: expected_result = 0 else: - expected_result = integer_sqrt(number) + expected_result = big_integer_sqrt(number) return test("test_sqrt", res, [number], expected_error, expected_result) def root_n(number, root): @@ -461,7 +477,7 @@ def test_lcm(a = 0, b = 0, expected_error = Error.Okay): res = int_lcm(*args) expected_result = None if expected_error == Error.Okay: - expected_result = math.lcm(a, b) + expected_result = big_integer_lcm(a, b) return test("test_lcm", res, [a, b], expected_error, expected_result) @@ -470,7 +486,7 @@ def test_is_square(a = 0, b = 0, expected_error = Error.Okay): res = is_square(*args) expected_result = None if expected_error == Error.Okay: - expected_result = str(integer_sqrt(a) ** 2 == a) if a > 0 else "False" + expected_result = str(big_integer_sqrt(a) ** 2 == a) if a > 0 else "False" return test("test_is_square", res, [a], expected_error, expected_result) @@ -703,6 +719,15 @@ if __name__ == '__main__': b = randint(0, min(BITS, 120)) elif test_proc == test_is_square: a = randint(0, 1 << BITS) + elif test_proc == test_lcm: + smallest = min(a, b) + biggest = max(a, b) + + # Randomly swap biggest and smallest + if randint(1, 11) % 2 == 0: + smallest, biggest = biggest, smallest + + a, b = smallest, biggest else: b = randint(0, 1 << BITS)