diff --git a/core/math/big/basic.odin b/core/math/big/basic.odin index 00411f80a..bd1d25bb9 100644 --- a/core/math/big/basic.odin +++ b/core/math/big/basic.odin @@ -652,6 +652,24 @@ sqr :: proc(dest, src: ^Int) -> (err: Error) { return mul(dest, src, src); } +/* + divmod. + Both the quotient and remainder are optional and may be passed a nil. +*/ +int_div :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) { + /* + Early out if neither of the results is wanted. + */ + if quotient == nil && remainder == nil { return .None; } + + + if err = clear_if_uninitialized(numerator); err != .None { return err; } + if err = clear_if_uninitialized(denominator); err != .None { return err; } + + return _int_div(quotient, remainder, numerator, denominator); +} +div :: proc{ int_div, }; + /* ========================== @@ -1014,47 +1032,49 @@ _int_div_3 :: proc(quotient, numerator: ^Int) -> (remainder: int, err: Error) { */ _int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) { -// mp_int ta, tb, tq, q; -// int n; -// bool neg; -// mp_err err; + ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{}; -// /* init our temps */ -// if ((err = mp_init_multi(&ta, &tb, &tq, &q, NULL)) != MP_OKAY) { -// return err; -// } + goto_end: for { + if err = one(tq); err != .None { break goto_end; } -// mp_set(&tq, 1uL); -// n = mp_count_bits(a) - mp_count_bits(b); -// if ((err = mp_abs(a, &ta)) != MP_OKAY) goto LBL_ERR; -// if ((err = mp_abs(b, &tb)) != MP_OKAY) goto LBL_ERR; -// if ((err = mp_mul_2d(&tb, n, &tb)) != MP_OKAY) goto LBL_ERR; -// if ((err = mp_mul_2d(&tq, n, &tq)) != MP_OKAY) goto LBL_ERR; + num_bits, _ := count_bits(numerator); + den_bits, _ := count_bits(denominator); + n := num_bits - den_bits; -// while (n-- >= 0) { -// if (mp_cmp(&tb, &ta) != MP_GT) { -// if ((err = mp_sub(&ta, &tb, &ta)) != MP_OKAY) goto LBL_ERR; -// if ((err = mp_add(&q, &tq, &q)) != MP_OKAY) goto LBL_ERR; -// } -// if ((err = mp_div_2d(&tb, 1, &tb, NULL)) != MP_OKAY) goto LBL_ERR; -// if ((err = mp_div_2d(&tq, 1, &tq, NULL)) != MP_OKAY) goto LBL_ERR; -// } + if err = abs(ta, numerator); err != .None { break goto_end; } + if err = abs(tb, denominator); err != .None { break goto_end; } -// /* now q == quotient and ta == remainder */ + if err = shl(tb, tb, n); err != .None { break goto_end; } + if err = shl(tq, tq, n); err != .None { break goto_end; } -// neg = (a->sign != b->sign); -// if (c != NULL) { -// mp_exch(c, &q); -// c->sign = ((neg && !mp_iszero(c)) ? MP_NEG : MP_ZPOS); -// } -// if (d != NULL) { -// mp_exch(d, &ta); -// d->sign = (mp_iszero(d) ? MP_ZPOS : a->sign); -// } -// LBL_ERR: -// mp_clear_multi(&ta, &tb, &tq, &q, NULL); -// return err; + for ; n >= 0; n -= 1 { + c: int; + if c, err = cmp(tb, ta); err != .None { break goto_end; } + if c != 1 { + if err = sub(ta, ta, tb); err != .None { break goto_end; } + if err = add( q, tq, q); err != .None { break goto_end; } + } + if err = shr1(tb, tb); err != .None { break goto_end; } + if err = shr1(tq, tq); err != .None { break goto_end; } + } + /* + Now q == quotient and ta == remainder. + */ + neg := numerator.sign != denominator.sign; + if quotient != nil { + swap(quotient, q); + z, _ := is_zero(quotient); + quotient.sign = .Negative if neg && !z else .Zero_or_Positive; + } + if remainder != nil { + swap(remainder, ta); + z, _ := is_zero(numerator); + remainder.sign = .Zero_or_Positive if z else numerator.sign; + } - return .None; + break goto_end; + } + destroy(ta, tb, tq, q); + return err; } \ No newline at end of file diff --git a/core/math/big/example.odin b/core/math/big/example.odin index 103259b3e..ca3fc5531 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -68,6 +68,22 @@ demo :: proc() { print("quotient ", quotient, 10); fmt.println("remainder ", i); fmt.println("error", err); + + fmt.println(); fmt.println(); + + err = set (numerator, 15625); + err = set (denominator, 3); + err = zero(quotient); + + print("numerator ", numerator, 10); + print("denominator", denominator, 10); + + err = _int_div_small(quotient, remainder, numerator, denominator); + + print("quotient ", quotient, 10); + print("remainder ", remainder, 10); + + } main :: proc() {