-rw-r--r-- 4898 libntruprime-20241008/src/core/invsntrupP/avx/recip.c raw
// 20240806 djb: some automated conversion to cryptoint #include <immintrin.h> #include "crypto_int8.h" #include "crypto_int16.h" #include "crypto_int32.h" #include "crypto_int64.h" #define int8 crypto_int8 #define int16 crypto_int16 #define int32 crypto_int32 #include "crypto_core.h" #include "params.h" /* ----- arithmetic mod q */ typedef int8 small; typedef int16 Fq; /* always represented as -(q-1)/2...(q-1)/2 */ /* works for -7000000 < x < 7000000 if q in 4591, 4621, 5167, 6343, 7177, 7879 */ static Fq Fq_freeze(int32 x) { x -= q*((q18*x)>>18); x -= q*((q27*x+67108864)>>27); return x; } static Fq Fq_bigfreeze(int32 x) { x -= q*((q14*x)>>14); x -= q*((q18*x)>>18); x -= q*((q27*x+67108864)>>27); x -= q*((q27*x+67108864)>>27); return x; } /* nonnegative e */ static Fq Fq_pow(Fq a,int e) { if (e == 0) return 1; if (e == 1) return a; if (crypto_int64_bottombit_01(e)) return Fq_bigfreeze(a*(int32)Fq_pow(a,e-1)); a = Fq_bigfreeze(a*(int32)a); return Fq_pow(a,e>>1); } static Fq Fq_recip(Fq a) { return Fq_pow(a,q-2); } /* ----- more */ #define qvec _mm256_set1_epi16(q) #define qinvvec _mm256_set1_epi16(qinv) static inline __m256i montproduct(__m256i x,__m256i y,__m256i yqinv) { __m256i hi,d,e; d = _mm256_mullo_epi16(x,yqinv); hi = _mm256_mulhi_epi16(x,y); e = _mm256_mulhi_epi16(d,qvec); return _mm256_sub_epi16(hi,e); } static inline void vectormodq_swapeliminate(Fq *f,Fq *g,int len,const Fq f0,const Fq g0,int mask) { __m256i f0vec = _mm256_set1_epi16(f0); __m256i g0vec = _mm256_set1_epi16(g0); __m256i f0vecqinv = _mm256_mullo_epi16(f0vec,qinvvec); __m256i g0vecqinv = _mm256_mullo_epi16(g0vec,qinvvec); __m256i maskvec = _mm256_set1_epi32(mask); while (len > 0) { __m256i fi = _mm256_loadu_si256((__m256i *) f); __m256i gi = _mm256_loadu_si256((__m256i *) g); __m256i finew = _mm256_blendv_epi8(fi,gi,maskvec); __m256i ginew = _mm256_blendv_epi8(gi,fi,maskvec); ginew = _mm256_sub_epi16(montproduct(ginew,f0vec,f0vecqinv),montproduct(finew,g0vec,g0vecqinv)); _mm256_storeu_si256((__m256i *) f,finew); _mm256_storeu_si256((__m256i *) (g-1),ginew); f += 16; g += 16; len -= 16; } } static inline void vectormodq_xswapeliminate(Fq *f,Fq *g,int len,const Fq f0,const Fq g0,int mask) { __m256i f0vec = _mm256_set1_epi16(f0); __m256i g0vec = _mm256_set1_epi16(g0); __m256i f0vecqinv = _mm256_mullo_epi16(f0vec,qinvvec); __m256i g0vecqinv = _mm256_mullo_epi16(g0vec,qinvvec); __m256i maskvec = _mm256_set1_epi32(mask); f += len + (-len & 15); g += len + (-len & 15); while (len > 0) { f -= 16; g -= 16; len -= 16; __m256i fi = _mm256_loadu_si256((__m256i *) f); __m256i gi = _mm256_loadu_si256((__m256i *) g); __m256i finew = _mm256_blendv_epi8(fi,gi,maskvec); __m256i ginew = _mm256_blendv_epi8(gi,fi,maskvec); ginew = _mm256_sub_epi16(montproduct(ginew,f0vec,f0vecqinv),montproduct(finew,g0vec,g0vecqinv)); _mm256_storeu_si256((__m256i *) (f+1),finew); _mm256_storeu_si256((__m256i *) g,ginew); } } void crypto_core(unsigned char *outbytes,const unsigned char *inbytes,const unsigned char *kbytes,const unsigned char *cbytes) { small *in = (void *) inbytes; int loop; Fq out[p],f[ppad],g[ppad],v[ppad],r[ppad]; Fq f0,g0; Fq scale; int i; int delta = 1; int minusdelta; int fgflip; int swap; for (i = 0;i < ppad;++i) f[i] = 0; f[0] = 1; f[p-1] = -1; f[p] = -1; /* generalization: initialize f to reversal of any deg-p polynomial m */ for (i = 0;i < p;++i) g[i] = in[p-1-i]; for (i = p;i < ppad;++i) g[i] = 0; for (i = 0;i < ppad;++i) r[i] = 0; r[0] = Fq_recip(3); for (i = 0;i < ppad;++i) v[i] = 0; for (loop = 0;loop < p;++loop) { g0 = Fq_freeze(g[0]); f0 = f[0]; if (q > 5167) f0 = Fq_freeze(f0); minusdelta = -delta; swap = crypto_int16_negative_mask(minusdelta) & crypto_int16_nonzero_mask(g0); delta ^= swap & (delta ^ minusdelta); delta += 1; fgflip = swap & (f0 ^ g0); f0 ^= fgflip; g0 ^= fgflip; f[0] = f0; vectormodq_swapeliminate(f+1,g+1,p,f0,g0,swap); vectormodq_xswapeliminate(v,r,loop+1,f0,g0,swap); } for (loop = p-1;loop > 0;--loop) { g0 = Fq_freeze(g[0]); f0 = f[0]; if (q > 5167) f0 = Fq_freeze(f0); minusdelta = -delta; swap = crypto_int16_negative_mask(minusdelta) & crypto_int16_nonzero_mask(g0); delta ^= swap & (delta ^ minusdelta); delta += 1; fgflip = swap & (f0 ^ g0); f0 ^= fgflip; g0 ^= fgflip; f[0] = f0; vectormodq_swapeliminate(f+1,g+1,loop,f0,g0,swap); vectormodq_xswapeliminate(v,r,p,f0,g0,swap); } scale = Fq_recip(Fq_freeze(f[0])); for (i = 0;i < p;++i) out[i] = Fq_bigfreeze(scale*(int32)Fq_freeze(v[p-i])); crypto_encode_pxint16(outbytes,out); outbytes[2*p] = crypto_int16_nonzero_mask(delta); }