diff --git a/core/math/bigint/basic.odin b/core/math/bigint/basic.odin index 866124327..8bab5b50d 100644 --- a/core/math/bigint/basic.odin +++ b/core/math/bigint/basic.odin @@ -89,20 +89,20 @@ add_digit :: proc(dest, a: ^Int, digit: DIGIT) -> (err: Error) { If `a` is negative and `|a|` >= `digit`, call `dest = |a| - digit` */ if is_neg(a) && (a.used > 1 || a.digit[0] >= digit) { + fmt.print("a = neg, %v\n", print_int(a)); /* Temporarily fix `a`'s sign. */ - a.sign = .Zero_or_Positive; + t := a; + t.sign = .Zero_or_Positive; /* dest = |a| - digit */ - err = sub(dest, a, digit); + err = sub(dest, t, digit); /* Restore sign and set `dest` sign. */ dest.sign = .Negative; - a.sign = .Negative; - clamp(dest); return err; @@ -208,11 +208,92 @@ sub_two_ints :: proc(dest, number, decrease: ^Int) -> (err: Error) { Adds the unsigned `DIGIT` immediate to an `Int`, such that the `DIGIT` doesn't have to be turned into an `Int` first. - dest = number - decrease; + dest = a - digit; */ -sub_digit :: proc(dest, number: ^Int, decrease: DIGIT) -> (err: Error) { +sub_digit :: proc(dest, a: ^Int, digit: DIGIT) -> (err: Error) { + dest := dest; x := a; digit := digit; + assert_initialized(dest); assert_initialized(a); - return .Unimplemented; + /* + Fast paths for destination and input Int being the same. + */ + if dest == a { + /* + Fast path for `dest` is negative and unsigned addition doesn't overflow the lowest digit. + */ + if is_neg(dest) && (dest.digit[0] + digit < _DIGIT_MAX) { + dest.digit[0] += digit; + return .OK; + } + /* + Can be subtracted from dest.digit[0] without underflow. + */ + if is_pos(a) && (dest.digit[0] > digit) { + dest.digit[0] -= digit; + return .OK; + } + } + + /* + Grow destination as required. + */ + err = grow(dest, a.used + 1); + if err != .OK { + return err; + } + + /* + If `a` is negative, just do an unsigned addition (with fudged signs). + */ + if is_neg(a) { + t := a; + t.sign = .Zero_or_Positive; + + err = add(dest, t, digit); + dest.sign = .Negative; + + clamp(dest); + return err; + } + + old_used := dest.used; + + /* + if `a`<= digit, simply fix the single digit. + */ + if a.used == 1 && (a.digit[0] <= digit || is_zero(a)) { + dest.digit[0] = digit - a.digit[0] if a.used == 1 else digit; + dest.sign = .Negative; + dest.used = 1; + } else { + dest.sign = .Zero_or_Positive; + dest.used = a.used; + + /* + Subtract with carry. + */ + carry := digit; + + for i := 0; i < a.used; i += 1 { + dest.digit[i] = a.digit[i] - carry; + carry := dest.digit[i] >> ((size_of(DIGIT) * 8) - 1); + dest.digit[i] &= _MASK; + } + } + + zero_count := old_used - dest.used; + /* + Zero remainder. + */ + if zero_count > 0 { + mem.zero_slice(dest.digit[dest.used:][:zero_count]); + } + /* + Adjust dest.used based on leading zeroes. + */ + clamp(dest); + + return .OK; } sub :: proc{sub_two_ints, sub_digit}; @@ -333,7 +414,7 @@ _sub :: proc(dest, number, decrease: ^Int) -> (err: Error) { it will propagate all the way to the MSB. As a result a single shift is enough to get the carry. */ - borrow = dest.digit[i] >> (_DIGIT_BITS - 1); + borrow = dest.digit[i] >> ((size_of(DIGIT) * 8) - 1); /* Clear borrow from dest[i]. */ @@ -351,7 +432,7 @@ _sub :: proc(dest, number, decrease: ^Int) -> (err: Error) { it will propagate all the way to the MSB. As a result a single shift is enough to get the carry. */ - borrow = dest.digit[i] >> (_DIGIT_BITS - 1); + borrow = dest.digit[i] >> ((size_of(DIGIT) * 8) - 1); /* Clear borrow from dest[i]. */ diff --git a/core/math/bigint/build.bat b/core/math/bigint/build.bat index df9cd1e85..ccc547948 100644 --- a/core/math/bigint/build.bat +++ b/core/math/bigint/build.bat @@ -1,2 +1,3 @@ @echo off -odin run . -vet \ No newline at end of file +odin run . +rem -vet \ No newline at end of file diff --git a/core/math/bigint/example.odin b/core/math/bigint/example.odin index 0c88a6e1a..69dab47dd 100644 --- a/core/math/bigint/example.odin +++ b/core/math/bigint/example.odin @@ -30,11 +30,11 @@ demo :: proc() { a, b, c: ^Int; err: Error; - a, err = init(21); + a, err = init(512); defer destroy(a); fmt.printf("a: %v, err: %v\n\n", print_int(a), err); - b, err = init(21); + b, err = init(42); defer destroy(b); fmt.printf("b: %v, err: %v\n\n", print_int(b), err); @@ -44,7 +44,7 @@ demo :: proc() { fmt.printf("c: %v\n", print_int(c, true)); fmt.println("=== Add ==="); - err = add(c, a, DIGIT(42)); + err = sub(a, a, DIGIT(42)); // err = add(c, a, b); fmt.printf("Error: %v\n", err); fmt.printf("a: %v\n", print_int(a));