big: Fix mul.

This commit is contained in:
Jeroen van Rijn
2021-07-29 16:37:16 +02:00
parent 708389a7ee
commit 13fab36639
4 changed files with 168 additions and 8 deletions

View File

@@ -950,7 +950,7 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
Limit ourselves to `digits` DIGITs of output.
*/
pb := min(b.used, digits - ix);
carry := DIGIT(0);
carry := _WORD(0);
iy := 0;
/*
Compute the column of the output and propagate the carry.
@@ -959,12 +959,12 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
/*
Compute the column as a _WORD.
*/
column := t.digit[ix + iy] + a.digit[ix] * b.digit[iy] + carry;
column := _WORD(t.digit[ix + iy]) + _WORD(a.digit[ix]) * _WORD(b.digit[iy]) + carry;
/*
The new column is the lower part of the result.
*/
t.digit[ix + iy] = column & _MASK;
t.digit[ix + iy] = DIGIT(column & _WORD(_MASK));
/*
Get the carry word from the result.
@@ -975,7 +975,7 @@ _int_mul :: proc(dest, a, b: ^Int, digits: int) -> (err: Error) {
Set carry if it is placed below digits
*/
if ix + iy < digits {
t.digit[ix + pb] = carry;
t.digit[ix + pb] = DIGIT(carry);
}
}

View File

@@ -1,2 +1,6 @@
@echo off
odin run . -vet
:odin run . -vet
odin build . -build-mode:dll
:dumpbin /EXPORTS big.dll
python test.py

View File

