]> Git Repo - secp256k1.git/blame - doc/safegcd_implementation.md
dont do self test
[secp256k1.git] / doc / safegcd_implementation.md
CommitLineData
d8a92fcc
PW
1# The safegcd implementation in libsecp256k1 explained
2
3This document explains the modular inverse implementation in the `src/modinv*.h` files. It is based
4on the paper
5["Fast constant-time gcd computation and modular inversion"](https://gcd.cr.yp.to/papers.html#safegcd)
6by Daniel J. Bernstein and Bo-Yin Yang. The references below are for the Date: 2019.04.13 version.
7
8The actual implementation is in C of course, but for demonstration purposes Python3 is used here.
9Most implementation aspects and optimizations are explained, except those that depend on the specific
10number representation used in the C code.
11
12## 1. Computing the Greatest Common Divisor (GCD) using divsteps
13
14The algorithm from the paper (section 11), at a very high level, is this:
15
16```python
17def gcd(f, g):
18 """Compute the GCD of an odd integer f and another integer g."""
19 assert f & 1 # require f to be odd
20 delta = 1 # additional state variable
21 while g != 0:
22 assert f & 1 # f will be odd in every iteration
23 if delta > 0 and g & 1:
24 delta, f, g = 1 - delta, g, (g - f) // 2
25 elif g & 1:
26 delta, f, g = 1 + delta, f, (g + f) // 2
27 else:
28 delta, f, g = 1 + delta, f, (g ) // 2
29 return abs(f)
30```
31
32It computes the greatest common divisor of an odd integer *f* and any integer *g*. Its inner loop
33keeps rewriting the variables *f* and *g* alongside a state variable *δ* that starts at *1*, until
34*g=0* is reached. At that point, *|f|* gives the GCD. Each of the transitions in the loop is called a
35"division step" (referred to as divstep in what follows).
36
37For example, *gcd(21, 14)* would be computed as:
38- Start with *δ=1 f=21 g=14*
39- Take the third branch: *δ=2 f=21 g=7*
40- Take the first branch: *δ=-1 f=7 g=-7*
41- Take the second branch: *δ=0 f=7 g=0*
42- The answer *|f| = 7*.
43
44Why it works:
45- Divsteps can be decomposed into two steps (see paragraph 8.2 in the paper):
46 - (a) If *g* is odd, replace *(f,g)* with *(g,g-f)* or (f,g+f), resulting in an even *g*.
47 - (b) Replace *(f,g)* with *(f,g/2)* (where *g* is guaranteed to be even).
48- Neither of those two operations change the GCD:
49 - For (a), assume *gcd(f,g)=c*, then it must be the case that *f=a c* and *g=b c* for some integers *a*
50 and *b*. As *(g,g-f)=(b c,(b-a)c)* and *(f,f+g)=(a c,(a+b)c)*, the result clearly still has
51 common factor *c*. Reasoning in the other direction shows that no common factor can be added by
52 doing so either.
53 - For (b), we know that *f* is odd, so *gcd(f,g)* clearly has no factor *2*, and we can remove
54 it from *g*.
55- The algorithm will eventually converge to *g=0*. This is proven in the paper (see theorem G.3).
56- It follows that eventually we find a final value *f'* for which *gcd(f,g) = gcd(f',0)*. As the
57 gcd of *f'* and *0* is *|f'|* by definition, that is our answer.
58
59Compared to more [traditional GCD algorithms](https://en.wikipedia.org/wiki/Euclidean_algorithm), this one has the property of only ever looking at
60the low-order bits of the variables to decide the next steps, and being easy to make
61constant-time (in more low-level languages than Python). The *δ* parameter is necessary to
62guide the algorithm towards shrinking the numbers' magnitudes without explicitly needing to look
63at high order bits.
64
65Properties that will become important later:
66- Performing more divsteps than needed is not a problem, as *f* does not change anymore after *g=0*.
67- Only even numbers are divided by *2*. This means that when reasoning about it algebraically we
68 do not need to worry about rounding.
69- At every point during the algorithm's execution the next *N* steps only depend on the bottom *N*
70 bits of *f* and *g*, and on *δ*.
71
72
73## 2. From GCDs to modular inverses
74
75We want an algorithm to compute the inverse *a* of *x* modulo *M*, i.e. the number a such that *a x=1
76mod M*. This inverse only exists if the GCD of *x* and *M* is *1*, but that is always the case if *M* is
77prime and *0 < x < M*. In what follows, assume that the modular inverse exists.
78It turns out this inverse can be computed as a side effect of computing the GCD by keeping track
79of how the internal variables can be written as linear combinations of the inputs at every step
80(see the [extended Euclidean algorithm](https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm)).
81Since the GCD is *1*, such an algorithm will compute numbers *a* and *b* such that a&thinsp;x + b&thinsp;M = 1*.
82Taking that expression *mod M* gives *a&thinsp;x mod M = 1*, and we see that *a* is the modular inverse of *x
83mod M*.
84
85A similar approach can be used to calculate modular inverses using the divsteps-based GCD
86algorithm shown above, if the modulus *M* is odd. To do so, compute *gcd(f=M,g=x)*, while keeping
87track of extra variables *d* and *e*, for which at every step *d = f/x (mod M)* and *e = g/x (mod M)*.
88*f/x* here means the number which multiplied with *x* gives *f mod M*. As *f* and *g* are initialized to *M*
89and *x* respectively, *d* and *e* just start off being *0* (*M/x mod M = 0/x mod M = 0*) and *1* (*x/x mod M
90= 1*).
91
92```python
93def div2(M, x):
94 """Helper routine to compute x/2 mod M (where M is odd)."""
95 assert M & 1
96 if x & 1: # If x is odd, make it even by adding M.
97 x += M
98 # x must be even now, so a clean division by 2 is possible.
99 return x // 2
100
101def modinv(M, x):
102 """Compute the inverse of x mod M (given that it exists, and M is odd)."""
103 assert M & 1
104 delta, f, g, d, e = 1, M, x, 0, 1
105 while g != 0:
106 # Note that while division by two for f and g is only ever done on even inputs, this is
107 # not true for d and e, so we need the div2 helper function.
108 if delta > 0 and g & 1:
109 delta, f, g, d, e = 1 - delta, g, (g - f) // 2, e, div2(M, e - d)
110 elif g & 1:
111 delta, f, g, d, e = 1 + delta, f, (g + f) // 2, d, div2(M, e + d)
112 else:
113 delta, f, g, d, e = 1 + delta, f, (g ) // 2, d, div2(M, e )
114 # Verify that the invariants d=f/x mod M, e=g/x mod M are maintained.
115 assert f % M == (d * x) % M
116 assert g % M == (e * x) % M
117 assert f == 1 or f == -1 # |f| is the GCD, it must be 1
118 # Because of invariant d = f/x (mod M), 1/x = d/f (mod M). As |f|=1, d/f = d*f.
119 return (d * f) % M
120```
121
122Also note that this approach to track *d* and *e* throughout the computation to determine the inverse
123is different from the paper. There (see paragraph 12.1 in the paper) a transition matrix for the
124entire computation is determined (see section 3 below) and the inverse is computed from that.
125The approach here avoids the need for 2x2 matrix multiplications of various sizes, and appears to
126be faster at the level of optimization we're able to do in C.
127
128
129## 3. Batching multiple divsteps
130
131Every divstep can be expressed as a matrix multiplication, applying a transition matrix *(1/2 t)*
132to both vectors *[f, g]* and *[d, e]* (see paragraph 8.1 in the paper):
133
134```
135 t = [ u, v ]
136 [ q, r ]
137
138 [ out_f ] = (1/2 * t) * [ in_f ]
139 [ out_g ] = [ in_g ]
140
141 [ out_d ] = (1/2 * t) * [ in_d ] (mod M)
142 [ out_e ] [ in_e ]
143```
144
145where *(u, v, q, r)* is *(0, 2, -1, 1)*, *(2, 0, 1, 1)*, or *(2, 0, 0, 1)*, depending on which branch is
146taken. As above, the resulting *f* and *g* are always integers.
147
148Performing multiple divsteps corresponds to a multiplication with the product of all the
149individual divsteps' transition matrices. As each transition matrix consists of integers
150divided by *2*, the product of these matrices will consist of integers divided by *2<sup>N</sup>* (see also
151theorem 9.2 in the paper). These divisions are expensive when updating *d* and *e*, so we delay
152them: we compute the integer coefficients of the combined transition matrix scaled by *2<sup>N</sup>*, and
153do one division by *2<sup>N</sup>* as a final step:
154
155```python
156def divsteps_n_matrix(delta, f, g):
157 """Compute delta and transition matrix t after N divsteps (multiplied by 2^N)."""
158 u, v, q, r = 1, 0, 0, 1 # start with identity matrix
159 for _ in range(N):
160 if delta > 0 and g & 1:
161 delta, f, g, u, v, q, r = 1 - delta, g, (g - f) // 2, 2*q, 2*r, q-u, r-v
162 elif g & 1:
163 delta, f, g, u, v, q, r = 1 + delta, f, (g + f) // 2, 2*u, 2*v, q+u, r+v
164 else:
165 delta, f, g, u, v, q, r = 1 + delta, f, (g ) // 2, 2*u, 2*v, q , r
166 return delta, (u, v, q, r)
167```
168
169As the branches in the divsteps are completely determined by the bottom *N* bits of *f* and *g*, this
170function to compute the transition matrix only needs to see those bottom bits. Furthermore all
171intermediate results and outputs fit in *(N+1)*-bit numbers (unsigned for *f* and *g*; signed for *u*, *v*,
172*q*, and *r*) (see also paragraph 8.3 in the paper). This means that an implementation using 64-bit
173integers could set *N=62* and compute the full transition matrix for 62 steps at once without any
174big integer arithmetic at all. This is the reason why this algorithm is efficient: it only needs
175to update the full-size *f*, *g*, *d*, and *e* numbers once every *N* steps.
176
177We still need functions to compute:
178
179```
180 [ out_f ] = (1/2^N * [ u, v ]) * [ in_f ]
181 [ out_g ] ( [ q, r ]) [ in_g ]
182
183 [ out_d ] = (1/2^N * [ u, v ]) * [ in_d ] (mod M)
184 [ out_e ] ( [ q, r ]) [ in_e ]
185```
186
187Because the divsteps transformation only ever divides even numbers by two, the result of *t&thinsp;[f,g]* is always even. When *t* is a composition of *N* divsteps, it follows that the resulting *f*
188and *g* will be multiple of *2<sup>N</sup>*, and division by *2<sup>N</sup>* is simply shifting them down:
189
190```python
191def update_fg(f, g, t):
192 """Multiply matrix t/2^N with [f, g]."""
193 u, v, q, r = t
194 cf, cg = u*f + v*g, q*f + r*g
195 # (t / 2^N) should cleanly apply to [f,g] so the result of t*[f,g] should have N zero
196 # bottom bits.
197 assert cf % 2**N == 0
198 assert cg % 2**N == 0
199 return cf >> N, cg >> N
200```
201
202The same is not true for *d* and *e*, and we need an equivalent of the `div2` function for division by *2<sup>N</sup> mod M*.
203This is easy if we have precomputed *1/M mod 2<sup>N</sup>* (which always exists for odd *M*):
204
205```python
206def div2n(M, Mi, x):
207 """Compute x/2^N mod M, given Mi = 1/M mod 2^N."""
208 assert (M * Mi) % 2**N == 1
209 # Find a factor m such that m*M has the same bottom N bits as x. We want:
210 # (m * M) mod 2^N = x mod 2^N
211 # <=> m mod 2^N = (x / M) mod 2^N
212 # <=> m mod 2^N = (x * Mi) mod 2^N
213 m = (Mi * x) % 2**N
214 # Subtract that multiple from x, cancelling its bottom N bits.
215 x -= m * M
216 # Now a clean division by 2^N is possible.
217 assert x % 2**N == 0
218 return (x >> N) % M
219
220def update_de(d, e, t, M, Mi):
221 """Multiply matrix t/2^N with [d, e], modulo M."""
222 u, v, q, r = t
223 cd, ce = u*d + v*e, q*d + r*e
224 return div2n(M, Mi, cd), div2n(M, Mi, ce)
225```
226
227With all of those, we can write a version of `modinv` that performs *N* divsteps at once:
228
229```python3
230def modinv(M, Mi, x):
231 """Compute the modular inverse of x mod M, given Mi=1/M mod 2^N."""
232 assert M & 1
233 delta, f, g, d, e = 1, M, x, 0, 1
234 while g != 0:
235 # Compute the delta and transition matrix t for the next N divsteps (this only needs
236 # (N+1)-bit signed integer arithmetic).
237 delta, t = divsteps_n_matrix(delta, f % 2**N, g % 2**N)
238 # Apply the transition matrix t to [f, g]:
239 f, g = update_fg(f, g, t)
240 # Apply the transition matrix t to [d, e]:
241 d, e = update_de(d, e, t, M, Mi)
242 return (d * f) % M
243```
244
245This means that in practice we'll always perform a multiple of *N* divsteps. This is not a problem
246because once *g=0*, further divsteps do not affect *f*, *g*, *d*, or *e* anymore (only *&delta;* keeps
277b224b
PW
247increasing). For variable time code such excess iterations will be mostly optimized away in later
248sections.
d8a92fcc
PW
249
250
251## 4. Avoiding modulus operations
252
253So far, there are two places where we compute a remainder of big numbers modulo *M*: at the end of
254`div2n` in every `update_de`, and at the very end of `modinv` after potentially negating *d* due to the
255sign of *f*. These are relatively expensive operations when done generically.
256
257To deal with the modulus operation in `div2n`, we simply stop requiring *d* and *e* to be in range
258*[0,M)* all the time. Let's start by inlining `div2n` into `update_de`, and dropping the modulus
259operation at the end:
260
261```python
262def update_de(d, e, t, M, Mi):
263 """Multiply matrix t/2^N with [d, e] mod M, given Mi=1/M mod 2^N."""
264 u, v, q, r = t
265 cd, ce = u*d + v*e, q*d + r*e
266 # Cancel out bottom N bits of cd and ce.
267 md = -((Mi * cd) % 2**N)
268 me = -((Mi * ce) % 2**N)
269 cd += md * M
270 ce += me * M
271 # And cleanly divide by 2**N.
272 return cd >> N, ce >> N
273```
274
275Let's look at bounds on the ranges of these numbers. It can be shown that *|u|+|v|* and *|q|+|r|*
276never exceed *2<sup>N</sup>* (see paragraph 8.3 in the paper), and thus a multiplication with *t* will have
277outputs whose absolute values are at most *2<sup>N</sup>* times the maximum absolute input value. In case the
278inputs *d* and *e* are in *(-M,M)*, which is certainly true for the initial values *d=0* and *e=1* assuming
279*M > 1*, the multiplication results in numbers in range *(-2<sup>N</sup>M,2<sup>N</sup>M)*. Subtracting less than *2<sup>N</sup>*
280times *M* to cancel out *N* bits brings that up to *(-2<sup>N+1</sup>M,2<sup>N</sup>M)*, and
281dividing by *2<sup>N</sup>* at the end takes it to *(-2M,M)*. Another application of `update_de` would take that
282to *(-3M,2M)*, and so forth. This progressive expansion of the variables' ranges can be
283counteracted by incrementing *d* and *e* by *M* whenever they're negative:
284
285```python
286 ...
287 if d < 0:
288 d += M
289 if e < 0:
290 e += M
291 cd, ce = u*d + v*e, q*d + r*e
292 # Cancel out bottom N bits of cd and ce.
293 ...
294```
295
296With inputs in *(-2M,M)*, they will first be shifted into range *(-M,M)*, which means that the
297output will again be in *(-2M,M)*, and this remains the case regardless of how many `update_de`
298invocations there are. In what follows, we will try to make this more efficient.
299
300Note that increasing *d* by *M* is equal to incrementing *cd* by *u&thinsp;M* and *ce* by *q&thinsp;M*. Similarly,
301increasing *e* by *M* is equal to incrementing *cd* by *v&thinsp;M* and *ce* by *r&thinsp;M*. So we could instead write:
302
303```python
304 ...
305 cd, ce = u*d + v*e, q*d + r*e
306 # Perform the equivalent of incrementing d, e by M when they're negative.
307 if d < 0:
308 cd += u*M
309 ce += q*M
310 if e < 0:
311 cd += v*M
312 ce += r*M
313 # Cancel out bottom N bits of cd and ce.
314 md = -((Mi * cd) % 2**N)
315 me = -((Mi * ce) % 2**N)
316 cd += md * M
317 ce += me * M
318 ...
319```
320
321Now note that we have two steps of corrections to *cd* and *ce* that add multiples of *M*: this
322increment, and the decrement that cancels out bottom bits. The second one depends on the first
323one, but they can still be efficiently combined by only computing the bottom bits of *cd* and *ce*
324at first, and using that to compute the final *md*, *me* values:
325
326```python
327def update_de(d, e, t, M, Mi):
328 """Multiply matrix t/2^N with [d, e], modulo M."""
329 u, v, q, r = t
330 md, me = 0, 0
331 # Compute what multiples of M to add to cd and ce.
332 if d < 0:
333 md += u
334 me += q
335 if e < 0:
336 md += v
337 me += r
338 # Compute bottom N bits of t*[d,e] + M*[md,me].
339 cd, ce = (u*d + v*e + md*M) % 2**N, (q*d + r*e + me*M) % 2**N
340 # Correct md and me such that the bottom N bits of t*[d,e] + M*[md,me] are zero.
341 md -= (Mi * cd) % 2**N
342 me -= (Mi * ce) % 2**N
343 # Do the full computation.
344 cd, ce = u*d + v*e + md*M, q*d + r*e + me*M
345 # And cleanly divide by 2**N.
346 return cd >> N, ce >> N
347```
348
349One last optimization: we can avoid the *md&thinsp;M* and *me&thinsp;M* multiplications in the bottom bits of *cd*
350and *ce* by moving them to the *md* and *me* correction:
351
352```python
353 ...
354 # Compute bottom N bits of t*[d,e].
355 cd, ce = (u*d + v*e) % 2**N, (q*d + r*e) % 2**N
356 # Correct md and me such that the bottom N bits of t*[d,e]+M*[md,me] are zero.
357 # Note that this is not the same as {md = (-Mi * cd) % 2**N} etc. That would also result in N
358 # zero bottom bits, but isn't guaranteed to be a reduction of [0,2^N) compared to the
359 # previous md and me values, and thus would violate our bounds analysis.
360 md -= (Mi*cd + md) % 2**N
361 me -= (Mi*ce + me) % 2**N
362 ...
363```
364
365The resulting function takes *d* and *e* in range *(-2M,M)* as inputs, and outputs values in the same
366range. That also means that the *d* value at the end of `modinv` will be in that range, while we want
367a result in *[0,M)*. To do that, we need a normalization function. It's easy to integrate the
368conditional negation of *d* (based on the sign of *f*) into it as well:
369
370```python
371def normalize(sign, v, M):
372 """Compute sign*v mod M, where v is in range (-2*M,M); output in [0,M)."""
373 assert sign == 1 or sign == -1
374 # v in (-2*M,M)
375 if v < 0:
376 v += M
377 # v in (-M,M). Now multiply v with sign (which can only be 1 or -1).
378 if sign == -1:
379 v = -v
380 # v in (-M,M)
381 if v < 0:
382 v += M
383 # v in [0,M)
384 return v
385```
386
387And calling it in `modinv` is simply:
388
389```python
390 ...
391 return normalize(f, d, M)
392```
393
394
395## 5. Constant-time operation
396
397The primary selling point of the algorithm is fast constant-time operation. What code flow still
398depends on the input data so far?
399
400- the number of iterations of the while *g &ne; 0* loop in `modinv`
401- the branches inside `divsteps_n_matrix`
402- the sign checks in `update_de`
403- the sign checks in `normalize`
404
405To make the while loop in `modinv` constant time it can be replaced with a constant number of
406iterations. The paper proves (Theorem 11.2) that *741* divsteps are sufficient for any *256*-bit
407inputs, and [safegcd-bounds](https://github.com/sipa/safegcd-bounds) shows that the slightly better bound *724* is
408sufficient even. Given that every loop iteration performs *N* divsteps, it will run a total of
409*&lceil;724/N&rceil;* times.
410
411To deal with the branches in `divsteps_n_matrix` we will replace them with constant-time bitwise
412operations (and hope the C compiler isn't smart enough to turn them back into branches; see
413`valgrind_ctime_test.c` for automated tests that this isn't the case). To do so, observe that a
414divstep can be written instead as (compare to the inner loop of `gcd` in section 1).
415
416```python
417 x = -f if delta > 0 else f # set x equal to (input) -f or f
418 if g & 1:
419 g += x # set g to (input) g-f or g+f
420 if delta > 0:
421 delta = -delta
422 f += g # set f to (input) g (note that g was set to g-f before)
423 delta += 1
424 g >>= 1
425```
426
427To convert the above to bitwise operations, we rely on a trick to negate conditionally: per the
428definition of negative numbers in two's complement, (*-v == ~v + 1*) holds for every number *v*. As
429*-1* in two's complement is all *1* bits, bitflipping can be expressed as xor with *-1*. It follows
430that *-v == (v ^ -1) - (-1)*. Thus, if we have a variable *c* that takes on values *0* or *-1*, then
431*(v ^ c) - c* is *v* if *c=0* and *-v* if *c=-1*.
432
433Using this we can write:
434
435```python
436 x = -f if delta > 0 else f
437```
438
439in constant-time form as:
440
441```python
442 c1 = (-delta) >> 63
443 # Conditionally negate f based on c1:
444 x = (f ^ c1) - c1
445```
446
447To use that trick, we need a helper mask variable *c1* that resolves the condition *&delta;>0* to *-1*
448(if true) or *0* (if false). We compute *c1* using right shifting, which is equivalent to dividing by
449the specified power of *2* and rounding down (in Python, and also in C under the assumption of a typical two's complement system; see
450`assumptions.h` for tests that this is the case). Right shifting by *63* thus maps all
451numbers in range *[-2<sup>63</sup>,0)* to *-1*, and numbers in range *[0,2<sup>63</sup>)* to *0*.
452
453Using the facts that *x&0=0* and *x&(-1)=x* (on two's complement systems again), we can write:
454
455```python
456 if g & 1:
457 g += x
458```
459
460as:
461
462```python
463 # Compute c2=0 if g is even and c2=-1 if g is odd.
464 c2 = -(g & 1)
465 # This masks out x if g is even, and leaves x be if g is odd.
466 g += x & c2
467```
468
469Using the conditional negation trick again we can write:
470
471```python
472 if g & 1:
473 if delta > 0:
474 delta = -delta
475```
476
477as:
478
479```python
480 # Compute c3=-1 if g is odd and delta>0, and 0 otherwise.
481 c3 = c1 & c2
482 # Conditionally negate delta based on c3:
483 delta = (delta ^ c3) - c3
484```
485
486Finally:
487
488```python
489 if g & 1:
490 if delta > 0:
491 f += g
492```
493
494becomes:
495
496```python
497 f += g & c3
498```
499
500It turns out that this can be implemented more efficiently by applying the substitution
501*&eta;=-&delta;*. In this representation, negating *&delta;* corresponds to negating *&eta;*, and incrementing
502*&delta;* corresponds to decrementing *&eta;*. This allows us to remove the negation in the *c1*
503computation:
504
505```python
506 # Compute a mask c1 for eta < 0, and compute the conditional negation x of f:
507 c1 = eta >> 63
508 x = (f ^ c1) - c1
509 # Compute a mask c2 for odd g, and conditionally add x to g:
510 c2 = -(g & 1)
511 g += x & c2
512 # Compute a mask c for (eta < 0) and odd (input) g, and use it to conditionally negate eta,
513 # and add g to f:
514 c3 = c1 & c2
515 eta = (eta ^ c3) - c3
516 f += g & c3
517 # Incrementing delta corresponds to decrementing eta.
518 eta -= 1
519 g >>= 1
520```
521
277b224b
PW
522A variant of divsteps with better worst-case performance can be used instead: starting *&delta;* at
523*1/2* instead of *1*. This reduces the worst case number of iterations to *590* for *256*-bit inputs
524(which can be shown using convex hull analysis). In this case, the substitution *&zeta;=-(&delta;+1/2)*
525is used instead to keep the variable integral. Incrementing *&delta;* by *1* still translates to
526decrementing *&zeta;* by *1*, but negating *&delta;* now corresponds to going from *&zeta;* to *-(&zeta;+1)*, or
527*~&zeta;*. Doing that conditionally based on *c3* is simply:
528
529```python
530 ...
531 c3 = c1 & c2
532 zeta ^= c3
533 ...
534```
535
d8a92fcc
PW
536By replacing the loop in `divsteps_n_matrix` with a variant of the divstep code above (extended to
537also apply all *f* operations to *u*, *v* and all *g* operations to *q*, *r*), a constant-time version of
538`divsteps_n_matrix` is obtained. The full code will be in section 7.
539
540These bit fiddling tricks can also be used to make the conditional negations and additions in
541`update_de` and `normalize` constant-time.
542
543
544## 6. Variable-time optimizations
545
546In section 5, we modified the `divsteps_n_matrix` function (and a few others) to be constant time.
547Constant time operations are only necessary when computing modular inverses of secret data. In
548other cases, it slows down calculations unnecessarily. In this section, we will construct a
549faster non-constant time `divsteps_n_matrix` function.
550
551To do so, first consider yet another way of writing the inner loop of divstep operations in
277b224b
PW
552`gcd` from section 1. This decomposition is also explained in the paper in section 8.2. We use
553the original version with initial *&delta;=1* and *&eta;=-&delta;* here.
d8a92fcc
PW
554
555```python
556for _ in range(N):
557 if g & 1 and eta < 0:
558 eta, f, g = -eta, g, -f
559 if g & 1:
560 g += f
561 eta -= 1
562 g >>= 1
563```
564
565Whenever *g* is even, the loop only shifts *g* down and decreases *&eta;*. When *g* ends in multiple zero
566bits, these iterations can be consolidated into one step. This requires counting the bottom zero
567bits efficiently, which is possible on most platforms; it is abstracted here as the function
568`count_trailing_zeros`.
569
570```python
571def count_trailing_zeros(v):
572 """For a non-zero value v, find z such that v=(d<<z) for some odd d."""
573 return (v & -v).bit_length() - 1
574
575i = N # divsteps left to do
576while True:
577 # Get rid of all bottom zeros at once. In the first iteration, g may be odd and the following
578 # lines have no effect (until "if eta < 0").
579 zeros = min(i, count_trailing_zeros(g))
580 eta -= zeros
581 g >>= zeros
582 i -= zeros
583 if i == 0:
584 break
585 # We know g is odd now
586 if eta < 0:
587 eta, f, g = -eta, g, -f
588 g += f
589 # g is even now, and the eta decrement and g shift will happen in the next loop.
590```
591
592We can now remove multiple bottom *0* bits from *g* at once, but still need a full iteration whenever
593there is a bottom *1* bit. In what follows, we will get rid of multiple *1* bits simultaneously as
594well.
595
596Observe that as long as *&eta; &geq; 0*, the loop does not modify *f*. Instead, it cancels out bottom
597bits of *g* and shifts them out, and decreases *&eta;* and *i* accordingly - interrupting only when *&eta;*
598becomes negative, or when *i* reaches *0*. Combined, this is equivalent to adding a multiple of *f* to
599*g* to cancel out multiple bottom bits, and then shifting them out.
600
601It is easy to find what that multiple is: we want a number *w* such that *g+w&thinsp;f* has a few bottom
602zero bits. If that number of bits is *L*, we want *g+w&thinsp;f mod 2<sup>L</sup> = 0*, or *w = -g/f mod 2<sup>L</sup>*. Since *f*
603is odd, such a *w* exists for any *L*. *L* cannot be more than *i* steps (as we'd finish the loop before
604doing more) or more than *&eta;+1* steps (as we'd run `eta, f, g = -eta, g, f` at that point), but
605apart from that, we're only limited by the complexity of computing *w*.
606
607This code demonstrates how to cancel up to 4 bits per step:
608
609```python
610NEGINV16 = [15, 5, 3, 9, 7, 13, 11, 1] # NEGINV16[n//2] = (-n)^-1 mod 16, for odd n
611i = N
612while True:
613 zeros = min(i, count_trailing_zeros(g))
614 eta -= zeros
615 g >>= zeros
616 i -= zeros
617 if i == 0:
618 break
619 # We know g is odd now
620 if eta < 0:
621 eta, f, g = -eta, g, f
622 # Compute limit on number of bits to cancel
623 limit = min(min(eta + 1, i), 4)
624 # Compute w = -g/f mod 2**limit, using the table value for -1/f mod 2**4. Note that f is
625 # always odd, so its inverse modulo a power of two always exists.
626 w = (g * NEGINV16[(f & 15) // 2]) % (2**limit)
627 # As w = -g/f mod (2**limit), g+w*f mod 2**limit = 0 mod 2**limit.
628 g += w * f
629 assert g % (2**limit) == 0
630 # The next iteration will now shift out at least limit bottom zero bits from g.
631```
632
633By using a bigger table more bits can be cancelled at once. The table can also be implemented
634as a formula. Several formulas are known for computing modular inverses modulo powers of two;
635some can be found in Hacker's Delight second edition by Henry S. Warren, Jr. pages 245-247.
636Here we need the negated modular inverse, which is a simple transformation of those:
637
638- Instead of a 3-bit table:
639 - *-f* or *f ^ 6*
640- Instead of a 4-bit table:
641 - *1 - f(f + 1)*
642 - *-(f + (((f + 1) & 4) << 1))*
643- For larger tables the following technique can be used: if *w=-1/f mod 2<sup>L</sup>*, then *w(w&thinsp;f+2)* is
644 *-1/f mod 2<sup>2L</sup>*. This allows extending the previous formulas (or tables). In particular we
645 have this 6-bit function (based on the 3-bit function above):
646 - *f(f<sup>2</sup> - 2)*
647
648This loop, again extended to also handle *u*, *v*, *q*, and *r* alongside *f* and *g*, placed in
649`divsteps_n_matrix`, gives a significantly faster, but non-constant time version.
650
651
652## 7. Final Python version
653
654All together we need the following functions:
655
656- A way to compute the transition matrix in constant time, using the `divsteps_n_matrix` function
657 from section 2, but with its loop replaced by a variant of the constant-time divstep from
658 section 5, extended to handle *u*, *v*, *q*, *r*:
659
660```python
277b224b
PW
661def divsteps_n_matrix(zeta, f, g):
662 """Compute zeta and transition matrix t after N divsteps (multiplied by 2^N)."""
d8a92fcc
PW
663 u, v, q, r = 1, 0, 0, 1 # start with identity matrix
664 for _ in range(N):
277b224b 665 c1 = zeta >> 63
d8a92fcc
PW
666 # Compute x, y, z as conditionally-negated versions of f, u, v.
667 x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1
668 c2 = -(g & 1)
669 # Conditionally add x, y, z to g, q, r.
670 g, q, r = g + (x & c2), q + (y & c2), r + (z & c2)
671 c1 &= c2 # reusing c1 here for the earlier c3 variable
277b224b 672 zeta = (zeta ^ c1) - 1 # inlining the unconditional zeta decrement here
d8a92fcc
PW
673 # Conditionally add g, q, r to f, u, v.
674 f, u, v = f + (g & c1), u + (q & c1), v + (r & c1)
675 # When shifting g down, don't shift q, r, as we construct a transition matrix multiplied
676 # by 2^N. Instead, shift f's coefficients u and v up.
677 g, u, v = g >> 1, u << 1, v << 1
277b224b 678 return zeta, (u, v, q, r)
d8a92fcc
PW
679```
680
681- The functions to update *f* and *g*, and *d* and *e*, from section 2 and section 4, with the constant-time
682 changes to `update_de` from section 5:
683
684```python
685def update_fg(f, g, t):
686 """Multiply matrix t/2^N with [f, g]."""
687 u, v, q, r = t
688 cf, cg = u*f + v*g, q*f + r*g
689 return cf >> N, cg >> N
690
691def update_de(d, e, t, M, Mi):
692 """Multiply matrix t/2^N with [d, e], modulo M."""
693 u, v, q, r = t
694 d_sign, e_sign = d >> 257, e >> 257
695 md, me = (u & d_sign) + (v & e_sign), (q & d_sign) + (r & e_sign)
696 cd, ce = (u*d + v*e) % 2**N, (q*d + r*e) % 2**N
697 md -= (Mi*cd + md) % 2**N
698 me -= (Mi*ce + me) % 2**N
376ca366 699 cd, ce = u*d + v*e + M*md, q*d + r*e + M*me
d8a92fcc
PW
700 return cd >> N, ce >> N
701```
702
703- The `normalize` function from section 4, made constant time as well:
704
705```python
706def normalize(sign, v, M):
707 """Compute sign*v mod M, where v in (-2*M,M); output in [0,M)."""
708 v_sign = v >> 257
709 # Conditionally add M to v.
710 v += M & v_sign
711 c = (sign - 1) >> 1
712 # Conditionally negate v.
713 v = (v ^ c) - c
714 v_sign = v >> 257
715 # Conditionally add M to v again.
716 v += M & v_sign
717 return v
718```
719
277b224b 720- And finally the `modinv` function too, adapted to use *&zeta;* instead of *&delta;*, and using the fixed
d8a92fcc
PW
721 iteration count from section 5:
722
723```python
724def modinv(M, Mi, x):
725 """Compute the modular inverse of x mod M, given Mi=1/M mod 2^N."""
277b224b
PW
726 zeta, f, g, d, e = -1, M, x, 0, 1
727 for _ in range((590 + N - 1) // N):
728 zeta, t = divsteps_n_matrix(zeta, f % 2**N, g % 2**N)
d8a92fcc
PW
729 f, g = update_fg(f, g, t)
730 d, e = update_de(d, e, t, M, Mi)
731 return normalize(f, d, M)
732```
733
734- To get a variable time version, replace the `divsteps_n_matrix` function with one that uses the
735 divsteps loop from section 5, and a `modinv` version that calls it without the fixed iteration
736 count:
737
738```python
739NEGINV16 = [15, 5, 3, 9, 7, 13, 11, 1] # NEGINV16[n//2] = (-n)^-1 mod 16, for odd n
740def divsteps_n_matrix_var(eta, f, g):
741 """Compute eta and transition matrix t after N divsteps (multiplied by 2^N)."""
742 u, v, q, r = 1, 0, 0, 1
743 i = N
744 while True:
745 zeros = min(i, count_trailing_zeros(g))
746 eta, i = eta - zeros, i - zeros
747 g, u, v = g >> zeros, u << zeros, v << zeros
748 if i == 0:
749 break
750 if eta < 0:
751 eta, f, u, v, g, q, r = -eta, g, q, r, -f, -u, -v
752 limit = min(min(eta + 1, i), 4)
753 w = (g * NEGINV16[(f & 15) // 2]) % (2**limit)
754 g, q, r = g + w*f, q + w*u, r + w*v
755 return eta, (u, v, q, r)
756
757def modinv_var(M, Mi, x):
758 """Compute the modular inverse of x mod M, given Mi = 1/M mod 2^N."""
759 eta, f, g, d, e = -1, M, x, 0, 1
760 while g != 0:
761 eta, t = divsteps_n_matrix_var(eta, f % 2**N, g % 2**N)
762 f, g = update_fg(f, g, t)
763 d, e = update_de(d, e, t, M, Mi)
764 return normalize(f, d, Mi)
765```
This page took 0.10388 seconds and 5 git commands to generate.