-rwxr-xr-x 11911 libntruprime-20241008/src/core/inv3sntrupP/avx/r3_recipgen.py raw
#!/usr/bin/env python3
import re
import sys
from math import ceil
p = int(sys.argv[1])
numvec = ceil(p/256)
ppad = 256*numvec
# see comment below about order of coefficients
f = [-1,-1] + [0]*(p-2) + [1]
f = list(reversed(f))
while len(f)%256: f += [0]
f0 = []
for j in range(len(f)//256):
for i in range(4):
for k in range(64):
f0 += [f[j*256+i+k*4]]
f1 = [1 if f0i<0 else 0 for f0i in f0]
f0 = [f0i&1 for f0i in f0]
# ---------- utility functions
out = """\
// 20240812 djb: more cryptoint usage
#include "crypto_core.h"
#include <immintrin.h>
#include "crypto_int8.h"
#define int8 crypto_int8
typedef int8 small;
#include "crypto_int32.h"
#include "crypto_int64.h"
#include "crypto_uint64.h"
#define p P
#define ppad PPAD
#define numvec NUMVEC
typedef __m256i vec256;
/*
This code stores PPAD-coeff poly as vec256[NUMVEC].
Order of 256 coefficients in each vec256
is optimized in light of costs of vector instructions:
0,4,...,252 in 64-bit word;
1,5,...,253 in 64-bit word;
2,6,...,254 in 64-bit word;
3,7,...,255 in 64-bit word.
*/
static inline void vec256_frombits(vec256 *v,const small *b)
{
int i;
for (i = 0;i < numvec;++i) {
vec256 b0 = _mm256_loadu_si256((vec256 *) b); b += 32; /* 0,1,...,31 */
vec256 b1 = _mm256_loadu_si256((vec256 *) b); b += 32; /* 32,33,... */
vec256 b2 = _mm256_loadu_si256((vec256 *) b); b += 32;
vec256 b3 = _mm256_loadu_si256((vec256 *) b); b += 32;
vec256 b4 = _mm256_loadu_si256((vec256 *) b); b += 32;
vec256 b5 = _mm256_loadu_si256((vec256 *) b); b += 32;
vec256 b6 = _mm256_loadu_si256((vec256 *) b); b += 32;
vec256 b7 = _mm256_loadu_si256((vec256 *) b); b += 32;
vec256 c0 = _mm256_unpacklo_epi32(b0,b1); /* 0 1 2 3 32 33 34 35 4 5 6 7 36 37 38 39 ... 55 */
vec256 c1 = _mm256_unpackhi_epi32(b0,b1); /* 8 9 10 11 40 41 42 43 ... 63 */
vec256 c2 = _mm256_unpacklo_epi32(b2,b3);
vec256 c3 = _mm256_unpackhi_epi32(b2,b3);
vec256 c4 = _mm256_unpacklo_epi32(b4,b5);
vec256 c5 = _mm256_unpackhi_epi32(b4,b5);
vec256 c6 = _mm256_unpacklo_epi32(b6,b7);
vec256 c7 = _mm256_unpackhi_epi32(b6,b7);
vec256 d0 = c0 | _mm256_slli_epi32(c1,2); /* 0 8, 1 9, 2 10, 3 11, 32 40, 33 41, ..., 55 63 */
vec256 d2 = c2 | _mm256_slli_epi32(c3,2);
vec256 d4 = c4 | _mm256_slli_epi32(c5,2);
vec256 d6 = c6 | _mm256_slli_epi32(c7,2);
vec256 e0 = _mm256_unpacklo_epi64(d0,d2);
vec256 e2 = _mm256_unpackhi_epi64(d0,d2);
vec256 e4 = _mm256_unpacklo_epi64(d4,d6);
vec256 e6 = _mm256_unpackhi_epi64(d4,d6);
vec256 f0 = e0 | _mm256_slli_epi32(e2,1);
vec256 f4 = e4 | _mm256_slli_epi32(e6,1);
vec256 g0 = _mm256_permute2x128_si256(f0,f4,0x20);
vec256 g4 = _mm256_permute2x128_si256(f0,f4,0x31);
vec256 h = g0 | _mm256_slli_epi32(g4,4);
#define TRANSPOSE _mm256_set_epi8( 31,27,23,19, 30,26,22,18, 29,25,21,17, 28,24,20,16, 15,11,7,3, 14,10,6,2, 13,9,5,1, 12,8,4,0 )
h = _mm256_shuffle_epi8(h,TRANSPOSE);
h = _mm256_permute4x64_epi64(h,0xd8);
h = _mm256_shuffle_epi32(h,0xd8);
*v++ = h;
}
}
static inline void vec256_tobits(const vec256 *v,small *b)
{
int i;
for (i = 0;i < numvec;++i) {
vec256 h = *v++;
h = _mm256_shuffle_epi32(h,0xd8);
h = _mm256_permute4x64_epi64(h,0xd8);
h = _mm256_shuffle_epi8(h,TRANSPOSE);
vec256 g0 = h & _mm256_set1_epi8(15);
vec256 g4 = _mm256_srli_epi32(h,4) & _mm256_set1_epi8(15);
vec256 f0 = _mm256_permute2x128_si256(g0,g4,0x20);
vec256 f4 = _mm256_permute2x128_si256(g0,g4,0x31);
vec256 e0 = f0 & _mm256_set1_epi8(5);
vec256 e2 = _mm256_srli_epi32(f0,1) & _mm256_set1_epi8(5);
vec256 e4 = f4 & _mm256_set1_epi8(5);
vec256 e6 = _mm256_srli_epi32(f4,1) & _mm256_set1_epi8(5);
vec256 d0 = _mm256_unpacklo_epi32(e0,e2);
vec256 d2 = _mm256_unpackhi_epi32(e0,e2);
vec256 d4 = _mm256_unpacklo_epi32(e4,e6);
vec256 d6 = _mm256_unpackhi_epi32(e4,e6);
vec256 c0 = d0 & _mm256_set1_epi8(1);
vec256 c1 = _mm256_srli_epi32(d0,2) & _mm256_set1_epi8(1);
vec256 c2 = d2 & _mm256_set1_epi8(1);
vec256 c3 = _mm256_srli_epi32(d2,2) & _mm256_set1_epi8(1);
vec256 c4 = d4 & _mm256_set1_epi8(1);
vec256 c5 = _mm256_srli_epi32(d4,2) & _mm256_set1_epi8(1);
vec256 c6 = d6 & _mm256_set1_epi8(1);
vec256 c7 = _mm256_srli_epi32(d6,2) & _mm256_set1_epi8(1);
vec256 b0 = _mm256_unpacklo_epi64(c0,c1);
vec256 b1 = _mm256_unpackhi_epi64(c0,c1);
vec256 b2 = _mm256_unpacklo_epi64(c2,c3);
vec256 b3 = _mm256_unpackhi_epi64(c2,c3);
vec256 b4 = _mm256_unpacklo_epi64(c4,c5);
vec256 b5 = _mm256_unpackhi_epi64(c4,c5);
vec256 b6 = _mm256_unpacklo_epi64(c6,c7);
vec256 b7 = _mm256_unpackhi_epi64(c6,c7);
_mm256_storeu_si256((vec256 *) b,b0); b += 32;
_mm256_storeu_si256((vec256 *) b,b1); b += 32;
_mm256_storeu_si256((vec256 *) b,b2); b += 32;
_mm256_storeu_si256((vec256 *) b,b3); b += 32;
_mm256_storeu_si256((vec256 *) b,b4); b += 32;
_mm256_storeu_si256((vec256 *) b,b5); b += 32;
_mm256_storeu_si256((vec256 *) b,b6); b += 32;
_mm256_storeu_si256((vec256 *) b,b7); b += 32;
}
}
static void vec256_init(vec256 *G0,vec256 *G1,const small *s)
{
int i;
small srev[ppad+(ppad-p)];
small si;
small g0[ppad];
small g1[ppad];
for (i = 0;i < p;++i) srev[ppad-1-i] = s[i];
for (i = 0;i < ppad-p;++i) srev[i] = 0;
for (i = p;i < ppad;++i) srev[i+ppad-p] = 0;
for (i = 0;i < ppad;++i) {
si = srev[i+ppad-p];
g0[i] = crypto_int8_bottombit_01(si);
g1[i] = (si >> 1) & g0[i];
}
vec256_frombits(G0,g0);
vec256_frombits(G1,g1);
}
static void vec256_final(small *out,const vec256 *V0,const vec256 *V1)
{
int i;
small v0[ppad];
small v1[ppad];
small v[ppad];
small vrev[ppad+(ppad-p)];
vec256_tobits(V0,v0);
vec256_tobits(V1,v1);
for (i = 0;i < ppad;++i)
v[i] = v0[i] + 2*v1[i] - 4*(v0[i]&v1[i]);
for (i = 0;i < ppad;++i) vrev[i] = v[ppad-1-i];
for (i = ppad;i < ppad+(ppad-p);++i) vrev[i] = 0;
for (i = 0;i < p;++i) out[i] = vrev[i+ppad-p];
}
static inline void vec256_swap(vec256 *f,vec256 *g,int len,vec256 mask)
{
vec256 flip;
int i;
for (i = 0;i < len;++i) {
flip = mask & (f[i] ^ g[i]);
f[i] ^= flip;
g[i] ^= flip;
}
}
static inline void vec256_scale(vec256 *f0,vec256 *f1,const vec256 c0,const vec256 c1)
{
int i;
for (i = 0;i < numvec;++i) {
vec256 f0i = f0[i];
vec256 f1i = f1[i];
f0i &= c0;
f1i ^= c1;
f1i &= f0i;
f0[i] = f0i;
f1[i] = f1i;
}
}
static inline void vec256_eliminate(vec256 *f0,vec256 *f1,vec256 *g0,vec256 *g1,int len,const vec256 c0,const vec256 c1)
{
int i;
for (i = 0;i < len;++i) {
vec256 f0i = f0[i];
vec256 f1i = f1[i];
vec256 g0i = g0[i];
vec256 g1i = g1[i];
vec256 t;
f0i &= c0;
f1i ^= c1;
f1i &= f0i;
t = g0i ^ f0i;
g0[i] = t | (g1i ^ f1i);
g1[i] = (g1i ^ f0i) & (f1i ^ t);
}
}
static inline int vec256_bit0mask(vec256 *f)
{
return crypto_int32_bottombit_mask(_mm_cvtsi128_si32(_mm256_castsi256_si128(f[0])));
}
"""
# ---------- divx
for j in range(1,numvec+1):
out += 'static inline void vec256_divx_%d(vec256 *f)\n' % j
out += '{\n'
for i in range(j):
out += ' vec256 f%d = f[%d];\n' % (i,i)
out += '\n'
for i in range(j):
out += ' unsigned long long low%d = _mm_cvtsi128_si64(_mm256_castsi256_si128(f%d));\n' % (i,i)
out += '\n'
for i in range(j):
if i == j-1:
out += ' low%d = low%d >> 1;\n' % (i,i)
else:
out += ' low%d = (low%d >> 1) | crypto_uint64_shlmod(low%d,63);\n' % (i,i,i+1)
out += '\n'
for i in range(j):
out += ' f%d = _mm256_blend_epi32(f%d,_mm256_set_epi64x(0,0,0,low%d),0x3);\n' % (i,i,i)
out += '\n'
for i in range(j):
out += ' f[%d] = _mm256_permute4x64_epi64(f%d,0x39);\n' % (i,i)
out += '}\n'
out += '\n'
# ----------
for j in range(1,numvec+1):
out += 'static inline void vec256_timesx_%d(vec256 *f)\n' % j
out += '{\n'
for i in range(j):
out += ' vec256 f%d = _mm256_permute4x64_epi64(f[%d],0x93);\n' % (i,i)
out += '\n'
for i in range(j):
if (j,i) in [(3,0),(3,1)]: # XXX: search through subsets for optimal timings
out += ' unsigned long long low%d = *(unsigned long long *) &f%d;\n' % (i,i)
else:
out += ' unsigned long long low%d = _mm_cvtsi128_si64(_mm256_castsi256_si128(f%d));\n' % (i,i)
out += '\n'
for i in reversed(range(j)):
if i == 0:
out += ' low%d = low%d << 1;\n' % (i,i)
else:
out += ' low%d = (low%d << 1) | crypto_int64_negative_01(low%d);\n' % (i,i,i-1)
out += '\n'
for i in range(j):
if (j,i) in [(3,0),(3,1)]: # XXX: search through subsets for optimal timings
out += ' *(unsigned long long *) &f%d = low%d;\n' % (i,i)
else:
out += ' f%d = _mm256_blend_epi32(f%d,_mm256_set_epi64x(0,0,0,low%d),0x3);\n' % (i,i,i)
out += '\n'
for i in range(j):
out += ' f[%d] = f%d;\n' % (i,i)
out += '}\n'
out += '\n'
# ---------- reciprocal initialization
out += """
void crypto_core(unsigned char *outbytes,const unsigned char *inbytes,const unsigned char *kbytes,const unsigned char *cbytes)
{
small *out = (void *) outbytes;
small *in = (void *) inbytes;
vec256 F0[numvec];
vec256 F1[numvec];
vec256 G0[numvec];
vec256 G1[numvec];
vec256 V0[numvec];
vec256 V1[numvec];
vec256 R0[numvec];
vec256 R1[numvec];
vec256 c0vec,c1vec;
int loop;
int c0,c1;
int minusdelta = -1;
int swapmask;
vec256 swapvec;
vec256_init(G0,G1,in);
"""
for name,bits in [('F0',f0),('F1',f1)]:
for j in range(len(bits)//256):
u = []
for k in range(j*256,j*256+256,32):
u = [sum(bits[k+i]<<i for i in range(32))] + u
if u == [0]*8:
out += ' %s[%d] = _mm256_set1_epi32(0);\n' % (name,j)
else:
u = ','.join(str(uk if uk < 2**31 else uk-2**32) for uk in u)
out += ' %s[%d] = _mm256_set_epi32(%s);\n' % (name,j,u)
out += '\n'
for i in range(numvec):
out += ' V0[%d] = _mm256_set1_epi32(0);\n' % i
out += ' V1[%d] = _mm256_set1_epi32(0);\n' % i
out += '\n'
for i in range(numvec):
if i == 0:
out += ' R0[%d] = _mm256_set_epi32(0,0,0,0,0,0,0,1);\n' % i
else:
out += ' R0[%d] = _mm256_set1_epi32(0);\n' % i
out += ' R1[%d] = _mm256_set1_epi32(0);\n' % i
# ---------- reciprocal main loop
vvecgvec = []
for loop in range(2*p-1):
vvec = min(numvec,1+loop//256)
gvec = min(numvec,1+(2*p-2-loop)//256)
vvecgvec += [(vvec,gvec)]
while len(vvecgvec) > 0:
vvec,gvec = vvecgvec[0]
loops = 1
while loops < len(vvecgvec) and vvecgvec[loops] == (vvec,gvec):
loops += 1
out += """
for (loop = %d;loop > 0;--loop) {
vec256_timesx_%d(V0);
vec256_timesx_%d(V1);
swapmask = crypto_int32_negative_mask(minusdelta) & vec256_bit0mask(G0);
c0 = vec256_bit0mask(F0) & vec256_bit0mask(G0);
c1 = vec256_bit0mask(F1) ^ vec256_bit0mask(G1);
c1 &= c0;
minusdelta ^= swapmask & (minusdelta ^ -minusdelta);
minusdelta -= 1;
swapvec = _mm256_set1_epi32(swapmask);
vec256_swap(F0,G0,%d,swapvec);
vec256_swap(F1,G1,%d,swapvec);
c0vec = _mm256_set1_epi32(c0);
c1vec = _mm256_set1_epi32(c1);
vec256_eliminate(F0,F1,G0,G1,%d,c0vec,c1vec);
vec256_divx_%d(G0);
vec256_divx_%d(G1);
vec256_swap(V0,R0,%d,swapvec);
vec256_swap(V1,R1,%d,swapvec);
vec256_eliminate(V0,V1,R0,R1,%d,c0vec,c1vec);
}
""" % (loops,vvec,vvec,gvec,gvec,gvec,gvec,gvec,vvec,vvec,vvec)
vvecgvec = vvecgvec[loops:]
# ---------- reciprocal finalization
out += """
c0vec = _mm256_set1_epi32(vec256_bit0mask(F0));
c1vec = _mm256_set1_epi32(vec256_bit0mask(F1));
vec256_scale(V0,V1,c0vec,c1vec);
vec256_final(out,V0,V1);
out[p] = crypto_int32_negative_mask(minusdelta);
}\
"""
out = re.sub(r'\bP\b',str(p),out)
out = re.sub(r'\bPPAD\b',str(ppad),out)
out = re.sub(r'\bNUMVEC\b',str(numvec),out)
print(out)