]> rtime.felk.cvut.cz Git - hercules2020/nv-tegra/linux-4.4.git/blob - crypto/ecc.c
arm64: config: Enable CDC-ACM driver in kernel
[hercules2020/nv-tegra/linux-4.4.git] / crypto / ecc.c
1 /*
2  * Copyright (c) 2013, Kenneth MacKay
3  * All rights reserved.
4  * Copyright (c) 2017, NVIDIA Corporation. All Rights Reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are
8  * met:
9  *  * Redistributions of source code must retain the above copyright
10  *   notice, this list of conditions and the following disclaimer.
11  *  * Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
19  * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
21  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27
28 #include <linux/random.h>
29 #include <linux/slab.h>
30 #include <linux/swab.h>
31 #include <linux/fips.h>
32
33 #include "ecc.h"
34
35 typedef struct {
36         u64 m_low;
37         u64 m_high;
38 } uint128_t;
39
40 /* NIST P-192 */
41 static u64 nist_p192_g_x[] = { 0xF4FF0AFD82FF1012ull, 0x7CBF20EB43A18800ull,
42                                 0x188DA80EB03090F6ull };
43 static u64 nist_p192_g_y[] = { 0x73F977A11E794811ull, 0x631011ED6B24CDD5ull,
44                                 0x07192B95FFC8DA78ull };
45 static u64 nist_p192_p[] = { 0xFFFFFFFFFFFFFFFFull, 0xFFFFFFFFFFFFFFFEull,
46                                 0xFFFFFFFFFFFFFFFFull };
47 static u64 nist_p192_n[] = { 0x146BC9B1B4D22831ull, 0xFFFFFFFF99DEF836ull,
48                                 0xFFFFFFFFFFFFFFFFull };
49 static struct ecc_curve nist_p192 = {
50         .name = "nist_192",
51         .g = {
52                 .x = nist_p192_g_x,
53                 .y = nist_p192_g_y,
54                 .ndigits = 3,
55         },
56         .p = nist_p192_p,
57         .n = nist_p192_n
58 };
59
60 /* NIST P-256 */
61 static u64 nist_p256_g_x[] = { 0xF4A13945D898C296ull, 0x77037D812DEB33A0ull,
62                                 0xF8BCE6E563A440F2ull, 0x6B17D1F2E12C4247ull };
63 static u64 nist_p256_g_y[] = { 0xCBB6406837BF51F5ull, 0x2BCE33576B315ECEull,
64                                 0x8EE7EB4A7C0F9E16ull, 0x4FE342E2FE1A7F9Bull };
65 static u64 nist_p256_p[] = { 0xFFFFFFFFFFFFFFFFull, 0x00000000FFFFFFFFull,
66                                 0x0000000000000000ull, 0xFFFFFFFF00000001ull };
67 static u64 nist_p256_n[] = { 0xF3B9CAC2FC632551ull, 0xBCE6FAADA7179E84ull,
68                                 0xFFFFFFFFFFFFFFFFull, 0xFFFFFFFF00000000ull };
69 static struct ecc_curve nist_p256 = {
70         .name = "nist_256",
71         .g = {
72                 .x = nist_p256_g_x,
73                 .y = nist_p256_g_y,
74                 .ndigits = 4,
75         },
76         .p = nist_p256_p,
77         .n = nist_p256_n
78 };
79
80 const struct ecc_curve *ecc_get_curve(unsigned int curve_id)
81 {
82         switch (curve_id) {
83         /* In FIPS mode only allow P256 and higher */
84         case ECC_CURVE_NIST_P192:
85                 return fips_enabled ? NULL : &nist_p192;
86         case ECC_CURVE_NIST_P256:
87                 return &nist_p256;
88         default:
89                 return NULL;
90         }
91 }
92 EXPORT_SYMBOL_GPL(ecc_get_curve);
93
94 static u64 *ecc_alloc_digits_space(unsigned int ndigits)
95 {
96         size_t len = ndigits * sizeof(u64);
97
98         if (!len)
99                 return NULL;
100
101         return kmalloc(len, GFP_KERNEL);
102 }
103
104 static void ecc_free_digits_space(u64 *space)
105 {
106         kzfree(space);
107 }
108
109 struct ecc_point *ecc_alloc_point(unsigned int ndigits)
110 {
111         struct ecc_point *p = kmalloc(sizeof(*p), GFP_KERNEL);
112
113         if (!p)
114                 return NULL;
115
116         p->x = ecc_alloc_digits_space(ndigits);
117         if (!p->x)
118                 goto err_alloc_x;
119
120         p->y = ecc_alloc_digits_space(ndigits);
121         if (!p->y)
122                 goto err_alloc_y;
123
124         p->ndigits = ndigits;
125
126         return p;
127
128 err_alloc_y:
129         ecc_free_digits_space(p->x);
130 err_alloc_x:
131         kfree(p);
132         return NULL;
133 }
134 EXPORT_SYMBOL_GPL(ecc_alloc_point);
135
136 void ecc_free_point(struct ecc_point *p)
137 {
138         if (!p)
139                 return;
140
141         kzfree(p->x);
142         kzfree(p->y);
143         kzfree(p);
144 }
145 EXPORT_SYMBOL_GPL(ecc_free_point);
146
147 void vli_clear(u64 *vli, unsigned int ndigits)
148 {
149         int i;
150
151         for (i = 0; i < ndigits; i++)
152                 vli[i] = 0;
153 }
154 EXPORT_SYMBOL_GPL(vli_clear);
155
156 /* Returns true if vli == 0, false otherwise. */
157 bool vli_is_zero(const u64 *vli, unsigned int ndigits)
158 {
159         int i;
160
161         for (i = 0; i < ndigits; i++) {
162                 if (vli[i])
163                         return false;
164         }
165
166         return true;
167 }
168 EXPORT_SYMBOL_GPL(vli_is_zero);
169
170 /* Returns nonzero if bit bit of vli is set. */
171 u64 vli_test_bit(const u64 *vli, unsigned int bit)
172 {
173         return (vli[bit / 64] & ((u64)1 << (bit % 64)));
174 }
175 EXPORT_SYMBOL_GPL(vli_test_bit);
176
177 /* Counts the number of 64-bit "digits" in vli. */
178 unsigned int vli_num_digits(const u64 *vli, unsigned int ndigits)
179 {
180         int i;
181
182         /* Search from the end until we find a non-zero digit.
183          * We do it in reverse because we expect that most digits will
184          * be nonzero.
185          */
186         for (i = ndigits - 1; i >= 0 && vli[i] == 0; i--);
187
188         return (i + 1);
189 }
190 EXPORT_SYMBOL_GPL(vli_num_digits);
191
192 /* Counts the number of bits required for vli. */
193 unsigned int vli_num_bits(const u64 *vli, unsigned int ndigits)
194 {
195         unsigned int i, num_digits;
196         u64 digit;
197
198         num_digits = vli_num_digits(vli, ndigits);
199         if (num_digits == 0)
200                 return 0;
201
202         digit = vli[num_digits - 1];
203         for (i = 0; digit; i++)
204                 digit >>= 1;
205
206         return ((num_digits - 1) * 64 + i);
207 }
208 EXPORT_SYMBOL_GPL(vli_num_bits);
209
210 /* Sets dest = src. */
211 void vli_set(u64 *dest, const u64 *src, unsigned int ndigits)
212 {
213         int i;
214
215         for (i = 0; i < ndigits; i++)
216                 dest[i] = src[i];
217 }
218 EXPORT_SYMBOL_GPL(vli_set);
219
220 /* Copy from vli to buf.
221  * For buffers smaller than vli: copy only LSB nbytes from vli.
222  * For buffers larger than vli : fill up remaining buf with zeroes.
223  */
224 void vli_copy_to_buf(u8 *dst_buf, unsigned int buf_len,
225                      const u64 *src_vli, unsigned int ndigits)
226 {
227         unsigned int nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
228         u8 *vli = (u8 *)src_vli;
229         int i;
230
231         for (i = 0; i < buf_len && i < nbytes; i++)
232                 dst_buf[i] = vli[i];
233
234         for (; i < buf_len; i++)
235                 dst_buf[i] = 0;
236 }
237 EXPORT_SYMBOL_GPL(vli_copy_to_buf);
238
239 /* Copy from buffer to vli.
240  * For buffers smaller than vli: fill up remaining vli with zeroes.
241  * For buffers larger than vli : copy only LSB nbytes to vli.
242  */
243 void vli_copy_from_buf(u64 *dst_vli, unsigned int ndigits,
244                        const u8 *src_buf, unsigned int buf_len)
245 {
246         unsigned int nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
247         u8 *vli = (u8 *)dst_vli;
248         int i;
249
250         for (i = 0; i < buf_len && i < nbytes; i++)
251                 vli[i] = src_buf[i];
252
253         for (; i < nbytes; i++)
254                 vli[i] = 0;
255 }
256 EXPORT_SYMBOL_GPL(vli_copy_from_buf);
257
258 /* Returns sign of left - right. */
259 int vli_cmp(const u64 *left, const u64 *right, unsigned int ndigits)
260 {
261         int i;
262
263         for (i = ndigits - 1; i >= 0; i--) {
264                 if (left[i] > right[i])
265                         return 1;
266                 else if (left[i] < right[i])
267                         return -1;
268         }
269
270         return 0;
271 }
272 EXPORT_SYMBOL_GPL(vli_cmp);
273
274 /* Computes result = in << c, returning carry. Can modify in place
275  * (if result == in). 0 < shift < 64.
276  */
277 u64 vli_lshift(u64 *result, const u64 *in, unsigned int shift,
278                unsigned int ndigits)
279 {
280         u64 carry = 0;
281         int i;
282
283         for (i = 0; i < ndigits; i++) {
284                 u64 temp = in[i];
285
286                 result[i] = (temp << shift) | carry;
287                 carry = temp >> (64 - shift);
288         }
289
290         return carry;
291 }
292 EXPORT_SYMBOL_GPL(vli_lshift);
293
294 /* Computes vli = vli >> 1. */
295 void vli_rshift1(u64 *vli, unsigned int ndigits)
296 {
297         u64 *end = vli;
298         u64 carry = 0;
299
300         vli += ndigits;
301
302         while (vli-- > end) {
303                 u64 temp = *vli;
304                 *vli = (temp >> 1) | carry;
305                 carry = temp << 63;
306         }
307 }
308 EXPORT_SYMBOL_GPL(vli_rshift1);
309
310 /* Computes result = left + right, returning carry. Can modify in place. */
311 u64 vli_add(u64 *result, const u64 *left, const u64 *right,
312             unsigned int ndigits)
313 {
314         u64 carry = 0;
315         int i;
316
317         for (i = 0; i < ndigits; i++) {
318                 u64 sum;
319
320                 sum = left[i] + right[i] + carry;
321                 if (sum != left[i])
322                         carry = (sum < left[i]);
323
324                 result[i] = sum;
325         }
326
327         return carry;
328 }
329 EXPORT_SYMBOL_GPL(vli_add);
330
331 /* Computes result = left - right, returning borrow. Can modify in place. */
332 u64 vli_sub(u64 *result, const u64 *left, const u64 *right,
333             unsigned int ndigits)
334 {
335         u64 borrow = 0;
336         int i;
337
338         for (i = 0; i < ndigits; i++) {
339                 u64 diff;
340
341                 diff = left[i] - right[i] - borrow;
342                 if (diff != left[i])
343                         borrow = (diff > left[i]);
344
345                 result[i] = diff;
346         }
347
348         return borrow;
349 }
350 EXPORT_SYMBOL_GPL(vli_sub);
351
352 static uint128_t mul_64_64(u64 left, u64 right)
353 {
354         u64 a0 = left & 0xffffffffull;
355         u64 a1 = left >> 32;
356         u64 b0 = right & 0xffffffffull;
357         u64 b1 = right >> 32;
358         u64 m0 = a0 * b0;
359         u64 m1 = a0 * b1;
360         u64 m2 = a1 * b0;
361         u64 m3 = a1 * b1;
362         uint128_t result;
363
364         m2 += (m0 >> 32);
365         m2 += m1;
366
367         /* Overflow */
368         if (m2 < m1)
369                 m3 += 0x100000000ull;
370
371         result.m_low = (m0 & 0xffffffffull) | (m2 << 32);
372         result.m_high = m3 + (m2 >> 32);
373
374         return result;
375 }
376
377 static uint128_t add_128_128(uint128_t a, uint128_t b)
378 {
379         uint128_t result;
380
381         result.m_low = a.m_low + b.m_low;
382         result.m_high = a.m_high + b.m_high + (result.m_low < a.m_low);
383
384         return result;
385 }
386
387 void vli_mult(u64 *result, const u64 *left, const u64 *right,
388               unsigned int ndigits)
389 {
390         uint128_t r01 = { 0, 0 };
391         u64 r2 = 0;
392         unsigned int i, k;
393
394         /* Compute each digit of result in sequence, maintaining the
395          * carries.
396          */
397         for (k = 0; k < ndigits * 2 - 1; k++) {
398                 unsigned int min;
399
400                 if (k < ndigits)
401                         min = 0;
402                 else
403                         min = (k + 1) - ndigits;
404
405                 for (i = min; i <= k && i < ndigits; i++) {
406                         uint128_t product;
407
408                         product = mul_64_64(left[i], right[k - i]);
409
410                         r01 = add_128_128(r01, product);
411                         r2 += (r01.m_high < product.m_high);
412                 }
413
414                 result[k] = r01.m_low;
415                 r01.m_low = r01.m_high;
416                 r01.m_high = r2;
417                 r2 = 0;
418         }
419
420         result[ndigits * 2 - 1] = r01.m_low;
421 }
422 EXPORT_SYMBOL_GPL(vli_mult);
423
424 void vli_square(u64 *result, const u64 *left, unsigned int ndigits)
425 {
426         uint128_t r01 = { 0, 0 };
427         u64 r2 = 0;
428         int i, k;
429
430         for (k = 0; k < ndigits * 2 - 1; k++) {
431                 unsigned int min;
432
433                 if (k < ndigits)
434                         min = 0;
435                 else
436                         min = (k + 1) - ndigits;
437
438                 for (i = min; i <= k && i <= k - i; i++) {
439                         uint128_t product;
440
441                         product = mul_64_64(left[i], left[k - i]);
442
443                         if (i < k - i) {
444                                 r2 += product.m_high >> 63;
445                                 product.m_high = (product.m_high << 1) |
446                                                  (product.m_low >> 63);
447                                 product.m_low <<= 1;
448                         }
449
450                         r01 = add_128_128(r01, product);
451                         r2 += (r01.m_high < product.m_high);
452                 }
453
454                 result[k] = r01.m_low;
455                 r01.m_low = r01.m_high;
456                 r01.m_high = r2;
457                 r2 = 0;
458         }
459
460         result[ndigits * 2 - 1] = r01.m_low;
461 }
462 EXPORT_SYMBOL_GPL(vli_square);
463
464 /* Computes result = (left + right) % mod.
465  * Assumes that left < mod and right < mod, result != mod.
466  */
467 void vli_mod_add(u64 *result, const u64 *left, const u64 *right,
468                  const u64 *mod, unsigned int ndigits)
469 {
470         u64 carry;
471
472         carry = vli_add(result, left, right, ndigits);
473
474         /* result > mod (result = mod + remainder), so subtract mod to
475          * get remainder.
476          */
477         if (carry || vli_cmp(result, mod, ndigits) >= 0)
478                 vli_sub(result, result, mod, ndigits);
479 }
480 EXPORT_SYMBOL_GPL(vli_mod_add);
481
482 /* Computes result = (left - right) % mod.
483  * Assumes that left < mod and right < mod, result != mod.
484  */
485 void vli_mod_sub(u64 *result, const u64 *left, const u64 *right,
486                  const u64 *mod, unsigned int ndigits)
487 {
488         u64 borrow = vli_sub(result, left, right, ndigits);
489
490         /* In this case, p_result == -diff == (max int) - diff.
491          * Since -x % d == d - x, we can get the correct result from
492          * result + mod (with overflow).
493          */
494         if (borrow)
495                 vli_add(result, result, mod, ndigits);
496 }
497 EXPORT_SYMBOL_GPL(vli_mod_sub);
498
499 /* Computes result = input % mod.
500  * Assumes that input < mod, result != mod.
501  */
502 void vli_mod(u64 *result, const u64 *input, const u64 *mod,
503              unsigned int ndigits)
504 {
505         if (vli_cmp(input, mod, ndigits) >= 0)
506                 vli_sub(result, input, mod, ndigits);
507         else
508                 vli_set(result, input, ndigits);
509 }
510 EXPORT_SYMBOL_GPL(vli_mod);
511
512 /* Print vli in big-endian format.
513  * The bytes are printed in hex.
514  */
515 void vli_print(char *vli_name, const u64 *vli, unsigned int ndigits)
516 {
517         int nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
518         int buf_size = 2 * ECC_MAX_DIGIT_BYTES + 1;
519         unsigned char *c, buf[buf_size];
520         int i, j;
521
522         c = (unsigned char *)vli;
523
524         for (i = nbytes - 1, j = 0; i >= 0 && j < buf_size; i--, j += 2)
525                 snprintf(&buf[j], 3, "%02x", *(c + i));
526
527         buf[j] = '\0';
528
529         pr_info("%20s(BigEnd)=%s\n", vli_name, buf);
530 }
531 EXPORT_SYMBOL_GPL(vli_print);
532
533 /* Computes result = (left * right) % mod.
534  * Assumes that left < mod and right < mod, result != mod.
535  * Uses:
536  *      (a * b) % m = ((a % m) * (b % m)) % m
537  *      (a * b) % m = (a + a + ... + a) % m = b modular additions of (a % m)
538  */
539 void vli_mod_mult(u64 *result, const u64 *left, const u64 *right,
540                   const u64 *mod, unsigned int ndigits)
541 {
542         u64 t1[ndigits], mm[ndigits];
543         u64 aa[ndigits], bb[ndigits];
544
545         vli_clear(result, ndigits);
546         vli_set(aa, left, ndigits);
547         vli_set(bb, right, ndigits);
548         vli_set(mm, mod, ndigits);
549
550         /* aa = aa % mm */
551         vli_mod(aa, aa, mm, ndigits);
552
553         /* bb = bb % mm */
554         vli_mod(bb, bb, mm, ndigits);
555
556         while (!vli_is_zero(bb, ndigits)) {
557
558                 /* if bb is odd i.e. 0th bit set then add
559                  * aa i.e. result = (result + aa) % mm
560                  */
561                 if (vli_test_bit(bb, 0))
562                         vli_mod_add(result, result, aa, mm, ndigits);
563
564                 /* bb = bb / 2 = bb >> 1 */
565                 vli_rshift1(bb, ndigits);
566
567                 /* aa = (aa * 2) % mm */
568                 vli_sub(t1, mm, aa, ndigits);
569                 if (vli_cmp(aa, t1, ndigits) == -1)
570                         /* if aa < t1 then aa = aa * 2 = aa << 1*/
571                         vli_lshift(aa, aa, 1, ndigits);
572                 else
573                         /* if aa >= t1 then aa = aa - t1 */
574                         vli_sub(aa, aa, t1, ndigits);
575         }
576 }
577 EXPORT_SYMBOL_GPL(vli_mod_mult);
578
579 /* Computes p_result = p_product % curve_p.
580  * See algorithm 5 and 6 from
581  * http://www.isys.uni-klu.ac.at/PDF/2001-0126-MT.pdf
582  */
583 static void vli_mmod_fast_192(u64 *result, const u64 *product,
584                               const u64 *curve_prime, u64 *tmp)
585 {
586         const unsigned int ndigits = 3;
587         int carry;
588
589         vli_set(result, product, ndigits);
590
591         vli_set(tmp, &product[3], ndigits);
592         carry = vli_add(result, result, tmp, ndigits);
593
594         tmp[0] = 0;
595         tmp[1] = product[3];
596         tmp[2] = product[4];
597         carry += vli_add(result, result, tmp, ndigits);
598
599         tmp[0] = tmp[1] = product[5];
600         tmp[2] = 0;
601         carry += vli_add(result, result, tmp, ndigits);
602
603         while (carry || vli_cmp(curve_prime, result, ndigits) != 1)
604                 carry -= vli_sub(result, result, curve_prime, ndigits);
605 }
606
607 /* Computes result = product % curve_prime
608  * from http://www.nsa.gov/ia/_files/nist-routines.pdf
609  */
610 static void vli_mmod_fast_256(u64 *result, const u64 *product,
611                               const u64 *curve_prime, u64 *tmp)
612 {
613         int carry;
614         const unsigned int ndigits = 4;
615
616         /* t */
617         vli_set(result, product, ndigits);
618
619         /* s1 */
620         tmp[0] = 0;
621         tmp[1] = product[5] & 0xffffffff00000000ull;
622         tmp[2] = product[6];
623         tmp[3] = product[7];
624         carry = vli_lshift(tmp, tmp, 1, ndigits);
625         carry += vli_add(result, result, tmp, ndigits);
626
627         /* s2 */
628         tmp[1] = product[6] << 32;
629         tmp[2] = (product[6] >> 32) | (product[7] << 32);
630         tmp[3] = product[7] >> 32;
631         carry += vli_lshift(tmp, tmp, 1, ndigits);
632         carry += vli_add(result, result, tmp, ndigits);
633
634         /* s3 */
635         tmp[0] = product[4];
636         tmp[1] = product[5] & 0xffffffff;
637         tmp[2] = 0;
638         tmp[3] = product[7];
639         carry += vli_add(result, result, tmp, ndigits);
640
641         /* s4 */
642         tmp[0] = (product[4] >> 32) | (product[5] << 32);
643         tmp[1] = (product[5] >> 32) | (product[6] & 0xffffffff00000000ull);
644         tmp[2] = product[7];
645         tmp[3] = (product[6] >> 32) | (product[4] << 32);
646         carry += vli_add(result, result, tmp, ndigits);
647
648         /* d1 */
649         tmp[0] = (product[5] >> 32) | (product[6] << 32);
650         tmp[1] = (product[6] >> 32);
651         tmp[2] = 0;
652         tmp[3] = (product[4] & 0xffffffff) | (product[5] << 32);
653         carry -= vli_sub(result, result, tmp, ndigits);
654
655         /* d2 */
656         tmp[0] = product[6];
657         tmp[1] = product[7];
658         tmp[2] = 0;
659         tmp[3] = (product[4] >> 32) | (product[5] & 0xffffffff00000000ull);
660         carry -= vli_sub(result, result, tmp, ndigits);
661
662         /* d3 */
663         tmp[0] = (product[6] >> 32) | (product[7] << 32);
664         tmp[1] = (product[7] >> 32) | (product[4] << 32);
665         tmp[2] = (product[4] >> 32) | (product[5] << 32);
666         tmp[3] = (product[6] << 32);
667         carry -= vli_sub(result, result, tmp, ndigits);
668
669         /* d4 */
670         tmp[0] = product[7];
671         tmp[1] = product[4] & 0xffffffff00000000ull;
672         tmp[2] = product[5];
673         tmp[3] = product[6] & 0xffffffff00000000ull;
674         carry -= vli_sub(result, result, tmp, ndigits);
675
676         if (carry < 0) {
677                 do {
678                         carry += vli_add(result, result, curve_prime, ndigits);
679                 } while (carry < 0);
680         } else {
681                 while (carry || vli_cmp(curve_prime, result, ndigits) != 1)
682                         carry -= vli_sub(result, result, curve_prime, ndigits);
683         }
684 }
685
686 /* Computes result = product % curve_prime
687  *  from http://www.nsa.gov/ia/_files/nist-routines.pdf
688 */
689 bool vli_mmod_fast(u64 *result, u64 *product,
690                    const u64 *curve_prime, unsigned int ndigits)
691 {
692         u64 tmp[2 * ndigits];
693
694         switch (ndigits) {
695         case 3:
696                 vli_mmod_fast_192(result, product, curve_prime, tmp);
697                 break;
698         case 4:
699                 vli_mmod_fast_256(result, product, curve_prime, tmp);
700                 break;
701         default:
702                 pr_err("unsupports digits size!\n");
703                 return false;
704         }
705
706         return true;
707 }
708 EXPORT_SYMBOL_GPL(vli_mmod_fast);
709
710 /* Computes result = (left * right) % curve_prime. */
711 void vli_mod_mult_fast(u64 *result, const u64 *left, const u64 *right,
712                        const u64 *curve_prime, unsigned int ndigits)
713 {
714         u64 product[2 * ndigits];
715
716         vli_mult(product, left, right, ndigits);
717         vli_mmod_fast(result, product, curve_prime, ndigits);
718 }
719 EXPORT_SYMBOL_GPL(vli_mod_mult_fast);
720
721 /* Computes result = left^2 % curve_prime. */
722 void vli_mod_square_fast(u64 *result, const u64 *left,
723                          const u64 *curve_prime, unsigned int ndigits)
724 {
725         u64 product[2 * ndigits];
726
727         vli_square(product, left, ndigits);
728         vli_mmod_fast(result, product, curve_prime, ndigits);
729 }
730 EXPORT_SYMBOL_GPL(vli_mod_square_fast);
731
732 #define EVEN(vli) (!(vli[0] & 1))
733 /* Computes result = (1 / p_input) % mod. All VLIs are the same size.
734  * See "From Euclid's GCD to Montgomery Multiplication to the Great Divide"
735  * https://labs.oracle.com/techrep/2001/smli_tr-2001-95.pdf
736  */
737 void vli_mod_inv(u64 *result, const u64 *input, const u64 *mod,
738                  unsigned int ndigits)
739 {
740         u64 a[ndigits], b[ndigits];
741         u64 u[ndigits], v[ndigits];
742         u64 carry;
743         int cmp_result;
744
745         if (vli_is_zero(input, ndigits)) {
746                 vli_clear(result, ndigits);
747                 return;
748         }
749
750         vli_set(a, input, ndigits);
751         vli_set(b, mod, ndigits);
752         vli_clear(u, ndigits);
753         u[0] = 1;
754         vli_clear(v, ndigits);
755
756         while ((cmp_result = vli_cmp(a, b, ndigits)) != 0) {
757                 carry = 0;
758
759                 if (EVEN(a)) {
760                         vli_rshift1(a, ndigits);
761
762                         if (!EVEN(u))
763                                 carry = vli_add(u, u, mod, ndigits);
764
765                         vli_rshift1(u, ndigits);
766                         if (carry)
767                                 u[ndigits - 1] |= 0x8000000000000000ull;
768                 } else if (EVEN(b)) {
769                         vli_rshift1(b, ndigits);
770
771                         if (!EVEN(v))
772                                 carry = vli_add(v, v, mod, ndigits);
773
774                         vli_rshift1(v, ndigits);
775                         if (carry)
776                                 v[ndigits - 1] |= 0x8000000000000000ull;
777                 } else if (cmp_result > 0) {
778                         vli_sub(a, a, b, ndigits);
779                         vli_rshift1(a, ndigits);
780
781                         if (vli_cmp(u, v, ndigits) < 0)
782                                 vli_add(u, u, mod, ndigits);
783
784                         vli_sub(u, u, v, ndigits);
785                         if (!EVEN(u))
786                                 carry = vli_add(u, u, mod, ndigits);
787
788                         vli_rshift1(u, ndigits);
789                         if (carry)
790                                 u[ndigits - 1] |= 0x8000000000000000ull;
791                 } else {
792                         vli_sub(b, b, a, ndigits);
793                         vli_rshift1(b, ndigits);
794
795                         if (vli_cmp(v, u, ndigits) < 0)
796                                 vli_add(v, v, mod, ndigits);
797
798                         vli_sub(v, v, u, ndigits);
799                         if (!EVEN(v))
800                                 carry = vli_add(v, v, mod, ndigits);
801
802                         vli_rshift1(v, ndigits);
803                         if (carry)
804                                 v[ndigits - 1] |= 0x8000000000000000ull;
805                 }
806         }
807
808         vli_set(result, u, ndigits);
809 }
810 EXPORT_SYMBOL_GPL(vli_mod_inv);
811
812 /* ------ Point operations ------ */
813
814 /* Returns true if p_point is the point at infinity, false otherwise. */
815 bool ecc_point_is_zero(const struct ecc_point *point)
816 {
817         return (vli_is_zero(point->x, point->ndigits) &&
818                 vli_is_zero(point->y, point->ndigits));
819 }
820 EXPORT_SYMBOL_GPL(ecc_point_is_zero);
821
822 /* Point multiplication algorithm using Montgomery's ladder with co-Z
823  * coordinates. From http://eprint.iacr.org/2011/338.pdf
824  */
825
826 /* Double in place */
827 void ecc_point_double_jacobian(u64 *x1, u64 *y1, u64 *z1,
828                                u64 *curve_prime, unsigned int ndigits)
829 {
830         /* t1 = x, t2 = y, t3 = z */
831         u64 t4[ndigits];
832         u64 t5[ndigits];
833
834         if (vli_is_zero(z1, ndigits))
835                 return;
836
837         /* t4 = y1^2 */
838         vli_mod_square_fast(t4, y1, curve_prime, ndigits);
839         /* t5 = x1*y1^2 = A */
840         vli_mod_mult_fast(t5, x1, t4, curve_prime, ndigits);
841         /* t4 = y1^4 */
842         vli_mod_square_fast(t4, t4, curve_prime, ndigits);
843         /* t2 = y1*z1 = z3 */
844         vli_mod_mult_fast(y1, y1, z1, curve_prime, ndigits);
845         /* t3 = z1^2 */
846         vli_mod_square_fast(z1, z1, curve_prime, ndigits);
847
848         /* t1 = x1 + z1^2 */
849         vli_mod_add(x1, x1, z1, curve_prime, ndigits);
850         /* t3 = 2*z1^2 */
851         vli_mod_add(z1, z1, z1, curve_prime, ndigits);
852         /* t3 = x1 - z1^2 */
853         vli_mod_sub(z1, x1, z1, curve_prime, ndigits);
854         /* t1 = x1^2 - z1^4 */
855         vli_mod_mult_fast(x1, x1, z1, curve_prime, ndigits);
856
857         /* t3 = 2*(x1^2 - z1^4) */
858         vli_mod_add(z1, x1, x1, curve_prime, ndigits);
859         /* t1 = 3*(x1^2 - z1^4) */
860         vli_mod_add(x1, x1, z1, curve_prime, ndigits);
861         if (vli_test_bit(x1, 0)) {
862                 u64 carry = vli_add(x1, x1, curve_prime, ndigits);
863
864                 vli_rshift1(x1, ndigits);
865                 x1[ndigits - 1] |= carry << 63;
866         } else {
867                 vli_rshift1(x1, ndigits);
868         }
869         /* t1 = 3/2*(x1^2 - z1^4) = B */
870
871         /* t3 = B^2 */
872         vli_mod_square_fast(z1, x1, curve_prime, ndigits);
873         /* t3 = B^2 - A */
874         vli_mod_sub(z1, z1, t5, curve_prime, ndigits);
875         /* t3 = B^2 - 2A = x3 */
876         vli_mod_sub(z1, z1, t5, curve_prime, ndigits);
877         /* t5 = A - x3 */
878         vli_mod_sub(t5, t5, z1, curve_prime, ndigits);
879         /* t1 = B * (A - x3) */
880         vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits);
881         /* t4 = B * (A - x3) - y1^4 = y3 */
882         vli_mod_sub(t4, x1, t4, curve_prime, ndigits);
883
884         vli_set(x1, z1, ndigits);
885         vli_set(z1, y1, ndigits);
886         vli_set(y1, t4, ndigits);
887 }
888 EXPORT_SYMBOL_GPL(ecc_point_double_jacobian);
889
890 /* Modify (x1, y1) => (x1 * z^2, y1 * z^3) */
891 static void apply_z(u64 *x1, u64 *y1, u64 *z, u64 *curve_prime,
892                     unsigned int ndigits)
893 {
894         u64 t1[ndigits];
895
896         vli_mod_square_fast(t1, z, curve_prime, ndigits);    /* z^2 */
897         vli_mod_mult_fast(x1, x1, t1, curve_prime, ndigits); /* x1 * z^2 */
898         vli_mod_mult_fast(t1, t1, z, curve_prime, ndigits);  /* z^3 */
899         vli_mod_mult_fast(y1, y1, t1, curve_prime, ndigits); /* y1 * z^3 */
900 }
901
902 /* P = (x1, y1) => 2P, (x2, y2) => P' */
903 static void xycz_initial_double(u64 *x1, u64 *y1, u64 *x2, u64 *y2,
904                                 u64 *p_initial_z, u64 *curve_prime,
905                                 unsigned int ndigits)
906 {
907         u64 z[ndigits];
908
909         vli_set(x2, x1, ndigits);
910         vli_set(y2, y1, ndigits);
911
912         vli_clear(z, ndigits);
913         z[0] = 1;
914
915         if (p_initial_z)
916                 vli_set(z, p_initial_z, ndigits);
917
918         apply_z(x1, y1, z, curve_prime, ndigits);
919
920         ecc_point_double_jacobian(x1, y1, z, curve_prime, ndigits);
921
922         apply_z(x2, y2, z, curve_prime, ndigits);
923 }
924
925 /* Input P = (x1, y1, Z), Q = (x2, y2, Z)
926  * Output P' = (x1', y1', Z3), P + Q = (x3, y3, Z3)
927  * or P => P', Q => P + Q
928  */
929 static void xycz_add(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime,
930                      unsigned int ndigits)
931 {
932         /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */
933         u64 t5[ndigits];
934
935         /* t5 = x2 - x1 */
936         vli_mod_sub(t5, x2, x1, curve_prime, ndigits);
937         /* t5 = (x2 - x1)^2 = A */
938         vli_mod_square_fast(t5, t5, curve_prime, ndigits);
939         /* t1 = x1*A = B */
940         vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits);
941         /* t3 = x2*A = C */
942         vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits);
943         /* t4 = y2 - y1 */
944         vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
945         /* t5 = (y2 - y1)^2 = D */
946         vli_mod_square_fast(t5, y2, curve_prime, ndigits);
947
948         /* t5 = D - B */
949         vli_mod_sub(t5, t5, x1, curve_prime, ndigits);
950         /* t5 = D - B - C = x3 */
951         vli_mod_sub(t5, t5, x2, curve_prime, ndigits);
952         /* t3 = C - B */
953         vli_mod_sub(x2, x2, x1, curve_prime, ndigits);
954         /* t2 = y1*(C - B) */
955         vli_mod_mult_fast(y1, y1, x2, curve_prime, ndigits);
956         /* t3 = B - x3 */
957         vli_mod_sub(x2, x1, t5, curve_prime, ndigits);
958         /* t4 = (y2 - y1)*(B - x3) */
959         vli_mod_mult_fast(y2, y2, x2, curve_prime, ndigits);
960         /* t4 = y3 */
961         vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
962
963         vli_set(x2, t5, ndigits);
964 }
965
966 /* Input P = (x1, y1, Z), Q = (x2, y2, Z)
967  * Output P + Q = (x3, y3, Z3), P - Q = (x3', y3', Z3)
968  * or P => P - Q, Q => P + Q
969  */
970 static void xycz_add_c(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime,
971                        unsigned int ndigits)
972 {
973         /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */
974         u64 t5[ndigits];
975         u64 t6[ndigits];
976         u64 t7[ndigits];
977
978         /* t5 = x2 - x1 */
979         vli_mod_sub(t5, x2, x1, curve_prime, ndigits);
980         /* t5 = (x2 - x1)^2 = A */
981         vli_mod_square_fast(t5, t5, curve_prime, ndigits);
982         /* t1 = x1*A = B */
983         vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits);
984         /* t3 = x2*A = C */
985         vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits);
986         /* t4 = y2 + y1 */
987         vli_mod_add(t5, y2, y1, curve_prime, ndigits);
988         /* t4 = y2 - y1 */
989         vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
990
991         /* t6 = C - B */
992         vli_mod_sub(t6, x2, x1, curve_prime, ndigits);
993         /* t2 = y1 * (C - B) */
994         vli_mod_mult_fast(y1, y1, t6, curve_prime, ndigits);
995         /* t6 = B + C */
996         vli_mod_add(t6, x1, x2, curve_prime, ndigits);
997         /* t3 = (y2 - y1)^2 */
998         vli_mod_square_fast(x2, y2, curve_prime, ndigits);
999         /* t3 = x3 */
1000         vli_mod_sub(x2, x2, t6, curve_prime, ndigits);
1001
1002         /* t7 = B - x3 */
1003         vli_mod_sub(t7, x1, x2, curve_prime, ndigits);
1004         /* t4 = (y2 - y1)*(B - x3) */
1005         vli_mod_mult_fast(y2, y2, t7, curve_prime, ndigits);
1006         /* t4 = y3 */
1007         vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
1008
1009         /* t7 = (y2 + y1)^2 = F */
1010         vli_mod_square_fast(t7, t5, curve_prime, ndigits);
1011         /* t7 = x3' */
1012         vli_mod_sub(t7, t7, t6, curve_prime, ndigits);
1013         /* t6 = x3' - B */
1014         vli_mod_sub(t6, t7, x1, curve_prime, ndigits);
1015         /* t6 = (y2 + y1)*(x3' - B) */
1016         vli_mod_mult_fast(t6, t6, t5, curve_prime, ndigits);
1017         /* t2 = y3' */
1018         vli_mod_sub(y1, t6, y1, curve_prime, ndigits);
1019
1020         vli_set(x1, t7, ndigits);
1021 }
1022
1023 /* Point addition.
1024  * Add 2 distinct points on elliptic curve to get a new point.
1025  *
1026  * P = (x1,y1)and Q = (x2, y2) then P + Q = (x3,y3) where
1027  * x3 = ((y2-y1)/(x2-x1))^2 - x1 - x2
1028  * y3 = ((y2-y1)/(x2-x1))(x1-x3) - y1
1029  *
1030  * Q => P + Q
1031  */
1032 void ecc_point_add(u64 *x1, u64 *y1, u64 *x2, u64 *y2, u64 *curve_prime,
1033                    unsigned int ndigits)
1034 {
1035         /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */
1036         u64 t5[ndigits];
1037         u64 t6[ndigits];
1038         u64 t7[ndigits];
1039
1040         /* t6 = x2 - x1 */
1041         vli_mod_sub(t6, x2, x1, curve_prime, ndigits);
1042         /* t6 = (x2 - x1)^2 = A */
1043         vli_mod_square_fast(t6, t6, curve_prime, ndigits);
1044         vli_mod_inv(t7, t6, curve_prime, ndigits);
1045         /* t5 = x2 - x1 */
1046         vli_mod_sub(t5, x2, x1, curve_prime, ndigits);
1047         /* t5 = (x2 - x1)^2 = A */
1048         vli_mod_square_fast(t5, t5, curve_prime, ndigits);
1049         /* t1 = x1*A = B = x1*(x2-x1)^2*/
1050         vli_mod_mult_fast(x1, x1, t5, curve_prime, ndigits);
1051         /* t3 = x2*A = C = x2*(x2-x1)^2*/
1052         vli_mod_mult_fast(x2, x2, t5, curve_prime, ndigits);
1053         /* t4 = y2 - y1 */
1054         vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
1055         /* t5 = (y2 - y1)^2 = D */
1056         vli_mod_square_fast(t5, y2, curve_prime, ndigits);
1057
1058         /* t5 = D - B = (y2 - y1)^2 - x1*(x2-x1)^2 */
1059         vli_mod_sub(t5, t5, x1, curve_prime, ndigits);
1060         /* t5 = D - B - C = x3 = (y2 - y1)^2 - x1*(x2-x1)^2 - x2*(x2-x1)^2*/
1061         vli_mod_sub(t5, t5, x2, curve_prime, ndigits);
1062
1063         /* t3 = C - B = x2*(x2-x1)^2 - x1*(x2-x1)^2 */
1064         vli_mod_sub(x2, x2, x1, curve_prime, ndigits);
1065         /* t2 = y1*(C - B) = y1*(x2*(x2-x1)^2 - x1*(x2-x1)^2)*/
1066         vli_mod_mult_fast(y1, y1, x2, curve_prime, ndigits);
1067         /* t3 = B - x3 = x1*(x2-x1)^2 - x3*/
1068         vli_mod_sub(x2, x1, t5, curve_prime, ndigits);
1069         /* t4 = (y2 - y1)*(B - x3)  = (y2 - y1)*(x1*(x2-x1)^2 - x3)*/
1070         vli_mod_mult_fast(y2, y2, x2, curve_prime, ndigits);
1071         /* t4 = y3 = ((y2 - y1)*(x1*(x2-x1)^2 - x3)) - y1*/
1072         vli_mod_sub(y2, y2, y1, curve_prime, ndigits);
1073
1074         vli_mod_mult_fast(t5, t5, t7,  curve_prime, ndigits);
1075         vli_set(x2, t5, ndigits);
1076 }
1077 EXPORT_SYMBOL_GPL(ecc_point_add);
1078
1079 void ecc_point_mult(struct ecc_point *result,
1080                     const struct ecc_point *point, const u64 *scalar,
1081                     u64 *initial_z, u64 *curve_prime,
1082                     unsigned int ndigits)
1083 {
1084         /* R0 and R1 */
1085         u64 rx[2][ndigits];
1086         u64 ry[2][ndigits];
1087         u64 z[ndigits];
1088         int i, nb;
1089         int num_bits = vli_num_bits(scalar, ndigits);
1090
1091         vli_set(rx[1], point->x, ndigits);
1092         vli_set(ry[1], point->y, ndigits);
1093
1094         xycz_initial_double(rx[1], ry[1], rx[0], ry[0], initial_z, curve_prime,
1095                             ndigits);
1096
1097         for (i = num_bits - 2; i > 0; i--) {
1098                 nb = !vli_test_bit(scalar, i);
1099                 xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve_prime,
1100                            ndigits);
1101                 xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve_prime,
1102                          ndigits);
1103         }
1104
1105         nb = !vli_test_bit(scalar, 0);
1106         xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve_prime,
1107                    ndigits);
1108
1109         /* Find final 1/Z value. */
1110         /* X1 - X0 */
1111         vli_mod_sub(z, rx[1], rx[0], curve_prime, ndigits);
1112         /* Yb * (X1 - X0) */
1113         vli_mod_mult_fast(z, z, ry[1 - nb], curve_prime, ndigits);
1114         /* xP * Yb * (X1 - X0) */
1115         vli_mod_mult_fast(z, z, point->x, curve_prime, ndigits);
1116
1117         /* 1 / (xP * Yb * (X1 - X0)) */
1118         vli_mod_inv(z, z, curve_prime, point->ndigits);
1119
1120         /* yP / (xP * Yb * (X1 - X0)) */
1121         vli_mod_mult_fast(z, z, point->y, curve_prime, ndigits);
1122         /* Xb * yP / (xP * Yb * (X1 - X0)) */
1123         vli_mod_mult_fast(z, z, rx[1 - nb], curve_prime, ndigits);
1124         /* End 1/Z calculation */
1125
1126         xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve_prime, ndigits);
1127
1128         apply_z(rx[0], ry[0], z, curve_prime, ndigits);
1129
1130         vli_set(result->x, rx[0], ndigits);
1131         vli_set(result->y, ry[0], ndigits);
1132 }
1133 EXPORT_SYMBOL_GPL(ecc_point_mult);
1134
1135 void ecc_swap_digits(const u64 *in, u64 *out, unsigned int ndigits)
1136 {
1137         int i;
1138
1139         for (i = 0; i < ndigits; i++)
1140                 out[i] = __swab64(in[ndigits - 1 - i]);
1141 }
1142 EXPORT_SYMBOL_GPL(ecc_swap_digits);
1143
1144 int ecc_is_key_valid(unsigned int curve_id, unsigned int ndigits,
1145                      const u8 *private_key, unsigned int private_key_len)
1146 {
1147         int nbytes;
1148         const struct ecc_curve *curve = ecc_get_curve(curve_id);
1149
1150         if (!private_key)
1151                 return -EINVAL;
1152
1153         nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
1154
1155         if (private_key_len != nbytes)
1156                 return -EINVAL;
1157
1158         if (vli_is_zero((const u64 *)&private_key[0], ndigits))
1159                 return -EINVAL;
1160
1161         /* Make sure the private key is in the range [1, n-1]. */
1162         if (vli_cmp(curve->n, (const u64 *)&private_key[0], ndigits) != 1)
1163                 return -EINVAL;
1164
1165         return 0;
1166 }
1167 EXPORT_SYMBOL_GPL(ecc_is_key_valid);
1168
1169 int ecc_is_pub_key_valid(unsigned int curve_id, unsigned int ndigits,
1170                          const u8 *pub_key, unsigned int pub_key_len)
1171 {
1172         const struct ecc_curve *curve = ecc_get_curve(curve_id);
1173         int nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT;
1174         struct ecc_point p;
1175
1176         if (!pub_key || pub_key_len != 2 * nbytes)
1177                 return -EINVAL;
1178
1179         p.x = (u64 *)pub_key;
1180         p.y = (u64 *)(pub_key + ECC_MAX_DIGIT_BYTES);
1181         p.ndigits = ndigits;
1182
1183         if (vli_cmp(curve->p, p.x, ndigits) != 1 ||
1184             vli_cmp(curve->p, p.y, ndigits) != 1)
1185                 return -EINVAL;
1186
1187         return 0;
1188 }
1189 EXPORT_SYMBOL_GPL(ecc_is_pub_key_valid);