diff --git a/core/math/big/example.odin b/core/math/big/example.odin index 3d51ac400..1d9a4cb3c 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -80,16 +80,27 @@ demo :: proc() { err: Error; bs: string; - if err = factorial(a, 500); err != nil { fmt.printf("factorial err: %v\n", err); return; } + // if err = factorial(a, 850); err != nil { fmt.printf("factorial err: %v\n", err); return; } + + foo := "615037959146039477924633848896619112832171971562900618409305032006863881436080"; + if err = atoi(a, foo, 10); err != nil { return; } + print("a: ", a, 10, true, true, true); + fmt.println(); + { SCOPED_TIMING(.sqr); - if err = sqr(b, a); err != nil { fmt.printf("sqr err: %v\n", err); return; } + if err = sqr(b, a); err != nil { fmt.printf("sqr err: %v\n", err); return; } } + fmt.println(); + print("b _sqr_karatsuba: ", b); + fmt.println(); - bs, err = itoa(b, 10); + bs, err = itoa(b, 16); defer delete(bs); - assert(bs[:50] == "14887338741396604108836218987068397819515734169330"); + if bs[:50] != "1C367982F3050A8A3C62A8A7906D165438B54B287AF3F15D36" { + fmt.println("sqr failed"); + } } main :: proc() { diff --git a/core/math/big/internal.odin b/core/math/big/internal.odin index 9b267176d..00b1ed7bf 100644 --- a/core/math/big/internal.odin +++ b/core/math/big/internal.odin @@ -36,7 +36,7 @@ import "core:mem" import "core:intrinsics" import rnd "core:math/rand" -//import "core:fmt" +// import "core:fmt" /* Low-level addition, unsigned. Handbook of Applied Cryptography, algorithm 14.7. @@ -627,20 +627,22 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc Do we need to square? */ if src.used >= SQR_TOOM_CUTOFF { - /* Use Toom-Cook? */ - // err = s_mp_sqr_toom(a, c); + /* + Use Toom-Cook? + */ // fmt.printf("_private_int_sqr_toom: %v\n", src.used); - err = #force_inline _private_int_sqr(dest, src); + err = #force_inline _private_int_sqr_karatsuba(dest, src); } else if src.used >= SQR_KARATSUBA_CUTOFF { - /* Karatsuba? */ - // err = s_mp_sqr_karatsuba(a, c); - // fmt.printf("_private_int_sqr_karatsuba: %v\n", src.used); - err = #force_inline _private_int_sqr(dest, src); + /* + Karatsuba? + */ + err = #force_inline _private_int_sqr_karatsuba(dest, src); } else if ((src.used * 2) + 1) < _WARRAY && src.used < (_MAX_COMBA / 2) { /* Fast comba? */ err = #force_inline _private_int_sqr_comba(dest, src); + //err = #force_inline _private_int_sqr(dest, src); } else { err = #force_inline _private_int_sqr(dest, src); } diff --git a/core/math/big/private.odin b/core/math/big/private.odin index f49cbc51a..3d8497c72 100644 --- a/core/math/big/private.odin +++ b/core/math/big/private.odin @@ -354,6 +354,72 @@ _private_int_sqr_comba :: proc(dest, src: ^Int, allocator := context.allocator) return internal_clamp(dest); } +/* + Karatsuba squaring, computes `dest` = `src` * `src` using three half-size squarings. + + See comments of `_private_int_mul_karatsuba` for details. + It is essentially the same algorithm but merely tuned to perform recursive squarings. +*/ +_private_int_sqr_karatsuba :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) { + context.allocator = allocator; + + x0, x1, t1, t2, x0x0, x1x1 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}; + defer internal_destroy(x0, x1, t1, t2, x0x0, x1x1); + + /* + Min # of digits, divided by two. + */ + B := src.used >> 1; + + /* + Init temps. + */ + if err = internal_grow(x0, B); err != nil { return err; } + if err = internal_grow(x1, src.used - B); err != nil { return err; } + if err = internal_grow(t1, src.used * 2); err != nil { return err; } + if err = internal_grow(t2, src.used * 2); err != nil { return err; } + if err = internal_grow(x0x0, B * 2 ); err != nil { return err; } + if err = internal_grow(x1x1, (src.used - B) * 2); err != nil { return err; } + + /* + Now shift the digits. + */ + x0.used = B; + x1.used = src.used - B; + + internal_copy_digits(x0, src, x0.used); + #force_inline mem.copy_non_overlapping(&x1.digit[0], &src.digit[B], size_of(DIGIT) * x1.used); + internal_clamp(x0); + + /* + Now calc the products x0*x0 and x1*x1. + */ + if err = internal_sqr(x0x0, x0); err != nil { return err; } + if err = internal_sqr(x1x1, x1); err != nil { return err; } + + /* + Now calc (x1+x0)^2 + */ + if err = internal_add(t1, x0, x1); err != nil { return err; } + if err = internal_sqr(t1, t1); err != nil { return err; } + + /* + Add x0y0 + */ + if err = internal_add(t2, x0x0, x1x1); err != nil { return err; } + if err = internal_sub(t1, t1, t2); err != nil { return err; } + + /* + Shift by B. + */ + if err = internal_shl_digit(t1, B); err != nil { return err; } + if err = internal_shl_digit(x1x1, B * 2); err != nil { return err; } + if err = internal_add(t1, t1, x0x0); err != nil { return err; } + if err = internal_add(dest, t1, x1x1); err != nil { return err; } + + return internal_clamp(dest); +} + /* Divide by three (based on routine from MPI and the GMP manual). */