-rw-r--r-- 12480 libntruprime-20240910/src/core/mult3sntrupP/avx/mult1280.c raw
// 20240812 djb: more cryptoint usage
#include "ntt.h"
#include <immintrin.h>
#include "crypto_int8.h"
typedef int8_t int8;
typedef int16_t int16;
#define int16x16 __m256i
#define load_x16(p) _mm256_loadu_si256((int16x16 *) (p))
#define store_x16(p,v) _mm256_storeu_si256((int16x16 *) (p),(v))
#define const_x16 _mm256_set1_epi16
#define add_x16 _mm256_add_epi16
#define sub_x16 _mm256_sub_epi16
#define mullo_x16 _mm256_mullo_epi16
#define mulhi_x16 _mm256_mulhi_epi16
#define mulhrs_x16 _mm256_mulhrs_epi16
#define signmask_x16(x) _mm256_srai_epi16((x),15)
static int16x16 squeeze_3_x16(int16x16 x)
{
return sub_x16(x,mullo_x16(mulhrs_x16(x,const_x16(10923)),const_x16(3)));
}
static int16x16 squeeze_7681_x16(int16x16 x)
{
return sub_x16(x,mullo_x16(mulhrs_x16(x,const_x16(4)),const_x16(7681)));
}
static int16x16 mulmod_7681_x16(int16x16 x,int16x16 y)
{
int16x16 yqinv = mullo_x16(y,const_x16(-7679)); /* XXX: precompute */
int16x16 b = mulhi_x16(x,y);
int16x16 d = mullo_x16(x,yqinv);
int16x16 e = mulhi_x16(d,const_x16(7681));
return sub_x16(b,e);
}
#define mask0 _mm256_set_epi16(-1,0,0,0,0,-1,0,0,0,0,-1,0,0,0,0,-1)
#define mask1 _mm256_set_epi16(0,0,-1,0,0,0,0,-1,0,0,0,0,-1,0,0,0)
#define mask2 _mm256_set_epi16(0,0,0,0,-1,0,0,0,0,-1,0,0,0,0,-1,0)
#define mask3 _mm256_set_epi16(0,-1,0,0,0,0,-1,0,0,0,0,-1,0,0,0,0)
#define mask4 _mm256_set_epi16(0,0,0,-1,0,0,0,0,-1,0,0,0,0,-1,0,0)
static void good(int16 fpad[5][512],const int16 f[1280])
{
int j;
int16x16 f0,f1,f2;
j = 0;
for (;;) {
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
f2 = load_x16(f+1024+j);
store_x16(&fpad[0][j],(f0&mask0)|(f1&mask1)|(f2&mask2));
store_x16(&fpad[1][j],(f0&mask1)|(f1&mask2)|(f2&mask3));
store_x16(&fpad[2][j],(f0&mask2)|(f1&mask3)|(f2&mask4));
store_x16(&fpad[3][j],(f0&mask3)|(f1&mask4)|(f2&mask0));
store_x16(&fpad[4][j],(f0&mask4)|(f1&mask0)|(f2&mask1));
j += 16;
if (j == 256) break;
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
f2 = load_x16(f+1024+j);
store_x16(&fpad[0][j],(f0&mask3)|(f1&mask4)|(f2&mask0));
store_x16(&fpad[1][j],(f0&mask4)|(f1&mask0)|(f2&mask1));
store_x16(&fpad[2][j],(f0&mask0)|(f1&mask1)|(f2&mask2));
store_x16(&fpad[3][j],(f0&mask1)|(f1&mask2)|(f2&mask3));
store_x16(&fpad[4][j],(f0&mask2)|(f1&mask3)|(f2&mask4));
j += 16;
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
f2 = load_x16(f+1024+j);
store_x16(&fpad[0][j],(f0&mask1)|(f1&mask2)|(f2&mask3));
store_x16(&fpad[1][j],(f0&mask2)|(f1&mask3)|(f2&mask4));
store_x16(&fpad[2][j],(f0&mask3)|(f1&mask4)|(f2&mask0));
store_x16(&fpad[3][j],(f0&mask4)|(f1&mask0)|(f2&mask1));
store_x16(&fpad[4][j],(f0&mask0)|(f1&mask1)|(f2&mask2));
j += 16;
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
f2 = load_x16(f+1024+j);
store_x16(&fpad[0][j],(f0&mask4)|(f1&mask0)|(f2&mask1));
store_x16(&fpad[1][j],(f0&mask0)|(f1&mask1)|(f2&mask2));
store_x16(&fpad[2][j],(f0&mask1)|(f1&mask2)|(f2&mask3));
store_x16(&fpad[3][j],(f0&mask2)|(f1&mask3)|(f2&mask4));
store_x16(&fpad[4][j],(f0&mask3)|(f1&mask4)|(f2&mask0));
j += 16;
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
f2 = load_x16(f+1024+j);
store_x16(&fpad[0][j],(f0&mask2)|(f1&mask3)|(f2&mask4));
store_x16(&fpad[1][j],(f0&mask3)|(f1&mask4)|(f2&mask0));
store_x16(&fpad[2][j],(f0&mask4)|(f1&mask0)|(f2&mask1));
store_x16(&fpad[3][j],(f0&mask0)|(f1&mask1)|(f2&mask2));
store_x16(&fpad[4][j],(f0&mask1)|(f1&mask2)|(f2&mask3));
j += 16;
}
for (;;) {
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
store_x16(&fpad[0][j],(f0&mask3)|(f1&mask4));
store_x16(&fpad[1][j],(f0&mask4)|(f1&mask0));
store_x16(&fpad[2][j],(f0&mask0)|(f1&mask1));
store_x16(&fpad[3][j],(f0&mask1)|(f1&mask2));
store_x16(&fpad[4][j],(f0&mask2)|(f1&mask3));
j += 16;
if (j == 512) break;
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
store_x16(&fpad[0][j],(f0&mask1)|(f1&mask2));
store_x16(&fpad[1][j],(f0&mask2)|(f1&mask3));
store_x16(&fpad[2][j],(f0&mask3)|(f1&mask4));
store_x16(&fpad[3][j],(f0&mask4)|(f1&mask0));
store_x16(&fpad[4][j],(f0&mask0)|(f1&mask1));
j += 16;
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
store_x16(&fpad[0][j],(f0&mask4)|(f1&mask0));
store_x16(&fpad[1][j],(f0&mask0)|(f1&mask1));
store_x16(&fpad[2][j],(f0&mask1)|(f1&mask2));
store_x16(&fpad[3][j],(f0&mask2)|(f1&mask3));
store_x16(&fpad[4][j],(f0&mask3)|(f1&mask4));
j += 16;
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
store_x16(&fpad[0][j],(f0&mask2)|(f1&mask3));
store_x16(&fpad[1][j],(f0&mask3)|(f1&mask4));
store_x16(&fpad[2][j],(f0&mask4)|(f1&mask0));
store_x16(&fpad[3][j],(f0&mask0)|(f1&mask1));
store_x16(&fpad[4][j],(f0&mask1)|(f1&mask2));
j += 16;
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
store_x16(&fpad[0][j],(f0&mask0)|(f1&mask1));
store_x16(&fpad[1][j],(f0&mask1)|(f1&mask2));
store_x16(&fpad[2][j],(f0&mask2)|(f1&mask3));
store_x16(&fpad[3][j],(f0&mask3)|(f1&mask4));
store_x16(&fpad[4][j],(f0&mask4)|(f1&mask0));
j += 16;
}
}
static void ungood(int16 f[2560],const int16 fpad[5][512])
{
int j;
int16x16 f0,f1,f2,f3,f4,g0,g1,g2,g3,g4;
j = 0;
for (;;) {
f0 = load_x16(&fpad[0][j]);
f1 = load_x16(&fpad[1][j]);
f2 = load_x16(&fpad[2][j]);
f3 = load_x16(&fpad[3][j]);
f4 = load_x16(&fpad[4][j]);
g0 = (f0&mask0)|(f1&mask1)|(f2&mask2)|(f3&mask3)|(f4&mask4);
g1 = (f0&mask1)|(f1&mask2)|(f2&mask3)|(f3&mask4)|(f4&mask0);
g2 = (f0&mask2)|(f1&mask3)|(f2&mask4)|(f3&mask0)|(f4&mask1);
g3 = (f0&mask3)|(f1&mask4)|(f2&mask0)|(f3&mask1)|(f4&mask2);
g4 = f0^f1^f2^f3^f4^g0^g1^g2^g3; /* same as (f0&mask4)|(f1&mask0)|(f2&mask1)|(f3&mask2)|(f4&mask3); */
store_x16(f+0+j,g0);
store_x16(f+512+j,g1);
store_x16(f+1024+j,g2);
store_x16(f+1536+j,g3);
store_x16(f+2048+j,g4);
j += 16;
f0 = load_x16(&fpad[0][j]);
f1 = load_x16(&fpad[1][j]);
f2 = load_x16(&fpad[2][j]);
f3 = load_x16(&fpad[3][j]);
f4 = load_x16(&fpad[4][j]);
g0 = (f0&mask3)|(f1&mask4)|(f2&mask0)|(f3&mask1)|(f4&mask2);
g1 = (f0&mask4)|(f1&mask0)|(f2&mask1)|(f3&mask2)|(f4&mask3);
g2 = (f0&mask0)|(f1&mask1)|(f2&mask2)|(f3&mask3)|(f4&mask4);
g3 = (f0&mask1)|(f1&mask2)|(f2&mask3)|(f3&mask4)|(f4&mask0);
g4 = f0^f1^f2^f3^f4^g0^g1^g2^g3;
store_x16(f+0+j,g0);
store_x16(f+512+j,g1);
store_x16(f+1024+j,g2);
store_x16(f+1536+j,g3);
store_x16(f+2048+j,g4);
j += 16;
if (j == 512) break;
f0 = load_x16(&fpad[0][j]);
f1 = load_x16(&fpad[1][j]);
f2 = load_x16(&fpad[2][j]);
f3 = load_x16(&fpad[3][j]);
f4 = load_x16(&fpad[4][j]);
g0 = (f0&mask1)|(f1&mask2)|(f2&mask3)|(f3&mask4)|(f4&mask0);
g1 = (f0&mask2)|(f1&mask3)|(f2&mask4)|(f3&mask0)|(f4&mask1);
g2 = (f0&mask3)|(f1&mask4)|(f2&mask0)|(f3&mask1)|(f4&mask2);
g3 = (f0&mask4)|(f1&mask0)|(f2&mask1)|(f3&mask2)|(f4&mask3);
g4 = f0^f1^f2^f3^f4^g0^g1^g2^g3;
store_x16(f+0+j,g0);
store_x16(f+512+j,g1);
store_x16(f+1024+j,g2);
store_x16(f+1536+j,g3);
store_x16(f+2048+j,g4);
j += 16;
f0 = load_x16(&fpad[0][j]);
f1 = load_x16(&fpad[1][j]);
f2 = load_x16(&fpad[2][j]);
f3 = load_x16(&fpad[3][j]);
f4 = load_x16(&fpad[4][j]);
g0 = (f0&mask4)|(f1&mask0)|(f2&mask1)|(f3&mask2)|(f4&mask3);
g1 = (f0&mask0)|(f1&mask1)|(f2&mask2)|(f3&mask3)|(f4&mask4);
g2 = (f0&mask1)|(f1&mask2)|(f2&mask3)|(f3&mask4)|(f4&mask0);
g3 = (f0&mask2)|(f1&mask3)|(f2&mask4)|(f3&mask0)|(f4&mask1);
g4 = f0^f1^f2^f3^f4^g0^g1^g2^g3;
store_x16(f+0+j,g0);
store_x16(f+512+j,g1);
store_x16(f+1024+j,g2);
store_x16(f+1536+j,g3);
store_x16(f+2048+j,g4);
j += 16;
f0 = load_x16(&fpad[0][j]);
f1 = load_x16(&fpad[1][j]);
f2 = load_x16(&fpad[2][j]);
f3 = load_x16(&fpad[3][j]);
f4 = load_x16(&fpad[4][j]);
g0 = (f0&mask2)|(f1&mask3)|(f2&mask4)|(f3&mask0)|(f4&mask1);
g1 = (f0&mask3)|(f1&mask4)|(f2&mask0)|(f3&mask1)|(f4&mask2);
g2 = (f0&mask4)|(f1&mask0)|(f2&mask1)|(f3&mask2)|(f4&mask3);
g3 = (f0&mask0)|(f1&mask1)|(f2&mask2)|(f3&mask3)|(f4&mask4);
g4 = f0^f1^f2^f3^f4^g0^g1^g2^g3;
store_x16(f+0+j,g0);
store_x16(f+512+j,g1);
store_x16(f+1024+j,g2);
store_x16(f+1536+j,g3);
store_x16(f+2048+j,g4);
j += 16;
}
}
#define ALIGNED __attribute((aligned(32)))
static void mult1280(int16 h[2560],const int16 f[1280],const int16 g[1280])
{
ALIGNED int16 fgpad[10][512];
#define fpad fgpad
#define gpad (fgpad+5)
#define hpad fpad
ALIGNED int16 h_7681[2560];
int i;
good(fpad,f);
good(gpad,g);
ntt512_7681(fgpad[0],10);
for (i = 0;i < 512;i += 16) {
int16x16 f0 = squeeze_7681_x16(load_x16(&fpad[0][i]));
int16x16 f1 = squeeze_7681_x16(load_x16(&fpad[1][i]));
int16x16 f2 = squeeze_7681_x16(load_x16(&fpad[2][i]));
int16x16 f3 = squeeze_7681_x16(load_x16(&fpad[3][i]));
int16x16 f4 = squeeze_7681_x16(load_x16(&fpad[4][i]));
int16x16 g0 = squeeze_7681_x16(load_x16(&gpad[0][i]));
int16x16 g1 = squeeze_7681_x16(load_x16(&gpad[1][i]));
int16x16 g2 = squeeze_7681_x16(load_x16(&gpad[2][i]));
int16x16 g3 = squeeze_7681_x16(load_x16(&gpad[3][i]));
int16x16 g4 = squeeze_7681_x16(load_x16(&gpad[4][i]));
int16x16 d0 = mulmod_7681_x16(f0,g0);
int16x16 d1 = mulmod_7681_x16(f1,g1);
int16x16 d2 = mulmod_7681_x16(f2,g2);
int16x16 d3 = mulmod_7681_x16(f3,g3);
int16x16 d4 = mulmod_7681_x16(f4,g4);
int16x16 dsum;
int16x16 h0,h1,h2,h3,h4;
dsum = add_x16(add_x16(d0,d1),d2);
dsum = squeeze_7681_x16(dsum);
dsum = add_x16(add_x16(d3,d4),dsum);
dsum = squeeze_7681_x16(dsum);
/* could save one mult here using sum hi = (sum fi)(sum gi) */
h0 = add_x16(dsum,mulmod_7681_x16(sub_x16(f4,f1),sub_x16(g1,g4)));
h1 = add_x16(dsum,mulmod_7681_x16(sub_x16(f1,f0),sub_x16(g0,g1)));
h2 = add_x16(dsum,mulmod_7681_x16(sub_x16(f2,f0),sub_x16(g0,g2)));
h3 = add_x16(dsum,mulmod_7681_x16(sub_x16(f3,f0),sub_x16(g0,g3)));
h4 = add_x16(dsum,mulmod_7681_x16(sub_x16(f4,f0),sub_x16(g0,g4)));
h0 = add_x16(h0,mulmod_7681_x16(sub_x16(f3,f2),sub_x16(g2,g3)));
h1 = add_x16(h1,mulmod_7681_x16(sub_x16(f4,f2),sub_x16(g2,g4)));
h2 = add_x16(h2,mulmod_7681_x16(sub_x16(f4,f3),sub_x16(g3,g4)));
h3 = add_x16(h3,mulmod_7681_x16(sub_x16(f2,f1),sub_x16(g1,g2)));
h4 = add_x16(h4,mulmod_7681_x16(sub_x16(f3,f1),sub_x16(g1,g3)));
store_x16(&hpad[0][i],squeeze_7681_x16(h0));
store_x16(&hpad[1][i],squeeze_7681_x16(h1));
store_x16(&hpad[2][i],squeeze_7681_x16(h2));
store_x16(&hpad[3][i],squeeze_7681_x16(h3));
store_x16(&hpad[4][i],squeeze_7681_x16(h4));
}
invntt512_7681(hpad[0],5);
ungood(h_7681,hpad);
for (i = 0;i < 2560;i += 16) {
int16x16 u = load_x16(&h_7681[i]);
u = mulmod_7681_x16(u,const_x16(956));
store_x16(&h[i],u);
}
}
#include "crypto_core.h"
#include "crypto_decode_{P}xint16.h"
#define crypto_decode_pxint16 crypto_decode_{P}xint16
#include "crypto_encode_{P}xint16.h"
#define crypto_encode_pxint16 crypto_encode_{P}xint16
#define p {P}
static inline int16x16 freeze_3_x16(int16x16 x)
{
int16x16 mask, x3;
x = add_x16(x,const_x16(3)&signmask_x16(x));
mask = signmask_x16(sub_x16(x,const_x16(2)));
x3 = sub_x16(x,const_x16(3));
x = _mm256_blendv_epi8(x3,x,mask);
return x;
}
void crypto_core(unsigned char *outbytes,const unsigned char *inbytes,const unsigned char *kbytes,const unsigned char *cbytes)
{
ALIGNED int16 f[1280];
ALIGNED int16 g[1280];
ALIGNED int16 fg[2560];
#define h f
int i;
int16x16 x;
x = const_x16(0);
for (i = p&~15;i < 1280;i += 16) store_x16(&f[i],x);
for (i = p&~15;i < 1280;i += 16) store_x16(&g[i],x);
for (i = 0;i < p;++i) {
int8 fi = inbytes[i];
int8 fi0 = crypto_int8_bottombit_01(fi);
f[i] = fi0-(fi&(fi0<<1));
}
for (i = 0;i < p;++i) {
int8 gi = kbytes[i];
int8 gi0 = crypto_int8_bottombit_01(gi);
g[i] = gi0-(gi&(gi0<<1));
}
mult1280(fg,f,g);
fg[0] -= fg[p-1];
for (i = 0;i < 1280;i += 16) {
int16x16 fgi = load_x16(&fg[i]);
int16x16 fgip = load_x16(&fg[i + p]);
int16x16 fgip1 = load_x16(&fg[i + p - 1]);
x = add_x16(fgi,add_x16(fgip,fgip1));
x = freeze_3_x16(squeeze_3_x16(x));
store_x16(&h[i],x);
}
for (i = 0;i < p;++i) outbytes[i] = h[i];
}