-rw-r--r-- 8795 libntruprime-20240910/crypto_core/multsntrup653/avx/mult768.c raw
// 20240812 djb: more cryptoint usage #include "ntt.h" #include <immintrin.h> #include "crypto_int8.h" typedef int8_t int8; typedef int16_t int16; #define int16x16 __m256i #define load_x16(p) _mm256_loadu_si256((int16x16 *) (p)) #define store_x16(p,v) _mm256_storeu_si256((int16x16 *) (p),(v)) #define const_x16 _mm256_set1_epi16 #define add_x16 _mm256_add_epi16 #define sub_x16 _mm256_sub_epi16 #define mullo_x16 _mm256_mullo_epi16 #define mulhi_x16 _mm256_mulhi_epi16 #define mulhrs_x16 _mm256_mulhrs_epi16 #define signmask_x16(x) _mm256_srai_epi16((x),15) static inline int16x16 squeeze_4621_x16(int16x16 x) { return sub_x16(x,mullo_x16(mulhrs_x16(x,const_x16(7)),const_x16(4621))); } static inline int16x16 squeeze_7681_x16(int16x16 x) { return sub_x16(x,mullo_x16(mulhrs_x16(x,const_x16(4)),const_x16(7681))); } static inline int16x16 squeeze_10753_x16(int16x16 x) { return sub_x16(x,mullo_x16(mulhrs_x16(x,const_x16(3)),const_x16(10753))); } static inline int16x16 mulmod_4621_x16(int16x16 x,int16x16 y) { int16x16 yqinv = mullo_x16(y,const_x16(-29499)); /* XXX: precompute */ int16x16 b = mulhi_x16(x,y); int16x16 d = mullo_x16(x,yqinv); int16x16 e = mulhi_x16(d,const_x16(4621)); return sub_x16(b,e); } static inline int16x16 mulmod_7681_x16(int16x16 x,int16x16 y) { int16x16 yqinv = mullo_x16(y,const_x16(-7679)); /* XXX: precompute */ int16x16 b = mulhi_x16(x,y); int16x16 d = mullo_x16(x,yqinv); int16x16 e = mulhi_x16(d,const_x16(7681)); return sub_x16(b,e); } static inline int16x16 mulmod_10753_x16(int16x16 x,int16x16 y) { int16x16 yqinv = mullo_x16(y,const_x16(-10751)); /* XXX: precompute */ int16x16 b = mulhi_x16(x,y); int16x16 d = mullo_x16(x,yqinv); int16x16 e = mulhi_x16(d,const_x16(10753)); return sub_x16(b,e); } #define mask0 _mm256_set_epi16(-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1) #define mask1 _mm256_set_epi16(0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0) #define mask2 _mm256_set_epi16(0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0) static void good(int16 fpad[3][512],const int16 f[768]) { int j; int16x16 f0,f1; j = 0; for (;;) { f0 = load_x16(f+j); f1 = load_x16(f+512+j); store_x16(&fpad[0][j],(f0&mask0)|(f1&mask1)); store_x16(&fpad[1][j],(f0&mask1)|(f1&mask2)); store_x16(&fpad[2][j],(f0&mask2)|(f1&mask0)); j += 16; if (j == 256) break; f0 = load_x16(f+j); f1 = load_x16(f+512+j); store_x16(&fpad[0][j],(f0&mask2)|(f1&mask0)); store_x16(&fpad[1][j],(f0&mask0)|(f1&mask1)); store_x16(&fpad[2][j],(f0&mask1)|(f1&mask2)); j += 16; f0 = load_x16(f+j); f1 = load_x16(f+512+j); store_x16(&fpad[0][j],(f0&mask1)|(f1&mask2)); store_x16(&fpad[1][j],(f0&mask2)|(f1&mask0)); store_x16(&fpad[2][j],(f0&mask0)|(f1&mask1)); j += 16; } for (;;) { f0 = load_x16(f+j); store_x16(&fpad[0][j],f0&mask2); store_x16(&fpad[1][j],f0&mask0); store_x16(&fpad[2][j],f0&mask1); j += 16; if (j == 512) break; f0 = load_x16(f+j); store_x16(&fpad[0][j],f0&mask1); store_x16(&fpad[1][j],f0&mask2); store_x16(&fpad[2][j],f0&mask0); j += 16; f0 = load_x16(f+j); store_x16(&fpad[0][j],f0&mask0); store_x16(&fpad[1][j],f0&mask1); store_x16(&fpad[2][j],f0&mask2); j += 16; } } static void ungood(int16 f[1536],const int16 fpad[3][512]) { int j; int16x16 f0,f1,f2,g0,g1,g2; j = 0; for (;;) { f0 = load_x16(&fpad[0][j]); f1 = load_x16(&fpad[1][j]); f2 = load_x16(&fpad[2][j]); g0 = (f0&mask0)|(f1&mask1)|(f2&mask2); g1 = (f0&mask1)|(f1&mask2)|(f2&mask0); g2 = f0^f1^f2^g0^g1; /* same as (f0&mask2)|(f1&mask0)|(f2&mask1) */ store_x16(f+0+j,g0); store_x16(f+512+j,g1); store_x16(f+1024+j,g2); j += 16; f0 = load_x16(&fpad[0][j]); f1 = load_x16(&fpad[1][j]); f2 = load_x16(&fpad[2][j]); g0 = (f0&mask2)|(f1&mask0)|(f2&mask1); g1 = (f0&mask0)|(f1&mask1)|(f2&mask2); g2 = f0^f1^f2^g0^g1; /* same as (f0&mask1)|(f1&mask2)|(f2&mask0) */ store_x16(f+0+j,g0); store_x16(f+512+j,g1); store_x16(f+1024+j,g2); j += 16; if (j == 512) break; f0 = load_x16(&fpad[0][j]); f1 = load_x16(&fpad[1][j]); f2 = load_x16(&fpad[2][j]); g0 = (f0&mask1)|(f1&mask2)|(f2&mask0); g1 = (f0&mask2)|(f1&mask0)|(f2&mask1); g2 = f0^f1^f2^g0^g1; /* same as (f0&mask0)|(f1&mask1)|(f2&mask2) */ store_x16(f+0+j,g0); store_x16(f+512+j,g1); store_x16(f+1024+j,g2); j += 16; } } #define ALIGNED __attribute((aligned(512))) static void mult768(int16 h[1536],const int16 f[768],const int16 g[768]) { ALIGNED int16 fgpad[6][512]; #define fpad fgpad #define gpad (fgpad+3) #define hpad fpad ALIGNED int16 h_7681[1536]; ALIGNED int16 h_10753[1536]; int i; good(fpad,f); good(gpad,g); ntt512_7681(fgpad[0],6); for (i = 0;i < 512;i += 16) { int16x16 f0 = squeeze_7681_x16(load_x16(&fpad[0][i])); int16x16 f1 = squeeze_7681_x16(load_x16(&fpad[1][i])); int16x16 f2 = squeeze_7681_x16(load_x16(&fpad[2][i])); int16x16 g0 = squeeze_7681_x16(load_x16(&gpad[0][i])); int16x16 g1 = squeeze_7681_x16(load_x16(&gpad[1][i])); int16x16 g2 = squeeze_7681_x16(load_x16(&gpad[2][i])); int16x16 d0 = mulmod_7681_x16(f0,g0); int16x16 d1 = mulmod_7681_x16(f1,g1); int16x16 d2 = mulmod_7681_x16(f2,g2); int16x16 dsum = add_x16(add_x16(d0,d1),d2); int16x16 h0 = add_x16(dsum,mulmod_7681_x16(sub_x16(f2,f1),sub_x16(g1,g2))); int16x16 h1 = add_x16(dsum,mulmod_7681_x16(sub_x16(f1,f0),sub_x16(g0,g1))); int16x16 h2 = add_x16(dsum,mulmod_7681_x16(sub_x16(f0,f2),sub_x16(g2,g0))); store_x16(&hpad[0][i],squeeze_7681_x16(h0)); store_x16(&hpad[1][i],squeeze_7681_x16(h1)); store_x16(&hpad[2][i],squeeze_7681_x16(h2)); } invntt512_7681(hpad[0],3); ungood(h_7681,hpad); good(fpad,f); good(gpad,g); ntt512_10753(fgpad[0],6); for (i = 0;i < 512;i += 16) { int16x16 f0 = squeeze_10753_x16(load_x16(&fpad[0][i])); int16x16 f1 = squeeze_10753_x16(load_x16(&fpad[1][i])); int16x16 f2 = squeeze_10753_x16(load_x16(&fpad[2][i])); int16x16 g0 = squeeze_10753_x16(load_x16(&gpad[0][i])); int16x16 g1 = squeeze_10753_x16(load_x16(&gpad[1][i])); int16x16 g2 = squeeze_10753_x16(load_x16(&gpad[2][i])); int16x16 d0 = mulmod_10753_x16(f0,g0); int16x16 d1 = mulmod_10753_x16(f1,g1); int16x16 d2 = mulmod_10753_x16(f2,g2); int16x16 dsum = add_x16(add_x16(d0,d1),d2); int16x16 h0 = add_x16(dsum,mulmod_10753_x16(sub_x16(f2,f1),sub_x16(g1,g2))); int16x16 h1 = add_x16(dsum,mulmod_10753_x16(sub_x16(f1,f0),sub_x16(g0,g1))); int16x16 h2 = add_x16(dsum,mulmod_10753_x16(sub_x16(f0,f2),sub_x16(g2,g0))); store_x16(&hpad[0][i],squeeze_10753_x16(h0)); store_x16(&hpad[1][i],squeeze_10753_x16(h1)); store_x16(&hpad[2][i],squeeze_10753_x16(h2)); } invntt512_10753(hpad[0],3); ungood(h_10753,hpad); for (i = 0;i < 1536;i += 16) { int16x16 u1 = load_x16(&h_10753[i]); int16x16 u2 = load_x16(&h_7681[i]); int16x16 t; u1 = mulmod_10753_x16(u1,const_x16(1268)); u2 = mulmod_7681_x16(u2,const_x16(956)); t = mulmod_7681_x16(sub_x16(u2,u1),const_x16(-2539)); t = add_x16(u1,mulmod_4621_x16(t,const_x16(1487))); store_x16(&h[i],t); } } #include "crypto_core.h" #include "crypto_decode_653xint16.h" #define crypto_decode_pxint16 crypto_decode_653xint16 #include "crypto_encode_653xint16.h" #define crypto_encode_pxint16 crypto_encode_653xint16 #define p 653 #define q 4621 static inline int16x16 freeze_4621_x16(int16x16 x) { int16x16 mask, xq; x = add_x16(x,const_x16(q)&signmask_x16(x)); mask = signmask_x16(sub_x16(x,const_x16((q+1)/2))); xq = sub_x16(x,const_x16(q)); x = _mm256_blendv_epi8(xq,x,mask); return x; } void crypto_core(unsigned char *outbytes,const unsigned char *inbytes,const unsigned char *kbytes,const unsigned char *cbytes) { ALIGNED int16 f[768]; ALIGNED int16 g[768]; ALIGNED int16 fg[1536]; #define h f int i; int16x16 x; x = const_x16(0); for (i = p&~15;i < 768;i += 16) store_x16(&f[i],x); for (i = p&~15;i < 768;i += 16) store_x16(&g[i],x); crypto_decode_pxint16(f,inbytes); for (i = 0;i < 768;i += 16) { x = load_x16(&f[i]); x = freeze_4621_x16(squeeze_4621_x16(x)); store_x16(&f[i],x); } for (i = 0;i < p;++i) { int8 gi = kbytes[i]; int8 gi0 = crypto_int8_bottombit_01(gi); g[i] = gi0-(gi&(gi0<<1)); } mult768(fg,f,g); fg[0] -= fg[p-1]; for (i = 0;i < 768;i += 16) { int16x16 fgi = load_x16(&fg[i]); int16x16 fgip = load_x16(&fg[i + p]); int16x16 fgip1 = load_x16(&fg[i + p - 1]); x = add_x16(fgi,add_x16(fgip,fgip1)); x = freeze_4621_x16(squeeze_4621_x16(x)); store_x16(&h[i],x); } crypto_encode_pxint16(outbytes,h); }