]> Git Repo - secp256k1.git/blob - src/modinv32_impl.h
Add extensive comments on the safegcd algorithm and implementation
[secp256k1.git] / src / modinv32_impl.h
1 /***********************************************************************
2  * Copyright (c) 2020 Peter Dettman                                    *
3  * Distributed under the MIT software license, see the accompanying    *
4  * file COPYING or https://www.opensource.org/licenses/mit-license.php.*
5  **********************************************************************/
6
7 #ifndef SECP256K1_MODINV32_IMPL_H
8 #define SECP256K1_MODINV32_IMPL_H
9
10 #include "modinv32.h"
11
12 #include "util.h"
13
14 #include <stdlib.h>
15
16 /* This file implements modular inversion based on the paper "Fast constant-time gcd computation and
17  * modular inversion" by Daniel J. Bernstein and Bo-Yin Yang.
18  *
19  * For an explanation of the algorithm, see doc/safegcd_implementation.md. This file contains an
20  * implementation for N=30, using 30-bit signed limbs represented as int32_t.
21  */
22
23 /* Take as input a signed30 number in range (-2*modulus,modulus), and add a multiple of the modulus
24  * to it to bring it to range [0,modulus). If sign < 0, the input will also be negated in the
25  * process. The input must have limbs in range (-2^30,2^30). The output will have limbs in range
26  * [0,2^30). */
27 static void secp256k1_modinv32_normalize_30(secp256k1_modinv32_signed30 *r, int32_t sign, const secp256k1_modinv32_modinfo *modinfo) {
28     const int32_t M30 = (int32_t)(UINT32_MAX >> 2);
29     int32_t r0 = r->v[0], r1 = r->v[1], r2 = r->v[2], r3 = r->v[3], r4 = r->v[4],
30             r5 = r->v[5], r6 = r->v[6], r7 = r->v[7], r8 = r->v[8];
31     int32_t cond_add, cond_negate;
32
33     /* In a first step, add the modulus if the input is negative, and then negate if requested.
34      * This brings r from range (-2*modulus,modulus) to range (-modulus,modulus). As all input
35      * limbs are in range (-2^30,2^30), this cannot overflow an int32_t. Note that the right
36      * shifts below are signed sign-extending shifts (see assumptions.h for tests that that is
37      * indeed the behavior of the right shift operator). */
38     cond_add = r8 >> 31;
39     r0 += modinfo->modulus.v[0] & cond_add;
40     r1 += modinfo->modulus.v[1] & cond_add;
41     r2 += modinfo->modulus.v[2] & cond_add;
42     r3 += modinfo->modulus.v[3] & cond_add;
43     r4 += modinfo->modulus.v[4] & cond_add;
44     r5 += modinfo->modulus.v[5] & cond_add;
45     r6 += modinfo->modulus.v[6] & cond_add;
46     r7 += modinfo->modulus.v[7] & cond_add;
47     r8 += modinfo->modulus.v[8] & cond_add;
48     cond_negate = sign >> 31;
49     r0 = (r0 ^ cond_negate) - cond_negate;
50     r1 = (r1 ^ cond_negate) - cond_negate;
51     r2 = (r2 ^ cond_negate) - cond_negate;
52     r3 = (r3 ^ cond_negate) - cond_negate;
53     r4 = (r4 ^ cond_negate) - cond_negate;
54     r5 = (r5 ^ cond_negate) - cond_negate;
55     r6 = (r6 ^ cond_negate) - cond_negate;
56     r7 = (r7 ^ cond_negate) - cond_negate;
57     r8 = (r8 ^ cond_negate) - cond_negate;
58     /* Propagate the top bits, to bring limbs back to range (-2^30,2^30). */
59     r1 += r0 >> 30; r0 &= M30;
60     r2 += r1 >> 30; r1 &= M30;
61     r3 += r2 >> 30; r2 &= M30;
62     r4 += r3 >> 30; r3 &= M30;
63     r5 += r4 >> 30; r4 &= M30;
64     r6 += r5 >> 30; r5 &= M30;
65     r7 += r6 >> 30; r6 &= M30;
66     r8 += r7 >> 30; r7 &= M30;
67
68     /* In a second step add the modulus again if the result is still negative, bringing r to range
69      * [0,modulus). */
70     cond_add = r8 >> 31;
71     r0 += modinfo->modulus.v[0] & cond_add;
72     r1 += modinfo->modulus.v[1] & cond_add;
73     r2 += modinfo->modulus.v[2] & cond_add;
74     r3 += modinfo->modulus.v[3] & cond_add;
75     r4 += modinfo->modulus.v[4] & cond_add;
76     r5 += modinfo->modulus.v[5] & cond_add;
77     r6 += modinfo->modulus.v[6] & cond_add;
78     r7 += modinfo->modulus.v[7] & cond_add;
79     r8 += modinfo->modulus.v[8] & cond_add;
80     /* And propagate again. */
81     r1 += r0 >> 30; r0 &= M30;
82     r2 += r1 >> 30; r1 &= M30;
83     r3 += r2 >> 30; r2 &= M30;
84     r4 += r3 >> 30; r3 &= M30;
85     r5 += r4 >> 30; r4 &= M30;
86     r6 += r5 >> 30; r5 &= M30;
87     r7 += r6 >> 30; r6 &= M30;
88     r8 += r7 >> 30; r7 &= M30;
89
90     r->v[0] = r0;
91     r->v[1] = r1;
92     r->v[2] = r2;
93     r->v[3] = r3;
94     r->v[4] = r4;
95     r->v[5] = r5;
96     r->v[6] = r6;
97     r->v[7] = r7;
98     r->v[8] = r8;
99 }
100
101 /* Data type for transition matrices (see section 3 of explanation).
102  *
103  * t = [ u  v ]
104  *     [ q  r ]
105  */
106 typedef struct {
107     int32_t u, v, q, r;
108 } secp256k1_modinv32_trans2x2;
109
110 /* Compute the transition matrix and eta for 30 divsteps.
111  *
112  * Input:  eta: initial eta
113  *         f0:  bottom limb of initial f
114  *         g0:  bottom limb of initial g
115  * Output: t: transition matrix
116  * Return: final eta
117  *
118  * Implements the divsteps_n_matrix function from the explanation.
119  */
120 static int32_t secp256k1_modinv32_divsteps_30(int32_t eta, uint32_t f0, uint32_t g0, secp256k1_modinv32_trans2x2 *t) {
121     /* u,v,q,r are the elements of the transformation matrix being built up,
122      * starting with the identity matrix. Semantically they are signed integers
123      * in range [-2^30,2^30], but here represented as unsigned mod 2^32. This
124      * permits left shifting (which is UB for negative numbers). The range
125      * being inside [-2^31,2^31) means that casting to signed works correctly.
126      */
127     uint32_t u = 1, v = 0, q = 0, r = 1;
128     uint32_t c1, c2, f = f0, g = g0, x, y, z;
129     int i;
130
131     for (i = 0; i < 30; ++i) {
132         VERIFY_CHECK((f & 1) == 1); /* f must always be odd */
133         VERIFY_CHECK((u * f0 + v * g0) == f << i);
134         VERIFY_CHECK((q * f0 + r * g0) == g << i);
135         /* Compute conditional masks for (eta < 0) and for (g & 1). */
136         c1 = eta >> 31;
137         c2 = -(g & 1);
138         /* Compute x,y,z, conditionally negated versions of f,u,v. */
139         x = (f ^ c1) - c1;
140         y = (u ^ c1) - c1;
141         z = (v ^ c1) - c1;
142         /* Conditionally add x,y,z to g,q,r. */
143         g += x & c2;
144         q += y & c2;
145         r += z & c2;
146         /* In what follows, c1 is a condition mask for (eta < 0) and (g & 1). */
147         c1 &= c2;
148         /* Conditionally negate eta, and unconditionally subtract 1. */
149         eta = (eta ^ c1) - (c1 + 1);
150         /* Conditionally add g,q,r to f,u,v. */
151         f += g & c1;
152         u += q & c1;
153         v += r & c1;
154         /* Shifts */
155         g >>= 1;
156         u <<= 1;
157         v <<= 1;
158     }
159     /* Return data in t and return value. */
160     t->u = (int32_t)u;
161     t->v = (int32_t)v;
162     t->q = (int32_t)q;
163     t->r = (int32_t)r;
164     return eta;
165 }
166
167 /* Compute the transition matrix and eta for 30 divsteps (variable time).
168  *
169  * Input:  eta: initial eta
170  *         f0:  bottom limb of initial f
171  *         g0:  bottom limb of initial g
172  * Output: t: transition matrix
173  * Return: final eta
174  *
175  * Implements the divsteps_n_matrix_var function from the explanation.
176  */
177 static int32_t secp256k1_modinv32_divsteps_30_var(int32_t eta, uint32_t f0, uint32_t g0, secp256k1_modinv32_trans2x2 *t) {
178     /* inv256[i] = -(2*i+1)^-1 (mod 256) */
179     static const uint8_t inv256[128] = {
180         0xFF, 0x55, 0x33, 0x49, 0xC7, 0x5D, 0x3B, 0x11, 0x0F, 0xE5, 0xC3, 0x59,
181         0xD7, 0xED, 0xCB, 0x21, 0x1F, 0x75, 0x53, 0x69, 0xE7, 0x7D, 0x5B, 0x31,
182         0x2F, 0x05, 0xE3, 0x79, 0xF7, 0x0D, 0xEB, 0x41, 0x3F, 0x95, 0x73, 0x89,
183         0x07, 0x9D, 0x7B, 0x51, 0x4F, 0x25, 0x03, 0x99, 0x17, 0x2D, 0x0B, 0x61,
184         0x5F, 0xB5, 0x93, 0xA9, 0x27, 0xBD, 0x9B, 0x71, 0x6F, 0x45, 0x23, 0xB9,
185         0x37, 0x4D, 0x2B, 0x81, 0x7F, 0xD5, 0xB3, 0xC9, 0x47, 0xDD, 0xBB, 0x91,
186         0x8F, 0x65, 0x43, 0xD9, 0x57, 0x6D, 0x4B, 0xA1, 0x9F, 0xF5, 0xD3, 0xE9,
187         0x67, 0xFD, 0xDB, 0xB1, 0xAF, 0x85, 0x63, 0xF9, 0x77, 0x8D, 0x6B, 0xC1,
188         0xBF, 0x15, 0xF3, 0x09, 0x87, 0x1D, 0xFB, 0xD1, 0xCF, 0xA5, 0x83, 0x19,
189         0x97, 0xAD, 0x8B, 0xE1, 0xDF, 0x35, 0x13, 0x29, 0xA7, 0x3D, 0x1B, 0xF1,
190         0xEF, 0xC5, 0xA3, 0x39, 0xB7, 0xCD, 0xAB, 0x01
191     };
192
193     /* Transformation matrix; see comments in secp256k1_modinv32_divsteps_30. */
194     uint32_t u = 1, v = 0, q = 0, r = 1;
195     uint32_t f = f0, g = g0, m;
196     uint16_t w;
197     int i = 30, limit, zeros;
198
199     for (;;) {
200         /* Use a sentinel bit to count zeros only up to i. */
201         zeros = secp256k1_ctz32_var(g | (UINT32_MAX << i));
202         /* Perform zeros divsteps at once; they all just divide g by two. */
203         g >>= zeros;
204         u <<= zeros;
205         v <<= zeros;
206         eta -= zeros;
207         i -= zeros;
208          /* We're done once we've done 30 divsteps. */
209         if (i == 0) break;
210         VERIFY_CHECK((f & 1) == 1);
211         VERIFY_CHECK((g & 1) == 1);
212         VERIFY_CHECK((u * f0 + v * g0) == f << (30 - i));
213         VERIFY_CHECK((q * f0 + r * g0) == g << (30 - i));
214         /* If eta is negative, negate it and replace f,g with g,-f. */
215         if (eta < 0) {
216             uint32_t tmp;
217             eta = -eta;
218             tmp = f; f = g; g = -tmp;
219             tmp = u; u = q; q = -tmp;
220             tmp = v; v = r; r = -tmp;
221         }
222         /* eta is now >= 0. In what follows we're going to cancel out the bottom bits of g. No more
223          * than i can be cancelled out (as we'd be done before that point), and no more than eta+1
224          * can be done as its sign will flip once that happens. */
225         limit = ((int)eta + 1) > i ? i : ((int)eta + 1);
226         /* m is a mask for the bottom min(limit, 8) bits (our table only supports 8 bits). */
227         m = (UINT32_MAX >> (32 - limit)) & 255U;
228         /* Find what multiple of f must be added to g to cancel its bottom min(limit, 8) bits. */
229         w = (g * inv256[(f >> 1) & 127]) & m;
230         /* Do so. */
231         g += f * w;
232         q += u * w;
233         r += v * w;
234         VERIFY_CHECK((g & m) == 0);
235     }
236     /* Return data in t and return value. */
237     t->u = (int32_t)u;
238     t->v = (int32_t)v;
239     t->q = (int32_t)q;
240     t->r = (int32_t)r;
241     return eta;
242 }
243
244 /* Compute (t/2^30) * [d, e] mod modulus, where t is a transition matrix for 30 divsteps.
245  *
246  * On input and output, d and e are in range (-2*modulus,modulus). All output limbs will be in range
247  * (-2^30,2^30).
248  *
249  * This implements the update_de function from the explanation.
250  */
251 static void secp256k1_modinv32_update_de_30(secp256k1_modinv32_signed30 *d, secp256k1_modinv32_signed30 *e, const secp256k1_modinv32_trans2x2 *t, const secp256k1_modinv32_modinfo* modinfo) {
252     const int32_t M30 = (int32_t)(UINT32_MAX >> 2);
253     const int32_t u = t->u, v = t->v, q = t->q, r = t->r;
254     int32_t di, ei, md, me, sd, se;
255     int64_t cd, ce;
256     int i;
257     /* [md,me] start as zero; plus [u,q] if d is negative; plus [v,r] if e is negative. */
258     sd = d->v[8] >> 31;
259     se = e->v[8] >> 31;
260     md = (u & sd) + (v & se);
261     me = (q & sd) + (r & se);
262     /* Begin computing t*[d,e]. */
263     di = d->v[0];
264     ei = e->v[0];
265     cd = (int64_t)u * di + (int64_t)v * ei;
266     ce = (int64_t)q * di + (int64_t)r * ei;
267     /* Correct md,me so that t*[d,e]+modulus*[md,me] has 30 zero bottom bits. */
268     md -= (modinfo->modulus_inv30 * (uint32_t)cd + md) & M30;
269     me -= (modinfo->modulus_inv30 * (uint32_t)ce + me) & M30;
270     /* Update the beginning of computation for t*[d,e]+modulus*[md,me] now md,me are known. */
271     cd += (int64_t)modinfo->modulus.v[0] * md;
272     ce += (int64_t)modinfo->modulus.v[0] * me;
273     /* Verify that the low 30 bits of the computation are indeed zero, and then throw them away. */
274     VERIFY_CHECK(((int32_t)cd & M30) == 0); cd >>= 30;
275     VERIFY_CHECK(((int32_t)ce & M30) == 0); ce >>= 30;
276     /* Now iteratively compute limb i=1..8 of t*[d,e]+modulus*[md,me], and store them in output
277      * limb i-1 (shifting down by 30 bits). */
278     for (i = 1; i < 9; ++i) {
279         di = d->v[i];
280         ei = e->v[i];
281         cd += (int64_t)u * di + (int64_t)v * ei;
282         ce += (int64_t)q * di + (int64_t)r * ei;
283         cd += (int64_t)modinfo->modulus.v[i] * md;
284         ce += (int64_t)modinfo->modulus.v[i] * me;
285         d->v[i - 1] = (int32_t)cd & M30; cd >>= 30;
286         e->v[i - 1] = (int32_t)ce & M30; ce >>= 30;
287     }
288     /* What remains is limb 9 of t*[d,e]+modulus*[md,me]; store it as output limb 8. */
289     d->v[8] = (int32_t)cd;
290     e->v[8] = (int32_t)ce;
291 }
292
293 /* Compute (t/2^30) * [f, g], where t is a transition matrix for 30 divsteps.
294  *
295  * This implements the update_fg function from the explanation.
296  */
297 static void secp256k1_modinv32_update_fg_30(secp256k1_modinv32_signed30 *f, secp256k1_modinv32_signed30 *g, const secp256k1_modinv32_trans2x2 *t) {
298     const int32_t M30 = (int32_t)(UINT32_MAX >> 2);
299     const int32_t u = t->u, v = t->v, q = t->q, r = t->r;
300     int32_t fi, gi;
301     int64_t cf, cg;
302     int i;
303     /* Start computing t*[f,g]. */
304     fi = f->v[0];
305     gi = g->v[0];
306     cf = (int64_t)u * fi + (int64_t)v * gi;
307     cg = (int64_t)q * fi + (int64_t)r * gi;
308     /* Verify that the bottom 30 bits of the result are zero, and then throw them away. */
309     VERIFY_CHECK(((int32_t)cf & M30) == 0); cf >>= 30;
310     VERIFY_CHECK(((int32_t)cg & M30) == 0); cg >>= 30;
311     /* Now iteratively compute limb i=1..8 of t*[f,g], and store them in output limb i-1 (shifting
312      * down by 30 bits). */
313     for (i = 1; i < 9; ++i) {
314         fi = f->v[i];
315         gi = g->v[i];
316         cf += (int64_t)u * fi + (int64_t)v * gi;
317         cg += (int64_t)q * fi + (int64_t)r * gi;
318         f->v[i - 1] = (int32_t)cf & M30; cf >>= 30;
319         g->v[i - 1] = (int32_t)cg & M30; cg >>= 30;
320     }
321     /* What remains is limb 9 of t*[f,g]; store it as output limb 8. */
322     f->v[8] = (int32_t)cf;
323     g->v[8] = (int32_t)cg;
324 }
325
326 /* Compute the inverse of x modulo modinfo->modulus, and replace x with it (constant time in x). */
327 static void secp256k1_modinv32(secp256k1_modinv32_signed30 *x, const secp256k1_modinv32_modinfo *modinfo) {
328     /* Start with d=0, e=1, f=modulus, g=x, eta=-1. */
329     secp256k1_modinv32_signed30 d = {{0}};
330     secp256k1_modinv32_signed30 e = {{1}};
331     secp256k1_modinv32_signed30 f = modinfo->modulus;
332     secp256k1_modinv32_signed30 g = *x;
333     int i;
334     int32_t eta = -1;
335
336     /* Do 25 iterations of 30 divsteps each = 750 divsteps. 724 suffices for 256-bit inputs. */
337     for (i = 0; i < 25; ++i) {
338         /* Compute transition matrix and new eta after 30 divsteps. */
339         secp256k1_modinv32_trans2x2 t;
340         eta = secp256k1_modinv32_divsteps_30(eta, f.v[0], g.v[0], &t);
341         /* Update d,e using that transition matrix. */
342         secp256k1_modinv32_update_de_30(&d, &e, &t, modinfo);
343         /* Update f,g using that transition matrix. */
344         secp256k1_modinv32_update_fg_30(&f, &g, &t);
345     }
346
347     /* At this point sufficient iterations have been performed that g must have reached 0
348      * and (if g was not originally 0) f must now equal +/- GCD of the initial f, g
349      * values i.e. +/- 1, and d now contains +/- the modular inverse. */
350     VERIFY_CHECK((g.v[0] | g.v[1] | g.v[2] | g.v[3] | g.v[4] | g.v[5] | g.v[6] | g.v[7] | g.v[8]) == 0);
351
352     /* Optionally negate d, normalize to [0,modulus), and return it. */
353     secp256k1_modinv32_normalize_30(&d, f.v[8], modinfo);
354     *x = d;
355 }
356
357 /* Compute the inverse of x modulo modinfo->modulus, and replace x with it (variable time). */
358 static void secp256k1_modinv32_var(secp256k1_modinv32_signed30 *x, const secp256k1_modinv32_modinfo *modinfo) {
359     /* Start with d=0, e=1, f=modulus, g=x, eta=-1. */
360     secp256k1_modinv32_signed30 d = {{0, 0, 0, 0, 0, 0, 0, 0, 0}};
361     secp256k1_modinv32_signed30 e = {{1, 0, 0, 0, 0, 0, 0, 0, 0}};
362     secp256k1_modinv32_signed30 f = modinfo->modulus;
363     secp256k1_modinv32_signed30 g = *x;
364     int j;
365     int32_t eta = -1;
366     int32_t cond;
367
368     /* Do iterations of 30 divsteps each until g=0. */
369     while (1) {
370         /* Compute transition matrix and new eta after 30 divsteps. */
371         secp256k1_modinv32_trans2x2 t;
372         eta = secp256k1_modinv32_divsteps_30_var(eta, f.v[0], g.v[0], &t);
373         /* Update d,e using that transition matrix. */
374         secp256k1_modinv32_update_de_30(&d, &e, &t, modinfo);
375         /* Update f,g using that transition matrix. */
376         secp256k1_modinv32_update_fg_30(&f, &g, &t);
377         /* If the bottom limb of g is 0, there is a chance g=0. */
378         if (g.v[0] == 0) {
379             cond = 0;
380             /* Check if the other limbs are also 0. */
381             for (j = 1; j < 9; ++j) {
382                 cond |= g.v[j];
383             }
384             /* If so, we're done. */
385             if (cond == 0) break;
386         }
387     }
388
389     /* At this point g is 0 and (if g was not originally 0) f must now equal +/- GCD of
390      * the initial f, g values i.e. +/- 1, and d now contains +/- the modular inverse. */
391
392     /* Optionally negate d, normalize to [0,modulus), and return it. */
393     secp256k1_modinv32_normalize_30(&d, f.v[8], modinfo);
394     *x = d;
395 }
396
397 #endif /* SECP256K1_MODINV32_IMPL_H */
This page took 0.046363 seconds and 4 git commands to generate.