mongoose/src/tls_x25519.c

247 lines
7.0 KiB
C
Raw Normal View History

2024-03-25 08:34:05 +00:00
/**
* Adapted from STROBE: https://strobe.sourceforge.io/
* Copyright (c) 2015-2016 Cryptography Research, Inc.
* Author: Mike Hamburg
* License: MIT License
*/
#include "tls_x25519.h"
const uint8_t X25519_BASE_POINT[X25519_BYTES] = {9};
#define X25519_WBITS 32
typedef uint32_t limb_t;
typedef uint64_t dlimb_t;
typedef int64_t sdlimb_t;
#define LIMB(x) (uint32_t)(x##ull), (uint32_t) ((x##ull) >> 32)
#define NLIMBS (256 / X25519_WBITS)
typedef limb_t fe[NLIMBS];
static limb_t umaal(limb_t *carry, limb_t acc, limb_t mand, limb_t mier) {
dlimb_t tmp = (dlimb_t) mand * mier + acc + *carry;
*carry = (limb_t) (tmp >> X25519_WBITS);
return (limb_t) tmp;
}
// These functions are implemented in terms of umaal on ARM
static limb_t adc(limb_t *carry, limb_t acc, limb_t mand) {
dlimb_t total = (dlimb_t) *carry + acc + mand;
*carry = (limb_t) (total >> X25519_WBITS);
return (limb_t) total;
}
static limb_t adc0(limb_t *carry, limb_t acc) {
dlimb_t total = (dlimb_t) *carry + acc;
*carry = (limb_t) (total >> X25519_WBITS);
return (limb_t) total;
}
// - Precondition: carry is small.
// - Invariant: result of propagate is < 2^255 + 1 word
// - In particular, always less than 2p.
// - Also, output x >= min(x,19)
static void propagate(fe x, limb_t over) {
unsigned i;
limb_t carry;
over = x[NLIMBS - 1] >> (X25519_WBITS - 1) | over << 1;
x[NLIMBS - 1] &= ~((limb_t) 1 << (X25519_WBITS - 1));
carry = over * 19;
for (i = 0; i < NLIMBS; i++) {
x[i] = adc0(&carry, x[i]);
}
}
static void add(fe out, const fe a, const fe b) {
unsigned i;
limb_t carry = 0;
for (i = 0; i < NLIMBS; i++) {
out[i] = adc(&carry, a[i], b[i]);
}
propagate(out, carry);
}
static void sub(fe out, const fe a, const fe b) {
unsigned i;
sdlimb_t carry = -38;
for (i = 0; i < NLIMBS; i++) {
carry = carry + a[i] - b[i];
out[i] = (limb_t) carry;
carry >>= X25519_WBITS;
}
propagate(out, (limb_t) (1 + carry));
}
// `b` can contain less than 8 limbs, thus we use `limb_t *` instead of `fe`
// to avoid build warnings
static void mul(fe out, const fe a, const limb_t *b, unsigned nb) {
limb_t accum[2 * NLIMBS] = {0};
unsigned i, j;
limb_t carry2;
for (i = 0; i < nb; i++) {
limb_t mand = b[i];
carry2 = 0;
for (j = 0; j < NLIMBS; j++) {
limb_t tmp; // "a" may be misaligned
memcpy(&tmp, &a[j], sizeof(tmp)); // So make an aligned copy
accum[i + j] = umaal(&carry2, accum[i + j], mand, tmp);
}
accum[i + j] = carry2;
}
carry2 = 0;
for (j = 0; j < NLIMBS; j++) {
out[j] = umaal(&carry2, accum[j], 38, accum[j + NLIMBS]);
}
propagate(out, carry2);
}
static void sqr(fe out, const fe a) {
mul(out, a, a, NLIMBS);
}
static void mul1(fe out, const fe a) {
mul(out, a, out, NLIMBS);
}
static void sqr1(fe a) {
mul1(a, a);
}
static void condswap(limb_t a[2 * NLIMBS], limb_t b[2 * NLIMBS],
limb_t doswap) {
unsigned i;
for (i = 0; i < 2 * NLIMBS; i++) {
limb_t xor_ab = (a[i] ^ b[i]) & doswap;
a[i] ^= xor_ab;
b[i] ^= xor_ab;
}
}
// Canonicalize a field element x, reducing it to the least residue which is
// congruent to it mod 2^255-19
// - Precondition: x < 2^255 + 1 word
static limb_t canon(fe x) {
// First, add 19.
unsigned i;
limb_t carry0 = 19;
limb_t res;
sdlimb_t carry;
for (i = 0; i < NLIMBS; i++) {
x[i] = adc0(&carry0, x[i]);
}
propagate(x, carry0);
// Here, 19 <= x2 < 2^255
// - This is because we added 19, so before propagate it can't be less
// than 19. After propagate, it still can't be less than 19, because if
// propagate does anything it adds 19.
// - We know that the high bit must be clear, because either the input was ~
// 2^255 + one word + 19 (in which case it propagates to at most 2 words) or
// it was < 2^255. So now, if we subtract 19, we will get back to something in
// [0,2^255-19).
carry = -19;
res = 0;
for (i = 0; i < NLIMBS; i++) {
carry += x[i];
res |= x[i] = (limb_t) carry;
carry >>= X25519_WBITS;
}
return (limb_t) (((dlimb_t) res - 1) >> X25519_WBITS);
}
static const limb_t a24[1] = {121665};
static void ladder_part1(fe xs[5]) {
limb_t *x2 = xs[0], *z2 = xs[1], *x3 = xs[2], *z3 = xs[3], *t1 = xs[4];
add(t1, x2, z2); // t1 = A
sub(z2, x2, z2); // z2 = B
add(x2, x3, z3); // x2 = C
sub(z3, x3, z3); // z3 = D
mul1(z3, t1); // z3 = DA
mul1(x2, z2); // x3 = BC
add(x3, z3, x2); // x3 = DA+CB
sub(z3, z3, x2); // z3 = DA-CB
sqr1(t1); // t1 = AA
sqr1(z2); // z2 = BB
sub(x2, t1, z2); // x2 = E = AA-BB
mul(z2, x2, a24, sizeof(a24) / sizeof(a24[0])); // z2 = E*a24
add(z2, z2, t1); // z2 = E*a24 + AA
}
static void ladder_part2(fe xs[5], const fe x1) {
limb_t *x2 = xs[0], *z2 = xs[1], *x3 = xs[2], *z3 = xs[3], *t1 = xs[4];
sqr1(z3); // z3 = (DA-CB)^2
mul1(z3, x1); // z3 = x1 * (DA-CB)^2
sqr1(x3); // x3 = (DA+CB)^2
mul1(z2, x2); // z2 = AA*(E*a24+AA)
sub(x2, t1, x2); // x2 = BB again
mul1(x2, t1); // x2 = AA*BB
}
static void x25519_core(fe xs[5], const uint8_t scalar[X25519_BYTES],
const uint8_t *x1, int clamp) {
int i;
limb_t swap = 0;
limb_t *x2 = xs[0], *x3 = xs[2], *z3 = xs[3];
memset(xs, 0, 4 * sizeof(fe));
x2[0] = z3[0] = 1;
memcpy(x3, x1, sizeof(fe));
for (i = 255; i >= 0; i--) {
uint8_t bytei = scalar[i / 8];
limb_t doswap;
if (clamp) {
if (i / 8 == 0) {
bytei &= (uint8_t) ~7U;
} else if (i / 8 == X25519_BYTES - 1) {
bytei &= 0x7F;
bytei |= 0x40;
}
}
doswap = 0 - (limb_t) ((bytei >> (i % 8)) & 1);
condswap(x2, x3, swap ^ doswap);
swap = doswap;
ladder_part1(xs);
ladder_part2(xs, (const limb_t *) x1);
}
condswap(x2, x3, swap);
}
int mg_tls_x25519(uint8_t out[X25519_BYTES], const uint8_t scalar[X25519_BYTES],
const uint8_t x1[X25519_BYTES], int clamp) {
int i, ret;
fe xs[5];
limb_t *x2, *z2, *z3, *prev;
static const struct {
uint8_t a, c, n;
} steps[13] = {{2, 1, 1}, {2, 1, 1}, {4, 2, 3}, {2, 4, 6}, {3, 1, 1},
{3, 2, 12}, {4, 3, 25}, {2, 3, 25}, {2, 4, 50}, {3, 2, 125},
{3, 1, 2}, {3, 1, 2}, {3, 1, 1}};
x25519_core(xs, scalar, x1, clamp);
// Precomputed inversion chain
x2 = xs[0];
z2 = xs[1];
z3 = xs[3];
prev = z2;
for (i = 0; i < 13; i++) {
int j;
limb_t *a = xs[steps[i].a];
for (j = steps[i].n; j > 0; j--) {
sqr(a, prev);
prev = a;
}
mul1(a, xs[steps[i].c]);
}
// Here prev = z3
// x2 /= z2
mul((limb_t *) out, x2, z3, NLIMBS);
ret = (int) canon((limb_t *) out);
if (!clamp) ret = 0;
return ret;
}