"""
Pure-Python data structure for working with Ed25519 (and Ristretto)
field elements and operations.
"""
from __future__ import annotations
from typing import Tuple, Sequence
import doctest
_TWO_TO_64 = 2 ** 64
_TWO_TO_128 = 2 ** 128
[docs]class fe25519:
"""
Class for creating and operating on field elements. The public
interface of this class is determined primarily by the needs of
the `ge25519 <https://pypi.org/project/ge25519>`__ library.
However, use of some built-in Python operators is supported via
special methods.
"""
# Precomputed static constants.
d = None
d2 = None
sqrtm1 = None
invsqrtamd = None
onemsqd = None
sqdmone = None
sqrtadm1 = None
curve25519_A = None
[docs] @staticmethod
def zero() -> fe25519:
"""
Constant corresponding to the zero element.
>>> fe25519.zero() + fe25519.one() == fe25519.one()
True
"""
return fe25519([0, 0, 0, 0, 0])
[docs] @staticmethod
def one() -> fe25519:
"""
Constant corresponding to the multiplicative identity element.
>>> fe25519.one() * fe25519.one() == fe25519.one()
True
"""
return fe25519([1, 0, 0, 0, 0])
def __init__(self: fe25519, ns: Sequence[int]):
"""Create field element using a list of five 64-bit integers."""
self.ns = ns # pylint: disable=invalid-name
[docs] def copy(self: fe25519) -> fe25519:
"""
Create a copy of this element instance.
>>> fe25519.one().copy() == fe25519.one()
True
"""
return fe25519(list(self.ns))
[docs] def reduce(self: fe25519) -> fe25519:
"""
Reduce this element to a canonical representation.
>>> (~fe25519.one()).reduce()
fe25519([1, 0, 0, 0, 0])
"""
t = self.ns # 128-bit integers.
mask = 2251799813685247
t[1] = (t[1] + (t[0] >> 51)) % _TWO_TO_128
t[0] &= mask
t[2] = (t[2] + (t[1] >> 51)) % _TWO_TO_128
t[1] &= mask
t[3] = (t[3] + (t[2] >> 51)) % _TWO_TO_128
t[2] &= mask
t[4] = (t[4] + (t[3] >> 51)) % _TWO_TO_128
t[3] &= mask
t[0] = (t[0] + 19 * (t[4] >> 51)) % _TWO_TO_128
t[4] &= mask
t[1] = (t[1] + (t[0] >> 51)) % _TWO_TO_128
t[0] &= mask
t[2] = (t[2] + (t[1] >> 51)) % _TWO_TO_128
t[1] &= mask
t[3] = (t[3] + (t[2] >> 51)) % _TWO_TO_128
t[2] &= mask
t[4] = (t[4] + (t[3] >> 51)) % _TWO_TO_128
t[3] &= mask
t[0] = (t[0] + 19 * (t[4] >> 51)) % _TWO_TO_128
t[4] &= mask
# Now t is between 0 and 2^255-1, properly carried.
# Сase 1: between 0 and 2^255-20. Case 2: between 2^255-19 and 2^255-1.
t[0] = (t[0] + 19) % _TWO_TO_128
t[1] = (t[1] + (t[0] >> 51)) % _TWO_TO_128
t[0] &= mask
t[2] = (t[2] + (t[1] >> 51)) % _TWO_TO_128
t[1] &= mask
t[3] = (t[3] + (t[2] >> 51)) % _TWO_TO_128
t[2] &= mask
t[4] = (t[4] + (t[3] >> 51)) % _TWO_TO_128
t[3] &= mask
t[0] = (t[0] + 19 * (t[4] >> 51)) % _TWO_TO_128
t[4] &= mask
# Now between 19 and 2^255-1 in both cases, and offset by 19.
t[0] = (t[0] + 2251799813685248 - 19) % _TWO_TO_128
t[1] = (t[1] + 2251799813685248 - 1) % _TWO_TO_128
t[2] = (t[2] + 2251799813685248 - 1) % _TWO_TO_128
t[3] = (t[3] + 2251799813685248 - 1) % _TWO_TO_128
t[4] = (t[4] + 2251799813685248 - 1) % _TWO_TO_128
# Now between 2^255 and 2^256-20, and offset by 2^255.
t[1] = (t[1] + (t[0] >> 51)) % _TWO_TO_128
t[0] &= mask
t[2] = (t[2] + (t[1] >> 51)) % _TWO_TO_128
t[1] &= mask
t[3] = (t[3] + (t[2] >> 51)) % _TWO_TO_128
t[2] &= mask
t[4] = (t[4] + (t[3] >> 51)) % _TWO_TO_128
t[3] &= mask
t[4] &= mask
return fe25519(t)
[docs] def __add__(self: fe25519, other: fe25519) -> fe25519:
"""
Compute the sum of this element and another element.
>>> fe25519.zero() + fe25519.zero() == fe25519.zero()
True
"""
return fe25519([(m+n)%_TWO_TO_64 for (m, n) in zip(self.ns, other.ns)])
[docs] def __neg__(self: fe25519) -> fe25519:
"""
Compute the negation of this element.
>>> fe25519.one().cneg(1) != fe25519.one()
True
"""
return fe25519.zero() - self
[docs] def cmov(self: fe25519, g: fe25519, b: int) -> fe25519:
"""Conditionally select this element or another based on a boolean integer."""
mask = _TWO_TO_64 - b
pairs = zip(self.ns, g.ns)
return fe25519([fi ^ ((fi ^ gi) & mask) for (fi, gi) in pairs])
[docs] def cneg(self: fe25519, b: int) -> fe25519:
"""
Compute the conditional negation of this element.
>>> fe25519.one().cneg(0) == fe25519.one()
True
>>> (fe25519.one().cneg(1) + fe25519.one()).is_zero()
1
"""
return self.copy().cmov(-self, b)
[docs] def __abs__(self: fe25519) -> fe25519:
"""
Compute the absolute value of this element.
>>> abs(-fe25519.one()).is_negative()
0
"""
return self.cneg(self.is_negative())
[docs] def __sub__(self: fe25519, other: fe25519) -> fe25519:
"""
Compute the result of subtracting another element from this element.
>>> fe25519.zero() - fe25519.one() == fe25519.one().cneg(1)
True
"""
mask = 2251799813685247
(h0, h1, h2, h3, h4) = other.ns
h1 = (h1 + (h0 >> 51)) % _TWO_TO_64
h0 &= mask
h2 = (h2 + (h1 >> 51)) % _TWO_TO_64
h1 &= mask
h3 = (h3 + (h2 >> 51)) % _TWO_TO_64
h2 &= mask
h4 = (h4 + (h3 >> 51)) % _TWO_TO_64
h3 &= mask
h0 = (h0 + 19 * (h4 >> 51)) % _TWO_TO_64
h4 &= mask
return fe25519([
((self.ns[0] + 4503599627370458) - h0) % _TWO_TO_64,
((self.ns[1] + 4503599627370494) - h1) % _TWO_TO_64,
((self.ns[2] + 4503599627370494) - h2) % _TWO_TO_64,
((self.ns[3] + 4503599627370494) - h3) % _TWO_TO_64,
((self.ns[4] + 4503599627370494) - h4) % _TWO_TO_64
])
[docs] def __mul__(self: fe25519, other: fe25519) -> fe25519:
"""
Compute the product of this element and another element.
>>> fe25519.one() * fe25519.zero() == fe25519.zero()
True
"""
mask = 2251799813685247 # 64-bit integer.
(f, g) = (self.ns, other.ns) # 64-bit integers.
r = [None, None, None, None, None] # 128-bit integers.
carry = None # 128-bit integer.
r0 = [None, None, None, None, None] # 64-bit integers.
f1_19 = (19 * f[1]) % _TWO_TO_64
f2_19 = (19 * f[2]) % _TWO_TO_64
f3_19 = (19 * f[3]) % _TWO_TO_64
f4_19 = (19 * f[4]) % _TWO_TO_64
r[0] = (f[0]*g[0] + f1_19*g[4] + f2_19*g[3] + f3_19*g[2] + f4_19*g[1]) % _TWO_TO_128
r[1] = (f[0]*g[1] + f[1]*g[0] + f2_19*g[4] + f3_19*g[3] + f4_19*g[2]) % _TWO_TO_128
r[2] = (f[0]*g[2] + f[1]*g[1] + f[2]*g[0] + f3_19*g[4] + f4_19*g[3]) % _TWO_TO_128
r[3] = (f[0]*g[3] + f[1]*g[2] + f[2]*g[1] + f[3]*g[0] + f4_19*g[4]) % _TWO_TO_128
r[4] = (f[0]*g[4] + f[1]*g[3] + f[2]*g[2] + f[3]*g[1] + f[4]*g[0]) % _TWO_TO_128
r0[0] = (r[0] % _TWO_TO_64) & mask
r[1] = (r[1] + (r[0] >> 51)) % _TWO_TO_128
r0[1] = (r[1] % _TWO_TO_64) & mask
r[2] = (r[2] + (r[1] >> 51)) % _TWO_TO_128
r0[2] = (r[2] % _TWO_TO_64) & mask
r[3] = (r[3] + (r[2] >> 51)) % _TWO_TO_128
r0[3] = (r[3] % _TWO_TO_64) & mask
r[4] = (r[4] + (r[3] >> 51)) % _TWO_TO_128
r0[4] = (r[4] % _TWO_TO_64) & mask
r0[0] = (r0[0] + (19*((r[4] >> 51) % _TWO_TO_64))) % _TWO_TO_64
carry = r0[0] >> 51
r0[0] &= mask
r0[1] = (r0[1] + (carry % _TWO_TO_64)) % _TWO_TO_64
carry = r0[1] >> 51
r0[1] &= mask
r0[2] = (r0[2] + (carry % _TWO_TO_64)) % _TWO_TO_64
return fe25519(r0)
[docs] def sq(self: fe25519) -> fe25519: # pylint: disable=invalid-name
"""
Compute the square of this element.
>>> two = fe25519.one() + fe25519.one()
>>> four = two + two
>>> two.sq() == four
True
"""
mask = 2251799813685247 # 64-bit integer.
f = self.ns # 64-bit integers.
r = [None, None, None, None, None] # 128-bit integers.
carry = None # 128-bit integer.
r0 = [None, None, None, None, None] # 64-bit integers.
f0_2 = (f[0] << 1) % _TWO_TO_64
f1_2 = (f[1] << 1) % _TWO_TO_64
f1_38 = (38 * f[1]) % _TWO_TO_64
f2_38 = (38 * f[2]) % _TWO_TO_64
f3_38 = (38 * f[3]) % _TWO_TO_64
f3_19 = (19 * f[3]) % _TWO_TO_64
f4_19 = (19 * f[4]) % _TWO_TO_64
r[0] = (f[0]*f[0] + f1_38*f[4] + f2_38*f[3]) % _TWO_TO_128
r[1] = (f0_2*f[1] + f2_38*f[4] + f3_19*f[3]) % _TWO_TO_128
r[2] = (f0_2*f[2] + f[1]*f[1] + f3_38*f[4]) % _TWO_TO_128
r[3] = (f0_2*f[3] + f1_2*f[2] + f4_19*f[4]) % _TWO_TO_128
r[4] = (f0_2*f[4] + f1_2*f[3] + f[2]*f[2]) % _TWO_TO_128
r0[0] = (r[0] % _TWO_TO_64) & mask
r[1] = (r[1] + (r[0] >> 51)) % _TWO_TO_128
r0[1] = (r[1] % _TWO_TO_64) & mask
r[2] = (r[2] + (r[1] >> 51)) % _TWO_TO_128
r0[2] = (r[2] % _TWO_TO_64) & mask
r[3] = (r[3] + (r[2] >> 51)) % _TWO_TO_128
r0[3] = (r[3] % _TWO_TO_64) & mask
r[4] = (r[4] + (r[3] >> 51)) % _TWO_TO_128
r0[4] = (r[4] % _TWO_TO_64) & mask
r0[0] = (r0[0] + (19*((r[4] >> 51) % _TWO_TO_64))) % _TWO_TO_64
carry = r0[0] >> 51
r0[0] &= mask
r0[1] = (r0[1] + (carry % _TWO_TO_64)) % _TWO_TO_64
carry = r0[1] >> 51
r0[1] &= mask
r0[2] = (r0[2] + (carry % _TWO_TO_64)) % _TWO_TO_64
return fe25519(r0)
[docs] def sq2(self: fe25519) -> fe25519:
"""
Compute the element that is twice the square of this element.
>>> two = fe25519.one() + fe25519.one()
>>> two.sq2() == two.sq() + two.sq()
True
"""
mask = 2251799813685247
f = self.ns # 64-bit integers.
r = [None, None, None, None, None] # 128-bit integers.
carry = None # 128-bit integer.
r0 = [None, None, None, None, None] # 64-bit integers.
f0_2 = (f[0] << 1) % _TWO_TO_64
f1_2 = (f[1] << 1) % _TWO_TO_64
f1_38 = (38 * f[1]) % _TWO_TO_64
f2_38 = (38 * f[2]) % _TWO_TO_64
f3_38 = (38 * f[3]) % _TWO_TO_64
f3_19 = (19 * f[3]) % _TWO_TO_64
f4_19 = (19 * f[4]) % _TWO_TO_64
r[0] = (f[0]*f[0] + f1_38*f[4] + f2_38*f[3]) % _TWO_TO_128
r[1] = (f0_2*f[1] + f2_38*f[4] + f3_19*f[3]) % _TWO_TO_128
r[2] = (f0_2*f[2] + f[1]*f[1] + f3_38*f[4]) % _TWO_TO_128
r[3] = (f0_2*f[3] + f1_2*f[2] + f4_19*f[4]) % _TWO_TO_128
r[4] = (f0_2*f[4] + f1_2*f[3] + f[2]*f[2]) % _TWO_TO_128
r[0] <<= 1
r[1] <<= 1
r[2] <<= 1
r[3] <<= 1
r[4] <<= 1
r0[0] = (r[0] % _TWO_TO_64) & mask
carry = r[0] >> 51
r[1] = (r[1] + carry) % _TWO_TO_128
r0[1] = (r[1] % _TWO_TO_64) & mask
carry = r[1] >> 51
r[2] = (r[2] + carry) % _TWO_TO_128
r0[2] = (r[2] % _TWO_TO_64) & mask
carry = r[2] >> 51
r[3] = (r[3] + carry) % _TWO_TO_128
r0[3] = (r[3] % _TWO_TO_64) & mask
carry = r[3] >> 51
r[4] = (r[4] + carry) % _TWO_TO_128
r0[4] = (r[4] % _TWO_TO_64) & mask
carry = r[4] >> 51
r0[0] = (r0[0] + 19*carry) % _TWO_TO_64
carry = r0[0] >> 51
r0[0] &= mask
r0[1] = (r0[1] + (carry % _TWO_TO_64)) % _TWO_TO_64
carry = r0[1] >> 51
r0[1] &= mask
r0[2] = (r0[2] + (carry % _TWO_TO_64)) % _TWO_TO_64
return fe25519(r0)
[docs] def pow22523(self: fe25519) -> fe25519:
"""
Compute the result of the exponentiation of this element by a
special fixed exponent.
"""
z = self.copy()
t0 = z.sq()
t1 = t0.sq()
t1 = t1.sq()
t1 = z * t1
t0 = t0 * t1
t0 = t0.sq()
t0 = t1 * t0
t1 = t0.sq()
for _ in range(1, 5):
t1 = t1.sq()
t0 = t1 * t0
t1 = t0.sq()
for _ in range(1, 10):
t1 = t1.sq()
t1 = t1 * t0
t2 = t1.sq()
for _ in range(1, 20):
t2 = t2.sq()
t1 = t2 * t1
t1 = t1.sq()
for _ in range(1, 10):
t1 = t1.sq()
t0 = t1 * t0
t1 = t0.sq()
for _ in range(1, 50):
t1 = t1.sq()
t1 = t1 * t0
t2 = t1.sq()
for _ in range(1, 100):
t2 = t2.sq()
t1 = t2 * t1
t1 = t1.sq()
for _ in range(1, 50):
t1 = t1.sq()
t0 = t1 * t0
t0 = t0.sq()
t0 = t0.sq()
return t0 * z
[docs] def invert(self: fe25519) -> fe25519:
"""
Compute the multiplicative inverse of this element.
>>> two = fe25519.one() + fe25519.one()
>>> (two.invert() * two).reduce() == fe25519.one()
True
"""
z = self.copy()
t0 = z.sq()
t1 = t0.sq()
t1 = t1.sq()
t1 = z * t1
t0 = t0 * t1
t2 = t0.sq()
t1 = t1 * t2
t2 = t1.sq()
for _ in range(1, 5):
t2 = t2.sq()
t1 = t2 * t1
t2 = t1.sq()
for _ in range(1, 10):
t2 = t2.sq()
t2 = t2 * t1
t3 = t2.sq()
for _ in range(1, 20):
t3 = t3.sq()
t2 = t3 * t2
t2 = t2.sq()
for _ in range(1, 10):
t2 = t2.sq()
t1 = t2 * t1
t2 = t1.sq()
for _ in range(1, 50):
t2 = t2.sq()
t2 = t2 * t1
t3 = t2.sq()
for _ in range(1, 100):
t3 = t3.sq()
t2 = t3 * t2
t2 = t2.sq()
for _ in range(1, 50):
t2 = t2.sq()
t1 = t2 * t1
t1 = t1.sq()
for _ in range(1, 5):
t1 = t1.sq()
return t1 * t0
[docs] def __invert__(self: fe25519) -> fe25519:
"""
Compute the multiplicative inverse of this element.
>>> two = fe25519.one() + fe25519.one()
>>> (((~two) * two) - fe25519.one()).is_zero()
1
"""
return self.invert()
[docs] def __pow__(self: fe25519, e: int) -> fe25519:
"""
Exponentiation is a synonym for squaring and inversion.
>>> two = fe25519.one() + fe25519.one()
>>> two**2 == two * two == two.sq()
True
>>> ~fe25519.one() == fe25519.one() ** (-1)
True
"""
if e == 2: # Squaring.
return self.sq()
if e == -1: # Inversion.
return self.invert()
# Supplied exponent is not supported.
return None
[docs] def sqrt_ratio_m1_ristretto255(self: fe25519, v: fe25519) -> Tuple[fe25519, int]:
"""
Compute the result of a specialized root operation.
"""
u = self
v3 = v.sq()
v3 = v3 * v # v3 = v^3
x = v3.sq()
x = x * v
x = x * u # x = uv^7
x = x.pow22523() # x = (uv^7)^((q-5)/8)
x = x * v3
x = x * u # x = uv^3(uv^7)^((q-5)/8)
vxx = x.sq()
vxx = vxx * v # vx^2
m_root_check = vxx - u # vx^2-u
p_root_check = vxx + u # vx^2+u
f_root_check = u * fe25519.sqrtm1 # u*sqrt(-1)
f_root_check = vxx + f_root_check # vx^2+u*sqrt(-1)
has_m_root = m_root_check.is_zero()
has_p_root = p_root_check.is_zero()
has_f_root = f_root_check.is_zero()
x_sqrtm1 = x * fe25519.sqrtm1 # x*sqrt(-1)
x = x.cmov(x_sqrtm1, has_p_root | has_f_root)
x = abs(x)
return (x, has_m_root | has_p_root)
[docs] def chi25519(self: fe25519) -> fe25519:
"""
Compute the result of a specialized root operation (for elligator).
"""
t0 = self.sq()
t1 = t0 * self
t0 = t1.sq()
t2 = t0.sq()
t2 = t2.sq()
t2 = t2 * t0
t1 = t2 * self
t2 = t1.sq()
for _ in range(1, 5):
t2 = t2.sq()
t1 = t2 * t1
t2 = t1.sq()
for _ in range(1, 10):
t2 = t2.sq()
t2 = t2 * t1
t3 = t2.sq()
for _ in range(1, 20):
t3 = t3.sq()
t2 = t3 * t2
t2 = t2.sq()
for _ in range(1, 10):
t2 = t2.sq()
t1 = t2 * t1
t2 = t1.sq()
for _ in range(1, 50):
t2 = t2.sq()
t2 = t2 * t1
t3 = t2.sq()
for _ in range(1, 100):
t3 = t3.sq()
t2 = t3 * t2
t2 = t2.sq()
for _ in range(1, 50):
t2 = t2.sq()
t1 = t2 * t1
t1 = t1.sq()
for _ in range(1, 4):
t1 = t1.sq()
return t1 * t0
[docs] def __eq__(self: fe25519, other: fe25519) -> bool:
"""
Determine whether this element and another are equivalent.
>>> fe25519.zero() == fe25519.one()
False
>>> fe25519.one() == fe25519.one()
True
"""
return self.ns == other.ns
[docs] def is_zero(self: fe25519) -> int:
"""
Determine whether this element is zero.
>>> fe25519.zero().is_zero()
1
>>> fe25519.one().is_zero()
0
"""
bs = self.to_bytes()
d = 0
for b in bs:
d |= b
return 1 & ((d - 1) >> 8)
[docs] def is_negative(self: fe25519) -> int:
"""
Determine whether the negation bit is set in this element.
>>> fe25519.zero().is_negative()
0
"""
bs = self.to_bytes()
return bs[0] & 1
[docs] @staticmethod
def from_bytes(bs: bytes) -> fe25519:
"""
Assemble an element instance from its byte representation.
>>> s = '0100000000000000000000000000000000000000000000000000000000000000'
>>> fe25519.from_bytes(bytes.fromhex(s))
fe25519([1, 0, 0, 0, 0])
"""
mask = 2251799813685247
def load64_le(bs):
w = bs[0]
w |= bs[1] << 8
w |= bs[2] << 16
w |= bs[3] << 24
w |= bs[4] << 32
w |= bs[5] << 40
w |= bs[6] << 48
w |= bs[7] << 56
return w
return fe25519([
(load64_le(bs[0:8])) & mask,
(load64_le(bs[6:14]) >> 3) & mask,
(load64_le(bs[12:20]) >> 6) & mask,
(load64_le(bs[19:27]) >> 1) & mask,
(load64_le(bs[24:32]) >> 12) & mask
])
[docs] def to_bytes(self: fe25519) -> bytes:
"""
Build the byte representation of this element.
>>> fe25519.one().to_bytes().hex()
'0100000000000000000000000000000000000000000000000000000000000000'
"""
t = self.reduce().ns
t0 = t[0] | ((t[1] << 51) % _TWO_TO_64)
t1 = (t[1] >> 13) | ((t[2] << 38) % _TWO_TO_64)
t2 = (t[2] >> 26) | ((t[3] << 25) % _TWO_TO_64)
t3 = (t[3] >> 39) | ((t[4] << 12) % _TWO_TO_64)
bs = bytearray()
bs.extend(t0.to_bytes(8, 'little'))
bs.extend(t1.to_bytes(8, 'little'))
bs.extend(t2.to_bytes(8, 'little'))
bs.extend(t3.to_bytes(8, 'little'))
return bytes(bs)
[docs] def __bytes__(self: fe25519) -> bytes:
"""
Build the byte representation of this element.
>>> bytes(fe25519.one()).hex()
'0100000000000000000000000000000000000000000000000000000000000000'
"""
return self.to_bytes()
[docs] def __str__(self: fe25519) -> str:
"""
Obtain the string representation of an element.
>>> str(fe25519.one())
'fe25519([1, 0, 0, 0, 0])'
"""
return 'fe25519(' + str(self.ns) + ')'
[docs] def __repr__(self: fe25519) -> str:
"""
Obtain the string representation of an element.
"""
return str(self) # pragma: no cover
# Precomputed static constants.
fe25519.d = fe25519([
929955233495203, 466365720129213, 1662059464998953, 2033849074728123, 1442794654840575
])
fe25519.d2 = fe25519([
1859910466990425, 932731440258426, 1072319116312658, 1815898335770999, 633789495995903
])
fe25519.sqrtm1 = fe25519([
1718705420411056, 234908883556509, 2233514472574048, 2117202627021982, 765476049583133
])
fe25519.invsqrtamd = fe25519([
278908739862762, 821645201101625, 8113234426968, 1777959178193151, 2118520810568447
])
fe25519.onemsqd = fe25519([
1136626929484150, 1998550399581263, 496427632559748, 118527312129759, 45110755273534
])
fe25519.sqdmone = fe25519([
1507062230895904, 1572317787530805, 683053064812840, 317374165784489, 1572899562415810
])
fe25519.sqrtadm1 = fe25519([
2241493124984347, 425987919032274, 2207028919301688, 1220490630685848, 974799131293748
])
fe25519.curve25519_A = fe25519([486662, 0, 0, 0, 0])
if __name__ == '__main__':
doctest.testmod() # pragma: no cover