@@ -41,4 +41,58 @@ PyRes :: struct {
r, err = int_itoa_cstring(sum, i8(radix), context.temp_allocator);
if err != .None { return PyRes{res=":add_two:itoa(sum):", err=err}; }
return PyRes{res = r, err = .None};
}
@export test_sub_two :: proc "c" (a, b: cstring, radix := int(10)) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
aa, bb, sum := &Int{}, &Int{}, &Int{};
defer destroy(aa, bb, sum);
if err = atoi(aa, string(a), i8(radix)); err != .None { return PyRes{res=":sub_two:atoi(a):", err=err}; }
if err = atoi(bb, string(b), i8(radix)); err != .None { return PyRes{res=":sub_two:atoi(b):", err=err}; }
if err = sub(sum, aa, bb); err != .None { return PyRes{res=":sub_two:sub(sum,a,b):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(sum, i8(radix), context.temp_allocator);
if err != .None { return PyRes{res=":sub_two:itoa(sum):", err=err}; }
return PyRes{res = r, err = .None};
}
@export test_mul_two :: proc "c" (a, b: cstring, radix := int(10)) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
aa, bb, product := &Int{}, &Int{}, &Int{};
defer destroy(aa, bb, product);
if err = atoi(aa, string(a), i8(radix)); err != .None { return PyRes{res=":mul_two:atoi(a):", err=err}; }
if err = atoi(bb, string(b), i8(radix)); err != .None { return PyRes{res=":mul_two:atoi(b):", err=err}; }
if err = mul(product, aa, bb); err != .None { return PyRes{res=":mul_two:mul(product,a,b):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(product, i8(radix), context.temp_allocator);
if err != .None { return PyRes{res=":mul_two:itoa(product):", err=err}; }
return PyRes{res = r, err = .None};
}
/*
NOTE(Jeroen): For simplicity, we don't return the quotient and the remainder, just the quotient.
*/
@export test_div_two :: proc "c" (a, b: cstring, radix := int(10)) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
aa, bb, quotient := &Int{}, &Int{}, &Int{};
defer destroy(aa, bb, quotient);
if err = atoi(aa, string(a), i8(radix)); err != .None { return PyRes{res=":div_two:atoi(a):", err=err}; }
if err = atoi(bb, string(b), i8(radix)); err != .None { return PyRes{res=":div_two:atoi(b):", err=err}; }
if err = div(quotient, aa, bb); err != .None { return PyRes{res=":div_two:div(quotient,a,b):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(quotient, i8(radix), context.temp_allocator);
if err != .None { return PyRes{res=":div_two:itoa(quotient):", err=err}; }
return PyRes{res = r, err = .None};
}

View File

@@ -1,5 +1,6 @@
from math import *
from ctypes import *
from random import *
import os
#
@@ -38,6 +39,9 @@ except:
print("Couldn't find or load " + LIB_PATH + ".")
exit(1)
#
# res = a + b, err
#
try:
l.test_add_two.argtypes = [c_char_p, c_char_p, c_longlong]
l.test_add_two.restype = Res
@@ -47,6 +51,44 @@ except:
add_two = l.test_add_two
#
# res = a - b, err
#
try:
l.test_sub_two.argtypes = [c_char_p, c_char_p, c_longlong]
l.test_sub_two.restype = Res
except:
print("Couldn't find exported function 'test_sub_two'")
exit(2)
sub_two = l.test_sub_two
#
# res = a * b, err
#
try:
l.test_mul_two.argtypes = [c_char_p, c_char_p, c_longlong]
l.test_mul_two.restype = Res
except:
print("Couldn't find exported function 'test_add_two'")
exit(2)
mul_two = l.test_mul_two
#
# res = a / b, err
#
try:
l.test_div_two.argtypes = [c_char_p, c_char_p, c_longlong]
l.test_div_two.restype = Res
except:
print("Couldn't find exported function 'test_div_two'")
exit(2)
div_two = l.test_div_two
try:
l.test_error_string.argtypes = [c_byte]
l.test_error_string.restype = c_char_p
@@ -91,15 +133,52 @@ def test_add_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_res
expected_result = a + b
return test("test_add_two", res, [str(a), str(b), radix], expected_error, expected_result)
def test_sub_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
res = sub_two(str(a).encode('utf-8'), str(b).encode('utf-8'), radix)
if expected_result == None:
expected_result = a - b
return test("test_sub_two", res, [str(a), str(b), radix], expected_error, expected_result)
def test_mul_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
res = mul_two(str(a).encode('utf-8'), str(b).encode('utf-8'), radix)
if expected_result == None:
expected_result = a * b
return test("test_mul_two", res, [str(a), str(b), radix], expected_error, expected_result)
def test_div_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
res = div_two(str(a).encode('utf-8'), str(b).encode('utf-8'), radix)
if expected_result == None:
expected_result = a // b if b != 0 else None
return test("test_add_two", res, [str(a), str(b), radix], expected_error, expected_result)
# TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
#
# The last two arguments in tests are the expected error and expected result.
#
# The expected error defaults to None.
# By default the Odin implementation will be tested against the Python one.
# You can override that by supplying an expected result as the last argument instead.
TESTS = {
test_add_two: [
[ 1234, 5432, 10, ],
[ 1234, 5432, 110, E_Invalid_Argument, ],
[ 1234, 5432, 10, ],
[ 1234, 5432, 110, E_Invalid_Argument, ],
],
test_sub_two: [
[ 1234, 5432, 10, ],
],
test_mul_two: [
[ 1234, 5432, 10, ],
[ 1099243943008198766717263669950239669, 137638828577110581150675834234248871, 10, ]
],
test_div_two: [
[ 54321, 12345, 10, ],
[ 55431, 0, 10, E_Division_by_Zero, ],
],
}
if __name__ == '__main__':
print()
print("---- core:math/big tests ----")
print()
@@ -112,4 +191,27 @@ if __name__ == '__main__':
else:
count_fail += 1
print("{}: {} passes, {} failures.".format(test_proc.__name__, count_pass, count_fail))
print("{}: {} passes, {} failures.".format(test_proc.__name__, count_pass, count_fail))
print()
print("---- core:math/big random tests ----")
print()
for test_proc in [test_add_two, test_sub_two, test_mul_two, test_div_two]:
count_pass = 0
count_fail = 0
a = randint(0, 1 << 120)
b = randint(0, 1 << 120)
res = None
# We've already tested division by zero above.
if b == 0 and test_proc == test_div_two:
b = b + 1
if test_proc(a, b):
count_pass += 1
else:
count_fail += 1
print("{} random: {} passes, {} failures.".format(test_proc.__name__, count_pass, count_fail))