]>
Commit | Line | Data |
---|---|---|
d8a92fcc PW |
1 | # The safegcd implementation in libsecp256k1 explained |
2 | ||
3 | This document explains the modular inverse implementation in the `src/modinv*.h` files. It is based | |
4 | on the paper | |
5 | ["Fast constant-time gcd computation and modular inversion"](https://gcd.cr.yp.to/papers.html#safegcd) | |
6 | by Daniel J. Bernstein and Bo-Yin Yang. The references below are for the Date: 2019.04.13 version. | |
7 | ||
8 | The actual implementation is in C of course, but for demonstration purposes Python3 is used here. | |
9 | Most implementation aspects and optimizations are explained, except those that depend on the specific | |
10 | number representation used in the C code. | |
11 | ||
12 | ## 1. Computing the Greatest Common Divisor (GCD) using divsteps | |
13 | ||
14 | The algorithm from the paper (section 11), at a very high level, is this: | |
15 | ||
16 | ```python | |
17 | def gcd(f, g): | |
18 | """Compute the GCD of an odd integer f and another integer g.""" | |
19 | assert f & 1 # require f to be odd | |
20 | delta = 1 # additional state variable | |
21 | while g != 0: | |
22 | assert f & 1 # f will be odd in every iteration | |
23 | if delta > 0 and g & 1: | |
24 | delta, f, g = 1 - delta, g, (g - f) // 2 | |
25 | elif g & 1: | |
26 | delta, f, g = 1 + delta, f, (g + f) // 2 | |
27 | else: | |
28 | delta, f, g = 1 + delta, f, (g ) // 2 | |
29 | return abs(f) | |
30 | ``` | |
31 | ||
32 | It computes the greatest common divisor of an odd integer *f* and any integer *g*. Its inner loop | |
33 | keeps rewriting the variables *f* and *g* alongside a state variable *δ* that starts at *1*, until | |
34 | *g=0* is reached. At that point, *|f|* gives the GCD. Each of the transitions in the loop is called a | |
35 | "division step" (referred to as divstep in what follows). | |
36 | ||
37 | For example, *gcd(21, 14)* would be computed as: | |
38 | - Start with *δ=1 f=21 g=14* | |
39 | - Take the third branch: *δ=2 f=21 g=7* | |
40 | - Take the first branch: *δ=-1 f=7 g=-7* | |
41 | - Take the second branch: *δ=0 f=7 g=0* | |
42 | - The answer *|f| = 7*. | |
43 | ||
44 | Why it works: | |
45 | - Divsteps can be decomposed into two steps (see paragraph 8.2 in the paper): | |
46 | - (a) If *g* is odd, replace *(f,g)* with *(g,g-f)* or (f,g+f), resulting in an even *g*. | |
47 | - (b) Replace *(f,g)* with *(f,g/2)* (where *g* is guaranteed to be even). | |
48 | - Neither of those two operations change the GCD: | |
49 | - For (a), assume *gcd(f,g)=c*, then it must be the case that *f=a c* and *g=b c* for some integers *a* | |
50 | and *b*. As *(g,g-f)=(b c,(b-a)c)* and *(f,f+g)=(a c,(a+b)c)*, the result clearly still has | |
51 | common factor *c*. Reasoning in the other direction shows that no common factor can be added by | |
52 | doing so either. | |
53 | - For (b), we know that *f* is odd, so *gcd(f,g)* clearly has no factor *2*, and we can remove | |
54 | it from *g*. | |
55 | - The algorithm will eventually converge to *g=0*. This is proven in the paper (see theorem G.3). | |
56 | - It follows that eventually we find a final value *f'* for which *gcd(f,g) = gcd(f',0)*. As the | |
57 | gcd of *f'* and *0* is *|f'|* by definition, that is our answer. | |
58 | ||
59 | Compared to more [traditional GCD algorithms](https://en.wikipedia.org/wiki/Euclidean_algorithm), this one has the property of only ever looking at | |
60 | the low-order bits of the variables to decide the next steps, and being easy to make | |
61 | constant-time (in more low-level languages than Python). The *δ* parameter is necessary to | |
62 | guide the algorithm towards shrinking the numbers' magnitudes without explicitly needing to look | |
63 | at high order bits. | |
64 | ||
65 | Properties that will become important later: | |
66 | - Performing more divsteps than needed is not a problem, as *f* does not change anymore after *g=0*. | |
67 | - Only even numbers are divided by *2*. This means that when reasoning about it algebraically we | |
68 | do not need to worry about rounding. | |
69 | - At every point during the algorithm's execution the next *N* steps only depend on the bottom *N* | |
70 | bits of *f* and *g*, and on *δ*. | |
71 | ||
72 | ||
73 | ## 2. From GCDs to modular inverses | |
74 | ||
75 | We want an algorithm to compute the inverse *a* of *x* modulo *M*, i.e. the number a such that *a x=1 | |
76 | mod M*. This inverse only exists if the GCD of *x* and *M* is *1*, but that is always the case if *M* is | |
77 | prime and *0 < x < M*. In what follows, assume that the modular inverse exists. | |
78 | It turns out this inverse can be computed as a side effect of computing the GCD by keeping track | |
79 | of how the internal variables can be written as linear combinations of the inputs at every step | |
80 | (see the [extended Euclidean algorithm](https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm)). | |
81 | Since the GCD is *1*, such an algorithm will compute numbers *a* and *b* such that a x + b M = 1*. | |
82 | Taking that expression *mod M* gives *a x mod M = 1*, and we see that *a* is the modular inverse of *x | |
83 | mod M*. | |
84 | ||
85 | A similar approach can be used to calculate modular inverses using the divsteps-based GCD | |
86 | algorithm shown above, if the modulus *M* is odd. To do so, compute *gcd(f=M,g=x)*, while keeping | |
87 | track of extra variables *d* and *e*, for which at every step *d = f/x (mod M)* and *e = g/x (mod M)*. | |
88 | *f/x* here means the number which multiplied with *x* gives *f mod M*. As *f* and *g* are initialized to *M* | |
89 | and *x* respectively, *d* and *e* just start off being *0* (*M/x mod M = 0/x mod M = 0*) and *1* (*x/x mod M | |
90 | = 1*). | |
91 | ||
92 | ```python | |
93 | def div2(M, x): | |
94 | """Helper routine to compute x/2 mod M (where M is odd).""" | |
95 | assert M & 1 | |
96 | if x & 1: # If x is odd, make it even by adding M. | |
97 | x += M | |
98 | # x must be even now, so a clean division by 2 is possible. | |
99 | return x // 2 | |
100 | ||
101 | def modinv(M, x): | |
102 | """Compute the inverse of x mod M (given that it exists, and M is odd).""" | |
103 | assert M & 1 | |
104 | delta, f, g, d, e = 1, M, x, 0, 1 | |
105 | while g != 0: | |
106 | # Note that while division by two for f and g is only ever done on even inputs, this is | |
107 | # not true for d and e, so we need the div2 helper function. | |
108 | if delta > 0 and g & 1: | |
109 | delta, f, g, d, e = 1 - delta, g, (g - f) // 2, e, div2(M, e - d) | |
110 | elif g & 1: | |
111 | delta, f, g, d, e = 1 + delta, f, (g + f) // 2, d, div2(M, e + d) | |
112 | else: | |
113 | delta, f, g, d, e = 1 + delta, f, (g ) // 2, d, div2(M, e ) | |
114 | # Verify that the invariants d=f/x mod M, e=g/x mod M are maintained. | |
115 | assert f % M == (d * x) % M | |
116 | assert g % M == (e * x) % M | |
117 | assert f == 1 or f == -1 # |f| is the GCD, it must be 1 | |
118 | # Because of invariant d = f/x (mod M), 1/x = d/f (mod M). As |f|=1, d/f = d*f. | |
119 | return (d * f) % M | |
120 | ``` | |
121 | ||
122 | Also note that this approach to track *d* and *e* throughout the computation to determine the inverse | |
123 | is different from the paper. There (see paragraph 12.1 in the paper) a transition matrix for the | |
124 | entire computation is determined (see section 3 below) and the inverse is computed from that. | |
125 | The approach here avoids the need for 2x2 matrix multiplications of various sizes, and appears to | |
126 | be faster at the level of optimization we're able to do in C. | |
127 | ||
128 | ||
129 | ## 3. Batching multiple divsteps | |
130 | ||
131 | Every divstep can be expressed as a matrix multiplication, applying a transition matrix *(1/2 t)* | |
132 | to both vectors *[f, g]* and *[d, e]* (see paragraph 8.1 in the paper): | |
133 | ||
134 | ``` | |
135 | t = [ u, v ] | |
136 | [ q, r ] | |
137 | ||
138 | [ out_f ] = (1/2 * t) * [ in_f ] | |
139 | [ out_g ] = [ in_g ] | |
140 | ||
141 | [ out_d ] = (1/2 * t) * [ in_d ] (mod M) | |
142 | [ out_e ] [ in_e ] | |
143 | ``` | |
144 | ||
145 | where *(u, v, q, r)* is *(0, 2, -1, 1)*, *(2, 0, 1, 1)*, or *(2, 0, 0, 1)*, depending on which branch is | |
146 | taken. As above, the resulting *f* and *g* are always integers. | |
147 | ||
148 | Performing multiple divsteps corresponds to a multiplication with the product of all the | |
149 | individual divsteps' transition matrices. As each transition matrix consists of integers | |
150 | divided by *2*, the product of these matrices will consist of integers divided by *2<sup>N</sup>* (see also | |
151 | theorem 9.2 in the paper). These divisions are expensive when updating *d* and *e*, so we delay | |
152 | them: we compute the integer coefficients of the combined transition matrix scaled by *2<sup>N</sup>*, and | |
153 | do one division by *2<sup>N</sup>* as a final step: | |
154 | ||
155 | ```python | |
156 | def divsteps_n_matrix(delta, f, g): | |
157 | """Compute delta and transition matrix t after N divsteps (multiplied by 2^N).""" | |
158 | u, v, q, r = 1, 0, 0, 1 # start with identity matrix | |
159 | for _ in range(N): | |
160 | if delta > 0 and g & 1: | |
161 | delta, f, g, u, v, q, r = 1 - delta, g, (g - f) // 2, 2*q, 2*r, q-u, r-v | |
162 | elif g & 1: | |
163 | delta, f, g, u, v, q, r = 1 + delta, f, (g + f) // 2, 2*u, 2*v, q+u, r+v | |
164 | else: | |
165 | delta, f, g, u, v, q, r = 1 + delta, f, (g ) // 2, 2*u, 2*v, q , r | |
166 | return delta, (u, v, q, r) | |
167 | ``` | |
168 | ||
169 | As the branches in the divsteps are completely determined by the bottom *N* bits of *f* and *g*, this | |
170 | function to compute the transition matrix only needs to see those bottom bits. Furthermore all | |
171 | intermediate results and outputs fit in *(N+1)*-bit numbers (unsigned for *f* and *g*; signed for *u*, *v*, | |
172 | *q*, and *r*) (see also paragraph 8.3 in the paper). This means that an implementation using 64-bit | |
173 | integers could set *N=62* and compute the full transition matrix for 62 steps at once without any | |
174 | big integer arithmetic at all. This is the reason why this algorithm is efficient: it only needs | |
175 | to update the full-size *f*, *g*, *d*, and *e* numbers once every *N* steps. | |
176 | ||
177 | We still need functions to compute: | |
178 | ||
179 | ``` | |
180 | [ out_f ] = (1/2^N * [ u, v ]) * [ in_f ] | |
181 | [ out_g ] ( [ q, r ]) [ in_g ] | |
182 | ||
183 | [ out_d ] = (1/2^N * [ u, v ]) * [ in_d ] (mod M) | |
184 | [ out_e ] ( [ q, r ]) [ in_e ] | |
185 | ``` | |
186 | ||
187 | Because the divsteps transformation only ever divides even numbers by two, the result of *t [f,g]* is always even. When *t* is a composition of *N* divsteps, it follows that the resulting *f* | |
188 | and *g* will be multiple of *2<sup>N</sup>*, and division by *2<sup>N</sup>* is simply shifting them down: | |
189 | ||
190 | ```python | |
191 | def update_fg(f, g, t): | |
192 | """Multiply matrix t/2^N with [f, g].""" | |
193 | u, v, q, r = t | |
194 | cf, cg = u*f + v*g, q*f + r*g | |
195 | # (t / 2^N) should cleanly apply to [f,g] so the result of t*[f,g] should have N zero | |
196 | # bottom bits. | |
197 | assert cf % 2**N == 0 | |
198 | assert cg % 2**N == 0 | |
199 | return cf >> N, cg >> N | |
200 | ``` | |
201 | ||
202 | The same is not true for *d* and *e*, and we need an equivalent of the `div2` function for division by *2<sup>N</sup> mod M*. | |
203 | This is easy if we have precomputed *1/M mod 2<sup>N</sup>* (which always exists for odd *M*): | |
204 | ||
205 | ```python | |
206 | def div2n(M, Mi, x): | |
207 | """Compute x/2^N mod M, given Mi = 1/M mod 2^N.""" | |
208 | assert (M * Mi) % 2**N == 1 | |
209 | # Find a factor m such that m*M has the same bottom N bits as x. We want: | |
210 | # (m * M) mod 2^N = x mod 2^N | |
211 | # <=> m mod 2^N = (x / M) mod 2^N | |
212 | # <=> m mod 2^N = (x * Mi) mod 2^N | |
213 | m = (Mi * x) % 2**N | |
214 | # Subtract that multiple from x, cancelling its bottom N bits. | |
215 | x -= m * M | |
216 | # Now a clean division by 2^N is possible. | |
217 | assert x % 2**N == 0 | |
218 | return (x >> N) % M | |
219 | ||
220 | def update_de(d, e, t, M, Mi): | |
221 | """Multiply matrix t/2^N with [d, e], modulo M.""" | |
222 | u, v, q, r = t | |
223 | cd, ce = u*d + v*e, q*d + r*e | |
224 | return div2n(M, Mi, cd), div2n(M, Mi, ce) | |
225 | ``` | |
226 | ||
227 | With all of those, we can write a version of `modinv` that performs *N* divsteps at once: | |
228 | ||
229 | ```python3 | |
230 | def modinv(M, Mi, x): | |
231 | """Compute the modular inverse of x mod M, given Mi=1/M mod 2^N.""" | |
232 | assert M & 1 | |
233 | delta, f, g, d, e = 1, M, x, 0, 1 | |
234 | while g != 0: | |
235 | # Compute the delta and transition matrix t for the next N divsteps (this only needs | |
236 | # (N+1)-bit signed integer arithmetic). | |
237 | delta, t = divsteps_n_matrix(delta, f % 2**N, g % 2**N) | |
238 | # Apply the transition matrix t to [f, g]: | |
239 | f, g = update_fg(f, g, t) | |
240 | # Apply the transition matrix t to [d, e]: | |
241 | d, e = update_de(d, e, t, M, Mi) | |
242 | return (d * f) % M | |
243 | ``` | |
244 | ||
245 | This means that in practice we'll always perform a multiple of *N* divsteps. This is not a problem | |
246 | because once *g=0*, further divsteps do not affect *f*, *g*, *d*, or *e* anymore (only *δ* keeps | |
277b224b PW |
247 | increasing). For variable time code such excess iterations will be mostly optimized away in later |
248 | sections. | |
d8a92fcc PW |
249 | |
250 | ||
251 | ## 4. Avoiding modulus operations | |
252 | ||
253 | So far, there are two places where we compute a remainder of big numbers modulo *M*: at the end of | |
254 | `div2n` in every `update_de`, and at the very end of `modinv` after potentially negating *d* due to the | |
255 | sign of *f*. These are relatively expensive operations when done generically. | |
256 | ||
257 | To deal with the modulus operation in `div2n`, we simply stop requiring *d* and *e* to be in range | |
258 | *[0,M)* all the time. Let's start by inlining `div2n` into `update_de`, and dropping the modulus | |
259 | operation at the end: | |
260 | ||
261 | ```python | |
262 | def update_de(d, e, t, M, Mi): | |
263 | """Multiply matrix t/2^N with [d, e] mod M, given Mi=1/M mod 2^N.""" | |
264 | u, v, q, r = t | |
265 | cd, ce = u*d + v*e, q*d + r*e | |
266 | # Cancel out bottom N bits of cd and ce. | |
267 | md = -((Mi * cd) % 2**N) | |
268 | me = -((Mi * ce) % 2**N) | |
269 | cd += md * M | |
270 | ce += me * M | |
271 | # And cleanly divide by 2**N. | |
272 | return cd >> N, ce >> N | |
273 | ``` | |
274 | ||
275 | Let's look at bounds on the ranges of these numbers. It can be shown that *|u|+|v|* and *|q|+|r|* | |
276 | never exceed *2<sup>N</sup>* (see paragraph 8.3 in the paper), and thus a multiplication with *t* will have | |
277 | outputs whose absolute values are at most *2<sup>N</sup>* times the maximum absolute input value. In case the | |
278 | inputs *d* and *e* are in *(-M,M)*, which is certainly true for the initial values *d=0* and *e=1* assuming | |
279 | *M > 1*, the multiplication results in numbers in range *(-2<sup>N</sup>M,2<sup>N</sup>M)*. Subtracting less than *2<sup>N</sup>* | |
280 | times *M* to cancel out *N* bits brings that up to *(-2<sup>N+1</sup>M,2<sup>N</sup>M)*, and | |
281 | dividing by *2<sup>N</sup>* at the end takes it to *(-2M,M)*. Another application of `update_de` would take that | |
282 | to *(-3M,2M)*, and so forth. This progressive expansion of the variables' ranges can be | |
283 | counteracted by incrementing *d* and *e* by *M* whenever they're negative: | |
284 | ||
285 | ```python | |
286 | ... | |
287 | if d < 0: | |
288 | d += M | |
289 | if e < 0: | |
290 | e += M | |
291 | cd, ce = u*d + v*e, q*d + r*e | |
292 | # Cancel out bottom N bits of cd and ce. | |
293 | ... | |
294 | ``` | |
295 | ||
296 | With inputs in *(-2M,M)*, they will first be shifted into range *(-M,M)*, which means that the | |
297 | output will again be in *(-2M,M)*, and this remains the case regardless of how many `update_de` | |
298 | invocations there are. In what follows, we will try to make this more efficient. | |
299 | ||
300 | Note that increasing *d* by *M* is equal to incrementing *cd* by *u M* and *ce* by *q M*. Similarly, | |
301 | increasing *e* by *M* is equal to incrementing *cd* by *v M* and *ce* by *r M*. So we could instead write: | |
302 | ||
303 | ```python | |
304 | ... | |
305 | cd, ce = u*d + v*e, q*d + r*e | |
306 | # Perform the equivalent of incrementing d, e by M when they're negative. | |
307 | if d < 0: | |
308 | cd += u*M | |
309 | ce += q*M | |
310 | if e < 0: | |
311 | cd += v*M | |
312 | ce += r*M | |
313 | # Cancel out bottom N bits of cd and ce. | |
314 | md = -((Mi * cd) % 2**N) | |
315 | me = -((Mi * ce) % 2**N) | |
316 | cd += md * M | |
317 | ce += me * M | |
318 | ... | |
319 | ``` | |
320 | ||
321 | Now note that we have two steps of corrections to *cd* and *ce* that add multiples of *M*: this | |
322 | increment, and the decrement that cancels out bottom bits. The second one depends on the first | |
323 | one, but they can still be efficiently combined by only computing the bottom bits of *cd* and *ce* | |
324 | at first, and using that to compute the final *md*, *me* values: | |
325 | ||
326 | ```python | |
327 | def update_de(d, e, t, M, Mi): | |
328 | """Multiply matrix t/2^N with [d, e], modulo M.""" | |
329 | u, v, q, r = t | |
330 | md, me = 0, 0 | |
331 | # Compute what multiples of M to add to cd and ce. | |
332 | if d < 0: | |
333 | md += u | |
334 | me += q | |
335 | if e < 0: | |
336 | md += v | |
337 | me += r | |
338 | # Compute bottom N bits of t*[d,e] + M*[md,me]. | |
339 | cd, ce = (u*d + v*e + md*M) % 2**N, (q*d + r*e + me*M) % 2**N | |
340 | # Correct md and me such that the bottom N bits of t*[d,e] + M*[md,me] are zero. | |
341 | md -= (Mi * cd) % 2**N | |
342 | me -= (Mi * ce) % 2**N | |
343 | # Do the full computation. | |
344 | cd, ce = u*d + v*e + md*M, q*d + r*e + me*M | |
345 | # And cleanly divide by 2**N. | |
346 | return cd >> N, ce >> N | |
347 | ``` | |
348 | ||
349 | One last optimization: we can avoid the *md M* and *me M* multiplications in the bottom bits of *cd* | |
350 | and *ce* by moving them to the *md* and *me* correction: | |
351 | ||
352 | ```python | |
353 | ... | |
354 | # Compute bottom N bits of t*[d,e]. | |
355 | cd, ce = (u*d + v*e) % 2**N, (q*d + r*e) % 2**N | |
356 | # Correct md and me such that the bottom N bits of t*[d,e]+M*[md,me] are zero. | |
357 | # Note that this is not the same as {md = (-Mi * cd) % 2**N} etc. That would also result in N | |
358 | # zero bottom bits, but isn't guaranteed to be a reduction of [0,2^N) compared to the | |
359 | # previous md and me values, and thus would violate our bounds analysis. | |
360 | md -= (Mi*cd + md) % 2**N | |
361 | me -= (Mi*ce + me) % 2**N | |
362 | ... | |
363 | ``` | |
364 | ||
365 | The resulting function takes *d* and *e* in range *(-2M,M)* as inputs, and outputs values in the same | |
366 | range. That also means that the *d* value at the end of `modinv` will be in that range, while we want | |
367 | a result in *[0,M)*. To do that, we need a normalization function. It's easy to integrate the | |
368 | conditional negation of *d* (based on the sign of *f*) into it as well: | |
369 | ||
370 | ```python | |
371 | def normalize(sign, v, M): | |
372 | """Compute sign*v mod M, where v is in range (-2*M,M); output in [0,M).""" | |
373 | assert sign == 1 or sign == -1 | |
374 | # v in (-2*M,M) | |
375 | if v < 0: | |
376 | v += M | |
377 | # v in (-M,M). Now multiply v with sign (which can only be 1 or -1). | |
378 | if sign == -1: | |
379 | v = -v | |
380 | # v in (-M,M) | |
381 | if v < 0: | |
382 | v += M | |
383 | # v in [0,M) | |
384 | return v | |
385 | ``` | |
386 | ||
387 | And calling it in `modinv` is simply: | |
388 | ||
389 | ```python | |
390 | ... | |
391 | return normalize(f, d, M) | |
392 | ``` | |
393 | ||
394 | ||
395 | ## 5. Constant-time operation | |
396 | ||
397 | The primary selling point of the algorithm is fast constant-time operation. What code flow still | |
398 | depends on the input data so far? | |
399 | ||
400 | - the number of iterations of the while *g ≠ 0* loop in `modinv` | |
401 | - the branches inside `divsteps_n_matrix` | |
402 | - the sign checks in `update_de` | |
403 | - the sign checks in `normalize` | |
404 | ||
405 | To make the while loop in `modinv` constant time it can be replaced with a constant number of | |
406 | iterations. The paper proves (Theorem 11.2) that *741* divsteps are sufficient for any *256*-bit | |
407 | inputs, and [safegcd-bounds](https://github.com/sipa/safegcd-bounds) shows that the slightly better bound *724* is | |
408 | sufficient even. Given that every loop iteration performs *N* divsteps, it will run a total of | |
409 | *⌈724/N⌉* times. | |
410 | ||
411 | To deal with the branches in `divsteps_n_matrix` we will replace them with constant-time bitwise | |
412 | operations (and hope the C compiler isn't smart enough to turn them back into branches; see | |
413 | `valgrind_ctime_test.c` for automated tests that this isn't the case). To do so, observe that a | |
414 | divstep can be written instead as (compare to the inner loop of `gcd` in section 1). | |
415 | ||
416 | ```python | |
417 | x = -f if delta > 0 else f # set x equal to (input) -f or f | |
418 | if g & 1: | |
419 | g += x # set g to (input) g-f or g+f | |
420 | if delta > 0: | |
421 | delta = -delta | |
422 | f += g # set f to (input) g (note that g was set to g-f before) | |
423 | delta += 1 | |
424 | g >>= 1 | |
425 | ``` | |
426 | ||
427 | To convert the above to bitwise operations, we rely on a trick to negate conditionally: per the | |
428 | definition of negative numbers in two's complement, (*-v == ~v + 1*) holds for every number *v*. As | |
429 | *-1* in two's complement is all *1* bits, bitflipping can be expressed as xor with *-1*. It follows | |
430 | that *-v == (v ^ -1) - (-1)*. Thus, if we have a variable *c* that takes on values *0* or *-1*, then | |
431 | *(v ^ c) - c* is *v* if *c=0* and *-v* if *c=-1*. | |
432 | ||
433 | Using this we can write: | |
434 | ||
435 | ```python | |
436 | x = -f if delta > 0 else f | |
437 | ``` | |
438 | ||
439 | in constant-time form as: | |
440 | ||
441 | ```python | |
442 | c1 = (-delta) >> 63 | |
443 | # Conditionally negate f based on c1: | |
444 | x = (f ^ c1) - c1 | |
445 | ``` | |
446 | ||
447 | To use that trick, we need a helper mask variable *c1* that resolves the condition *δ>0* to *-1* | |
448 | (if true) or *0* (if false). We compute *c1* using right shifting, which is equivalent to dividing by | |
449 | the specified power of *2* and rounding down (in Python, and also in C under the assumption of a typical two's complement system; see | |
450 | `assumptions.h` for tests that this is the case). Right shifting by *63* thus maps all | |
451 | numbers in range *[-2<sup>63</sup>,0)* to *-1*, and numbers in range *[0,2<sup>63</sup>)* to *0*. | |
452 | ||
453 | Using the facts that *x&0=0* and *x&(-1)=x* (on two's complement systems again), we can write: | |
454 | ||
455 | ```python | |
456 | if g & 1: | |
457 | g += x | |
458 | ``` | |
459 | ||
460 | as: | |
461 | ||
462 | ```python | |
463 | # Compute c2=0 if g is even and c2=-1 if g is odd. | |
464 | c2 = -(g & 1) | |
465 | # This masks out x if g is even, and leaves x be if g is odd. | |
466 | g += x & c2 | |
467 | ``` | |
468 | ||
469 | Using the conditional negation trick again we can write: | |
470 | ||
471 | ```python | |
472 | if g & 1: | |
473 | if delta > 0: | |
474 | delta = -delta | |
475 | ``` | |
476 | ||
477 | as: | |
478 | ||
479 | ```python | |
480 | # Compute c3=-1 if g is odd and delta>0, and 0 otherwise. | |
481 | c3 = c1 & c2 | |
482 | # Conditionally negate delta based on c3: | |
483 | delta = (delta ^ c3) - c3 | |
484 | ``` | |
485 | ||
486 | Finally: | |
487 | ||
488 | ```python | |
489 | if g & 1: | |
490 | if delta > 0: | |
491 | f += g | |
492 | ``` | |
493 | ||
494 | becomes: | |
495 | ||
496 | ```python | |
497 | f += g & c3 | |
498 | ``` | |
499 | ||
500 | It turns out that this can be implemented more efficiently by applying the substitution | |
501 | *η=-δ*. In this representation, negating *δ* corresponds to negating *η*, and incrementing | |
502 | *δ* corresponds to decrementing *η*. This allows us to remove the negation in the *c1* | |
503 | computation: | |
504 | ||
505 | ```python | |
506 | # Compute a mask c1 for eta < 0, and compute the conditional negation x of f: | |
507 | c1 = eta >> 63 | |
508 | x = (f ^ c1) - c1 | |
509 | # Compute a mask c2 for odd g, and conditionally add x to g: | |
510 | c2 = -(g & 1) | |
511 | g += x & c2 | |
512 | # Compute a mask c for (eta < 0) and odd (input) g, and use it to conditionally negate eta, | |
513 | # and add g to f: | |
514 | c3 = c1 & c2 | |
515 | eta = (eta ^ c3) - c3 | |
516 | f += g & c3 | |
517 | # Incrementing delta corresponds to decrementing eta. | |
518 | eta -= 1 | |
519 | g >>= 1 | |
520 | ``` | |
521 | ||
277b224b PW |
522 | A variant of divsteps with better worst-case performance can be used instead: starting *δ* at |
523 | *1/2* instead of *1*. This reduces the worst case number of iterations to *590* for *256*-bit inputs | |
524 | (which can be shown using convex hull analysis). In this case, the substitution *ζ=-(δ+1/2)* | |
525 | is used instead to keep the variable integral. Incrementing *δ* by *1* still translates to | |
526 | decrementing *ζ* by *1*, but negating *δ* now corresponds to going from *ζ* to *-(ζ+1)*, or | |
527 | *~ζ*. Doing that conditionally based on *c3* is simply: | |
528 | ||
529 | ```python | |
530 | ... | |
531 | c3 = c1 & c2 | |
532 | zeta ^= c3 | |
533 | ... | |
534 | ``` | |
535 | ||
d8a92fcc PW |
536 | By replacing the loop in `divsteps_n_matrix` with a variant of the divstep code above (extended to |
537 | also apply all *f* operations to *u*, *v* and all *g* operations to *q*, *r*), a constant-time version of | |
538 | `divsteps_n_matrix` is obtained. The full code will be in section 7. | |
539 | ||
540 | These bit fiddling tricks can also be used to make the conditional negations and additions in | |
541 | `update_de` and `normalize` constant-time. | |
542 | ||
543 | ||
544 | ## 6. Variable-time optimizations | |
545 | ||
546 | In section 5, we modified the `divsteps_n_matrix` function (and a few others) to be constant time. | |
547 | Constant time operations are only necessary when computing modular inverses of secret data. In | |
548 | other cases, it slows down calculations unnecessarily. In this section, we will construct a | |
549 | faster non-constant time `divsteps_n_matrix` function. | |
550 | ||
551 | To do so, first consider yet another way of writing the inner loop of divstep operations in | |
277b224b PW |
552 | `gcd` from section 1. This decomposition is also explained in the paper in section 8.2. We use |
553 | the original version with initial *δ=1* and *η=-δ* here. | |
d8a92fcc PW |
554 | |
555 | ```python | |
556 | for _ in range(N): | |
557 | if g & 1 and eta < 0: | |
558 | eta, f, g = -eta, g, -f | |
559 | if g & 1: | |
560 | g += f | |
561 | eta -= 1 | |
562 | g >>= 1 | |
563 | ``` | |
564 | ||
565 | Whenever *g* is even, the loop only shifts *g* down and decreases *η*. When *g* ends in multiple zero | |
566 | bits, these iterations can be consolidated into one step. This requires counting the bottom zero | |
567 | bits efficiently, which is possible on most platforms; it is abstracted here as the function | |
568 | `count_trailing_zeros`. | |
569 | ||
570 | ```python | |
571 | def count_trailing_zeros(v): | |
572 | """For a non-zero value v, find z such that v=(d<<z) for some odd d.""" | |
573 | return (v & -v).bit_length() - 1 | |
574 | ||
575 | i = N # divsteps left to do | |
576 | while True: | |
577 | # Get rid of all bottom zeros at once. In the first iteration, g may be odd and the following | |
578 | # lines have no effect (until "if eta < 0"). | |
579 | zeros = min(i, count_trailing_zeros(g)) | |
580 | eta -= zeros | |
581 | g >>= zeros | |
582 | i -= zeros | |
583 | if i == 0: | |
584 | break | |
585 | # We know g is odd now | |
586 | if eta < 0: | |
587 | eta, f, g = -eta, g, -f | |
588 | g += f | |
589 | # g is even now, and the eta decrement and g shift will happen in the next loop. | |
590 | ``` | |
591 | ||
592 | We can now remove multiple bottom *0* bits from *g* at once, but still need a full iteration whenever | |
593 | there is a bottom *1* bit. In what follows, we will get rid of multiple *1* bits simultaneously as | |
594 | well. | |
595 | ||
596 | Observe that as long as *η ≥ 0*, the loop does not modify *f*. Instead, it cancels out bottom | |
597 | bits of *g* and shifts them out, and decreases *η* and *i* accordingly - interrupting only when *η* | |
598 | becomes negative, or when *i* reaches *0*. Combined, this is equivalent to adding a multiple of *f* to | |
599 | *g* to cancel out multiple bottom bits, and then shifting them out. | |
600 | ||
601 | It is easy to find what that multiple is: we want a number *w* such that *g+w f* has a few bottom | |
602 | zero bits. If that number of bits is *L*, we want *g+w f mod 2<sup>L</sup> = 0*, or *w = -g/f mod 2<sup>L</sup>*. Since *f* | |
603 | is odd, such a *w* exists for any *L*. *L* cannot be more than *i* steps (as we'd finish the loop before | |
604 | doing more) or more than *η+1* steps (as we'd run `eta, f, g = -eta, g, f` at that point), but | |
605 | apart from that, we're only limited by the complexity of computing *w*. | |
606 | ||
607 | This code demonstrates how to cancel up to 4 bits per step: | |
608 | ||
609 | ```python | |
610 | NEGINV16 = [15, 5, 3, 9, 7, 13, 11, 1] # NEGINV16[n//2] = (-n)^-1 mod 16, for odd n | |
611 | i = N | |
612 | while True: | |
613 | zeros = min(i, count_trailing_zeros(g)) | |
614 | eta -= zeros | |
615 | g >>= zeros | |
616 | i -= zeros | |
617 | if i == 0: | |
618 | break | |
619 | # We know g is odd now | |
620 | if eta < 0: | |
621 | eta, f, g = -eta, g, f | |
622 | # Compute limit on number of bits to cancel | |
623 | limit = min(min(eta + 1, i), 4) | |
624 | # Compute w = -g/f mod 2**limit, using the table value for -1/f mod 2**4. Note that f is | |
625 | # always odd, so its inverse modulo a power of two always exists. | |
626 | w = (g * NEGINV16[(f & 15) // 2]) % (2**limit) | |
627 | # As w = -g/f mod (2**limit), g+w*f mod 2**limit = 0 mod 2**limit. | |
628 | g += w * f | |
629 | assert g % (2**limit) == 0 | |
630 | # The next iteration will now shift out at least limit bottom zero bits from g. | |
631 | ``` | |
632 | ||
633 | By using a bigger table more bits can be cancelled at once. The table can also be implemented | |
634 | as a formula. Several formulas are known for computing modular inverses modulo powers of two; | |
635 | some can be found in Hacker's Delight second edition by Henry S. Warren, Jr. pages 245-247. | |
636 | Here we need the negated modular inverse, which is a simple transformation of those: | |
637 | ||
638 | - Instead of a 3-bit table: | |
639 | - *-f* or *f ^ 6* | |
640 | - Instead of a 4-bit table: | |
641 | - *1 - f(f + 1)* | |
642 | - *-(f + (((f + 1) & 4) << 1))* | |
643 | - For larger tables the following technique can be used: if *w=-1/f mod 2<sup>L</sup>*, then *w(w f+2)* is | |
644 | *-1/f mod 2<sup>2L</sup>*. This allows extending the previous formulas (or tables). In particular we | |
645 | have this 6-bit function (based on the 3-bit function above): | |
646 | - *f(f<sup>2</sup> - 2)* | |
647 | ||
648 | This loop, again extended to also handle *u*, *v*, *q*, and *r* alongside *f* and *g*, placed in | |
649 | `divsteps_n_matrix`, gives a significantly faster, but non-constant time version. | |
650 | ||
651 | ||
652 | ## 7. Final Python version | |
653 | ||
654 | All together we need the following functions: | |
655 | ||
656 | - A way to compute the transition matrix in constant time, using the `divsteps_n_matrix` function | |
657 | from section 2, but with its loop replaced by a variant of the constant-time divstep from | |
658 | section 5, extended to handle *u*, *v*, *q*, *r*: | |
659 | ||
660 | ```python | |
277b224b PW |
661 | def divsteps_n_matrix(zeta, f, g): |
662 | """Compute zeta and transition matrix t after N divsteps (multiplied by 2^N).""" | |
d8a92fcc PW |
663 | u, v, q, r = 1, 0, 0, 1 # start with identity matrix |
664 | for _ in range(N): | |
277b224b | 665 | c1 = zeta >> 63 |
d8a92fcc PW |
666 | # Compute x, y, z as conditionally-negated versions of f, u, v. |
667 | x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1 | |
668 | c2 = -(g & 1) | |
669 | # Conditionally add x, y, z to g, q, r. | |
670 | g, q, r = g + (x & c2), q + (y & c2), r + (z & c2) | |
671 | c1 &= c2 # reusing c1 here for the earlier c3 variable | |
277b224b | 672 | zeta = (zeta ^ c1) - 1 # inlining the unconditional zeta decrement here |
d8a92fcc PW |
673 | # Conditionally add g, q, r to f, u, v. |
674 | f, u, v = f + (g & c1), u + (q & c1), v + (r & c1) | |
675 | # When shifting g down, don't shift q, r, as we construct a transition matrix multiplied | |
676 | # by 2^N. Instead, shift f's coefficients u and v up. | |
677 | g, u, v = g >> 1, u << 1, v << 1 | |
277b224b | 678 | return zeta, (u, v, q, r) |
d8a92fcc PW |
679 | ``` |
680 | ||
681 | - The functions to update *f* and *g*, and *d* and *e*, from section 2 and section 4, with the constant-time | |
682 | changes to `update_de` from section 5: | |
683 | ||
684 | ```python | |
685 | def update_fg(f, g, t): | |
686 | """Multiply matrix t/2^N with [f, g].""" | |
687 | u, v, q, r = t | |
688 | cf, cg = u*f + v*g, q*f + r*g | |
689 | return cf >> N, cg >> N | |
690 | ||
691 | def update_de(d, e, t, M, Mi): | |
692 | """Multiply matrix t/2^N with [d, e], modulo M.""" | |
693 | u, v, q, r = t | |
694 | d_sign, e_sign = d >> 257, e >> 257 | |
695 | md, me = (u & d_sign) + (v & e_sign), (q & d_sign) + (r & e_sign) | |
696 | cd, ce = (u*d + v*e) % 2**N, (q*d + r*e) % 2**N | |
697 | md -= (Mi*cd + md) % 2**N | |
698 | me -= (Mi*ce + me) % 2**N | |
376ca366 | 699 | cd, ce = u*d + v*e + M*md, q*d + r*e + M*me |
d8a92fcc PW |
700 | return cd >> N, ce >> N |
701 | ``` | |
702 | ||
703 | - The `normalize` function from section 4, made constant time as well: | |
704 | ||
705 | ```python | |
706 | def normalize(sign, v, M): | |
707 | """Compute sign*v mod M, where v in (-2*M,M); output in [0,M).""" | |
708 | v_sign = v >> 257 | |
709 | # Conditionally add M to v. | |
710 | v += M & v_sign | |
711 | c = (sign - 1) >> 1 | |
712 | # Conditionally negate v. | |
713 | v = (v ^ c) - c | |
714 | v_sign = v >> 257 | |
715 | # Conditionally add M to v again. | |
716 | v += M & v_sign | |
717 | return v | |
718 | ``` | |
719 | ||
277b224b | 720 | - And finally the `modinv` function too, adapted to use *ζ* instead of *δ*, and using the fixed |
d8a92fcc PW |
721 | iteration count from section 5: |
722 | ||
723 | ```python | |
724 | def modinv(M, Mi, x): | |
725 | """Compute the modular inverse of x mod M, given Mi=1/M mod 2^N.""" | |
277b224b PW |
726 | zeta, f, g, d, e = -1, M, x, 0, 1 |
727 | for _ in range((590 + N - 1) // N): | |
728 | zeta, t = divsteps_n_matrix(zeta, f % 2**N, g % 2**N) | |
d8a92fcc PW |
729 | f, g = update_fg(f, g, t) |
730 | d, e = update_de(d, e, t, M, Mi) | |
731 | return normalize(f, d, M) | |
732 | ``` | |
733 | ||
734 | - To get a variable time version, replace the `divsteps_n_matrix` function with one that uses the | |
735 | divsteps loop from section 5, and a `modinv` version that calls it without the fixed iteration | |
736 | count: | |
737 | ||
738 | ```python | |
739 | NEGINV16 = [15, 5, 3, 9, 7, 13, 11, 1] # NEGINV16[n//2] = (-n)^-1 mod 16, for odd n | |
740 | def divsteps_n_matrix_var(eta, f, g): | |
741 | """Compute eta and transition matrix t after N divsteps (multiplied by 2^N).""" | |
742 | u, v, q, r = 1, 0, 0, 1 | |
743 | i = N | |
744 | while True: | |
745 | zeros = min(i, count_trailing_zeros(g)) | |
746 | eta, i = eta - zeros, i - zeros | |
747 | g, u, v = g >> zeros, u << zeros, v << zeros | |
748 | if i == 0: | |
749 | break | |
750 | if eta < 0: | |
751 | eta, f, u, v, g, q, r = -eta, g, q, r, -f, -u, -v | |
752 | limit = min(min(eta + 1, i), 4) | |
753 | w = (g * NEGINV16[(f & 15) // 2]) % (2**limit) | |
754 | g, q, r = g + w*f, q + w*u, r + w*v | |
755 | return eta, (u, v, q, r) | |
756 | ||
757 | def modinv_var(M, Mi, x): | |
758 | """Compute the modular inverse of x mod M, given Mi = 1/M mod 2^N.""" | |
759 | eta, f, g, d, e = -1, M, x, 0, 1 | |
760 | while g != 0: | |
761 | eta, t = divsteps_n_matrix_var(eta, f % 2**N, g % 2**N) | |
762 | f, g = update_fg(f, g, t) | |
763 | d, e = update_de(d, e, t, M, Mi) | |
764 | return normalize(f, d, Mi) | |
765 | ``` |