big: Add Montgomery reduction.

This commit is contained in:
Jeroen van Rijn
2021-08-21 14:00:31 +02:00
parent b88e945268
commit 893cc013b5
2 changed files with 193 additions and 0 deletions

View File

@@ -33,6 +33,80 @@ int_prime_is_divisible :: proc(a: ^Int, allocator := context.allocator) -> (res:
return false, nil;
}
/*
Shifts with subtractions when the result is greater than b.
The method is slightly modified to shift B unconditionally upto just under
the leading bit of b. This saves alot of multiple precision shifting.
*/
/*
internal_int_montgomery_calc_normalization :: proc(a, b: ^Int) -> (err: Error) {
int x, bits;
mp_err err;
/* how many bits of last digit does b use */
bits = mp_count_bits(b) % MP_DIGIT_BIT;
if (b->used > 1) {
if ((err = mp_2expt(a, ((b->used - 1) * MP_DIGIT_BIT) + bits - 1)) != MP_OKAY) {
return err;
}
} else {
mp_set(a, 1uL);
bits = 1;
}
/* now compute C = A * B mod b */
for (x = bits - 1; x < (int)MP_DIGIT_BIT; x++) {
if ((err = mp_mul_2(a, a)) != MP_OKAY) {
return err;
}
if (mp_cmp_mag(a, b) != MP_LT) {
if ((err = s_mp_sub(a, b, a)) != MP_OKAY) {
return err;
}
}
}
return nil;
}
*/
/*
Sets up the Montgomery reduction stuff.
*/
internal_int_montgomery_setup :: proc(n: ^Int) -> (rho: DIGIT, err: Error) {
/*
Fast inversion mod 2**k
Based on the fact that:
XA = 1 (mod 2**n) => (X(2-XA)) A = 1 (mod 2**2n)
=> 2*X*A - X*X*A*A = 1
=> 2*(1) - (1) = 1
*/
b := n.digit[0];
if b & 1 == 0 { return 0, .Invalid_Argument; }
x := (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
x *= 2 - (b * x); /* here x*a==1 mod 2**8 */
x *= 2 - (b * x); /* here x*a==1 mod 2**16 */
when _WORD_TYPE_BITS == 64 {
x *= 2 - (b * x); /* here x*a==1 mod 2**32 */
x *= 2 - (b * x); /* here x*a==1 mod 2**64 */
}
/*
rho = -1/m mod b
*/
rho = DIGIT(((_WORD(1) << _WORD(_DIGIT_BITS)) - _WORD(x)) & _WORD(_MASK));
return rho, nil;
}
/*
Returns the number of Rabin-Miller trials needed for a given bit size.
*/
number_of_rabin_miller_trials :: proc(bit_size: int) -> (number_of_trials: int) {
switch {
case bit_size <= 80:

View File

@@ -1542,6 +1542,125 @@ _private_int_log :: proc(a: ^Int, base: DIGIT, allocator := context.allocator) -
/*
Computes xR**-1 == x (mod N) via Montgomery Reduction.
This is an optimized implementation of `internal_montgomery_reduce`
which uses the comba method to quickly calculate the columns of the reduction.
Based on Algorithm 14.32 on pp.601 of HAC.
*/
_private_montgomery_reduce_comba :: proc(x, n: ^Int, rho: DIGIT) -> (err: Error) {
W: [_WARRAY]_WORD = ---;
if x.used > _WARRAY { return .Invalid_Argument; }
/*
Get old used count.
*/
old_used := x.used;
/*
Grow `x` as required.
*/
internal_grow(x, n.used + 1) or_return;
/*
First we have to get the digits of the input into an array of double precision words W[...]
Copy the digits of `x` into W[0..`x.used` - 1]
*/
ix: int;
for ix = 0; ix < x.used; ix += 1 {
W[ix] = _WORD(x.digit[ix]);
}
/*
Zero the high words of W[a->used..m->used*2].
*/
zero_upper := (n.used * 2) + 1;
if ix < zero_upper {
for ix = x.used; ix < zero_upper; ix += 1 {
W[ix] = {};
}
}
/*
Now we proceed to zero successive digits from the least significant upwards.
*/
for ix = 0; ix < n.used; ix += 1 {
/*
`mu = ai * m' mod b`
We avoid a double precision multiplication (which isn't required)
by casting the value down to a DIGIT. Note this requires
that W[ix-1] have the carry cleared (see after the inner loop)
*/
mu := ((W[ix] & _WORD(_MASK)) * _WORD(rho)) & _WORD(_MASK);
/*
`a = a + mu * m * b**i`
This is computed in place and on the fly. The multiplication
by b**i is handled by offseting which columns the results
are added to.
Note the comba method normally doesn't handle carries in the
inner loop In this case we fix the carry from the previous
column since the Montgomery reduction requires digits of the
result (so far) [see above] to work.
This is handled by fixing up one carry after the inner loop.
The carry fixups are done in order so after these loops the
first m->used words of W[] have the carries fixed.
*/
for iy := 0; iy < n.used; iy += 1 {
W[ix + iy] += mu * _WORD(n.digit[iy]);
}
/*
Now fix carry for next digit, W[ix+1].
*/
W[ix + 1] += (W[ix] >> _DIGIT_BITS);
}
/*
Now we have to propagate the carries and shift the words downward
[all those least significant digits we zeroed].
*/
for ; ix < n.used * 2; ix += 1 {
W[ix + 1] += (W[ix] >> _DIGIT_BITS);
}
/* copy out, A = A/b**n
*
* The result is A/b**n but instead of converting from an
* array of mp_word to mp_digit than calling mp_rshd
* we just copy them in the right order
*/
for ix = 0; ix < (n.used + 1); ix += 1 {
x.digit[ix] = DIGIT(W[n.used + ix] & _WORD(_MASK));
}
/*
Set the max used.
*/
x.used = n.used + 1;
/*
Zero old_used digits, if the input a was larger than m->used+1 we'll have to clear the digits.
*/
internal_zero_unused(x, old_used);
internal_clamp(x);
/*
if A >= m then A = A - m
*/
if internal_cmp_mag(x, n) != -1 {
return internal_sub(x, x, n);
}
return nil;
}
/*
hac 14.61, pp608
*/