-rw-r--r-- 8795 libntruprime-20240825/crypto_core/multsntrup653/avx/mult768.c raw
// 20240812 djb: more cryptoint usage
#include "ntt.h"
#include <immintrin.h>
#include "crypto_int8.h"
typedef int8_t int8;
typedef int16_t int16;
#define int16x16 __m256i
#define load_x16(p) _mm256_loadu_si256((int16x16 *) (p))
#define store_x16(p,v) _mm256_storeu_si256((int16x16 *) (p),(v))
#define const_x16 _mm256_set1_epi16
#define add_x16 _mm256_add_epi16
#define sub_x16 _mm256_sub_epi16
#define mullo_x16 _mm256_mullo_epi16
#define mulhi_x16 _mm256_mulhi_epi16
#define mulhrs_x16 _mm256_mulhrs_epi16
#define signmask_x16(x) _mm256_srai_epi16((x),15)
static inline int16x16 squeeze_4621_x16(int16x16 x)
{
return sub_x16(x,mullo_x16(mulhrs_x16(x,const_x16(7)),const_x16(4621)));
}
static inline int16x16 squeeze_7681_x16(int16x16 x)
{
return sub_x16(x,mullo_x16(mulhrs_x16(x,const_x16(4)),const_x16(7681)));
}
static inline int16x16 squeeze_10753_x16(int16x16 x)
{
return sub_x16(x,mullo_x16(mulhrs_x16(x,const_x16(3)),const_x16(10753)));
}
static inline int16x16 mulmod_4621_x16(int16x16 x,int16x16 y)
{
int16x16 yqinv = mullo_x16(y,const_x16(-29499)); /* XXX: precompute */
int16x16 b = mulhi_x16(x,y);
int16x16 d = mullo_x16(x,yqinv);
int16x16 e = mulhi_x16(d,const_x16(4621));
return sub_x16(b,e);
}
static inline int16x16 mulmod_7681_x16(int16x16 x,int16x16 y)
{
int16x16 yqinv = mullo_x16(y,const_x16(-7679)); /* XXX: precompute */
int16x16 b = mulhi_x16(x,y);
int16x16 d = mullo_x16(x,yqinv);
int16x16 e = mulhi_x16(d,const_x16(7681));
return sub_x16(b,e);
}
static inline int16x16 mulmod_10753_x16(int16x16 x,int16x16 y)
{
int16x16 yqinv = mullo_x16(y,const_x16(-10751)); /* XXX: precompute */
int16x16 b = mulhi_x16(x,y);
int16x16 d = mullo_x16(x,yqinv);
int16x16 e = mulhi_x16(d,const_x16(10753));
return sub_x16(b,e);
}
#define mask0 _mm256_set_epi16(-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1)
#define mask1 _mm256_set_epi16(0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0)
#define mask2 _mm256_set_epi16(0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0)
static void good(int16 fpad[3][512],const int16 f[768])
{
int j;
int16x16 f0,f1;
j = 0;
for (;;) {
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
store_x16(&fpad[0][j],(f0&mask0)|(f1&mask1));
store_x16(&fpad[1][j],(f0&mask1)|(f1&mask2));
store_x16(&fpad[2][j],(f0&mask2)|(f1&mask0));
j += 16;
if (j == 256) break;
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
store_x16(&fpad[0][j],(f0&mask2)|(f1&mask0));
store_x16(&fpad[1][j],(f0&mask0)|(f1&mask1));
store_x16(&fpad[2][j],(f0&mask1)|(f1&mask2));
j += 16;
f0 = load_x16(f+j);
f1 = load_x16(f+512+j);
store_x16(&fpad[0][j],(f0&mask1)|(f1&mask2));
store_x16(&fpad[1][j],(f0&mask2)|(f1&mask0));
store_x16(&fpad[2][j],(f0&mask0)|(f1&mask1));
j += 16;
}
for (;;) {
f0 = load_x16(f+j);
store_x16(&fpad[0][j],f0&mask2);
store_x16(&fpad[1][j],f0&mask0);
store_x16(&fpad[2][j],f0&mask1);
j += 16;
if (j == 512) break;
f0 = load_x16(f+j);
store_x16(&fpad[0][j],f0&mask1);
store_x16(&fpad[1][j],f0&mask2);
store_x16(&fpad[2][j],f0&mask0);
j += 16;
f0 = load_x16(f+j);
store_x16(&fpad[0][j],f0&mask0);
store_x16(&fpad[1][j],f0&mask1);
store_x16(&fpad[2][j],f0&mask2);
j += 16;
}
}
static void ungood(int16 f[1536],const int16 fpad[3][512])
{
int j;
int16x16 f0,f1,f2,g0,g1,g2;
j = 0;
for (;;) {
f0 = load_x16(&fpad[0][j]);
f1 = load_x16(&fpad[1][j]);
f2 = load_x16(&fpad[2][j]);
g0 = (f0&mask0)|(f1&mask1)|(f2&mask2);
g1 = (f0&mask1)|(f1&mask2)|(f2&mask0);
g2 = f0^f1^f2^g0^g1; /* same as (f0&mask2)|(f1&mask0)|(f2&mask1) */
store_x16(f+0+j,g0);
store_x16(f+512+j,g1);
store_x16(f+1024+j,g2);
j += 16;
f0 = load_x16(&fpad[0][j]);
f1 = load_x16(&fpad[1][j]);
f2 = load_x16(&fpad[2][j]);
g0 = (f0&mask2)|(f1&mask0)|(f2&mask1);
g1 = (f0&mask0)|(f1&mask1)|(f2&mask2);
g2 = f0^f1^f2^g0^g1; /* same as (f0&mask1)|(f1&mask2)|(f2&mask0) */
store_x16(f+0+j,g0);
store_x16(f+512+j,g1);
store_x16(f+1024+j,g2);
j += 16;
if (j == 512) break;
f0 = load_x16(&fpad[0][j]);
f1 = load_x16(&fpad[1][j]);
f2 = load_x16(&fpad[2][j]);
g0 = (f0&mask1)|(f1&mask2)|(f2&mask0);
g1 = (f0&mask2)|(f1&mask0)|(f2&mask1);
g2 = f0^f1^f2^g0^g1; /* same as (f0&mask0)|(f1&mask1)|(f2&mask2) */
store_x16(f+0+j,g0);
store_x16(f+512+j,g1);
store_x16(f+1024+j,g2);
j += 16;
}
}
#define ALIGNED __attribute((aligned(512)))
static void mult768(int16 h[1536],const int16 f[768],const int16 g[768])
{
ALIGNED int16 fgpad[6][512];
#define fpad fgpad
#define gpad (fgpad+3)
#define hpad fpad
ALIGNED int16 h_7681[1536];
ALIGNED int16 h_10753[1536];
int i;
good(fpad,f);
good(gpad,g);
ntt512_7681(fgpad[0],6);
for (i = 0;i < 512;i += 16) {
int16x16 f0 = squeeze_7681_x16(load_x16(&fpad[0][i]));
int16x16 f1 = squeeze_7681_x16(load_x16(&fpad[1][i]));
int16x16 f2 = squeeze_7681_x16(load_x16(&fpad[2][i]));
int16x16 g0 = squeeze_7681_x16(load_x16(&gpad[0][i]));
int16x16 g1 = squeeze_7681_x16(load_x16(&gpad[1][i]));
int16x16 g2 = squeeze_7681_x16(load_x16(&gpad[2][i]));
int16x16 d0 = mulmod_7681_x16(f0,g0);
int16x16 d1 = mulmod_7681_x16(f1,g1);
int16x16 d2 = mulmod_7681_x16(f2,g2);
int16x16 dsum = add_x16(add_x16(d0,d1),d2);
int16x16 h0 = add_x16(dsum,mulmod_7681_x16(sub_x16(f2,f1),sub_x16(g1,g2)));
int16x16 h1 = add_x16(dsum,mulmod_7681_x16(sub_x16(f1,f0),sub_x16(g0,g1)));
int16x16 h2 = add_x16(dsum,mulmod_7681_x16(sub_x16(f0,f2),sub_x16(g2,g0)));
store_x16(&hpad[0][i],squeeze_7681_x16(h0));
store_x16(&hpad[1][i],squeeze_7681_x16(h1));
store_x16(&hpad[2][i],squeeze_7681_x16(h2));
}
invntt512_7681(hpad[0],3);
ungood(h_7681,hpad);
good(fpad,f);
good(gpad,g);
ntt512_10753(fgpad[0],6);
for (i = 0;i < 512;i += 16) {
int16x16 f0 = squeeze_10753_x16(load_x16(&fpad[0][i]));
int16x16 f1 = squeeze_10753_x16(load_x16(&fpad[1][i]));
int16x16 f2 = squeeze_10753_x16(load_x16(&fpad[2][i]));
int16x16 g0 = squeeze_10753_x16(load_x16(&gpad[0][i]));
int16x16 g1 = squeeze_10753_x16(load_x16(&gpad[1][i]));
int16x16 g2 = squeeze_10753_x16(load_x16(&gpad[2][i]));
int16x16 d0 = mulmod_10753_x16(f0,g0);
int16x16 d1 = mulmod_10753_x16(f1,g1);
int16x16 d2 = mulmod_10753_x16(f2,g2);
int16x16 dsum = add_x16(add_x16(d0,d1),d2);
int16x16 h0 = add_x16(dsum,mulmod_10753_x16(sub_x16(f2,f1),sub_x16(g1,g2)));
int16x16 h1 = add_x16(dsum,mulmod_10753_x16(sub_x16(f1,f0),sub_x16(g0,g1)));
int16x16 h2 = add_x16(dsum,mulmod_10753_x16(sub_x16(f0,f2),sub_x16(g2,g0)));
store_x16(&hpad[0][i],squeeze_10753_x16(h0));
store_x16(&hpad[1][i],squeeze_10753_x16(h1));
store_x16(&hpad[2][i],squeeze_10753_x16(h2));
}
invntt512_10753(hpad[0],3);
ungood(h_10753,hpad);
for (i = 0;i < 1536;i += 16) {
int16x16 u1 = load_x16(&h_10753[i]);
int16x16 u2 = load_x16(&h_7681[i]);
int16x16 t;
u1 = mulmod_10753_x16(u1,const_x16(1268));
u2 = mulmod_7681_x16(u2,const_x16(956));
t = mulmod_7681_x16(sub_x16(u2,u1),const_x16(-2539));
t = add_x16(u1,mulmod_4621_x16(t,const_x16(1487)));
store_x16(&h[i],t);
}
}
#include "crypto_core.h"
#include "crypto_decode_653xint16.h"
#define crypto_decode_pxint16 crypto_decode_653xint16
#include "crypto_encode_653xint16.h"
#define crypto_encode_pxint16 crypto_encode_653xint16
#define p 653
#define q 4621
static inline int16x16 freeze_4621_x16(int16x16 x)
{
int16x16 mask, xq;
x = add_x16(x,const_x16(q)&signmask_x16(x));
mask = signmask_x16(sub_x16(x,const_x16((q+1)/2)));
xq = sub_x16(x,const_x16(q));
x = _mm256_blendv_epi8(xq,x,mask);
return x;
}
void crypto_core(unsigned char *outbytes,const unsigned char *inbytes,const unsigned char *kbytes,const unsigned char *cbytes)
{
ALIGNED int16 f[768];
ALIGNED int16 g[768];
ALIGNED int16 fg[1536];
#define h f
int i;
int16x16 x;
x = const_x16(0);
for (i = p&~15;i < 768;i += 16) store_x16(&f[i],x);
for (i = p&~15;i < 768;i += 16) store_x16(&g[i],x);
crypto_decode_pxint16(f,inbytes);
for (i = 0;i < 768;i += 16) {
x = load_x16(&f[i]);
x = freeze_4621_x16(squeeze_4621_x16(x));
store_x16(&f[i],x);
}
for (i = 0;i < p;++i) {
int8 gi = kbytes[i];
int8 gi0 = crypto_int8_bottombit_01(gi);
g[i] = gi0-(gi&(gi0<<1));
}
mult768(fg,f,g);
fg[0] -= fg[p-1];
for (i = 0;i < 768;i += 16) {
int16x16 fgi = load_x16(&fg[i]);
int16x16 fgip = load_x16(&fg[i + p]);
int16x16 fgip1 = load_x16(&fg[i + p - 1]);
x = add_x16(fgi,add_x16(fgip,fgip1));
x = freeze_4621_x16(squeeze_4621_x16(x));
store_x16(&h[i],x);
}
crypto_encode_pxint16(outbytes,h);
}