-rwxr-xr-x 7956 libntruprime-20241021/src/encode/PxQ/avx/encodegen.py raw
#!/usr/bin/env python3
import sys
Mlen = 761
if len(sys.argv) > 1: Mlen = int(sys.argv[1])
m0 = 4591
if len(sys.argv) > 2: m0 = int(sys.argv[2])
m1 = m0
if len(sys.argv) > 3: m1 = int(sys.argv[3])
# M is Mlen-1 copies of m0, plus 1 copy of m1
offset = m0//2
if len(sys.argv) > 4: offset = int(sys.argv[4])
div3 = False
if len(sys.argv) > 5: div3 = sys.argv[5]=='True'
round = False
if len(sys.argv) > 6: round = sys.argv[6]=='True'
reading = 'R0'
top = 16384
print('/* auto-generated; do not edit */')
print('')
print('#include <immintrin.h>')
print('#include "crypto_encode.h"')
print('#include "crypto_int16.h"')
print('#include "crypto_uint16.h"')
print('#include "crypto_uint32.h"')
print('#define int16 crypto_int16')
print('#define uint16 crypto_uint16')
print('#define uint32 crypto_uint32')
print('')
print('void crypto_encode(unsigned char *out,const void *v)')
print('{')
print(' const int16 *%s = v;' % reading)
if Mlen > 1:
print(' /* XXX: caller could overlap R with input */')
Rlen = (Mlen+1)//2
print(' uint16 R[%d];' % Rlen)
if Mlen > 2**29:
print(' long long i;')
else:
print(' long i;')
print(' const uint16 *reading;')
print(' uint16 *writing;')
print(' uint16 r0,r1;')
print(' uint32 r2;')
def access(reading,pos):
result = '%s[%s]' % (reading,pos)
if reading == 'R0':
if round:
result = '3*((10923*%s+16384)>>15)'%result
result = '((%s+%d)&16383)'%(result,offset)
if div3:
result = '(%s*10923)>>15'%result
return result
def printloop(looplen,reading,todo,m0,bytes):
if looplen <= 0: return
if looplen == 1:
print(' r0 = %s;' % access(reading,0))
print(' r1 = %s;' % access(reading,1))
print(' r2 = r0+r1*(uint32)%d;' % m0)
for j in range(bytes):
print(' *out++ = r2; r2 >>= 8;')
print(' %s[0] = r2;' % todo)
return
if looplen >= 12 and bytes == 1:
print(' reading = (uint16 *) %s;' % reading)
print(' writing = %s;' % todo)
print(' i = %d;' % ((looplen+7)//8))
print(' while (i > 0) {')
print(' __m256i x,y;')
print(' --i;')
if looplen%8:
print(' if (!i) {')
print(' reading -= %d;' % (2*(8-(looplen%8))))
print(' writing -= %d;' % (8-(looplen%8)))
print(' out -= %d;' % (8-(looplen%8)))
print(' }')
print(' x = _mm256_loadu_si256((__m256i *) reading);')
if reading == 'R0' and round:
print(' x = _mm256_mulhrs_epi16(x,_mm256_set1_epi16(10923));')
if 0:
print(' x = _mm256_mullo_epi16(x,_mm256_set1_epi16(3));')
else:
print(' x = _mm256_add_epi16(x,_mm256_add_epi16(x,x));')
if reading == 'R0' and offset != 0:
print(' x = _mm256_add_epi16(x,_mm256_set1_epi16(%d));' % offset)
if reading == 'R0':
print(' x &= _mm256_set1_epi16(16383);')
if reading == 'R0' and div3:
print(' x = _mm256_mulhi_epi16(x,_mm256_set1_epi16(21846));')
print(' y = x & _mm256_set1_epi32(65535);')
print(' x = _mm256_srli_epi32(x,16);')
print(' x = _mm256_mullo_epi32(x,_mm256_set1_epi32(%d));' % m0)
print(' x = _mm256_add_epi32(y,x);')
print(' x = _mm256_shuffle_epi8(x,_mm256_set_epi8(')
print(' 12,8,4,0,12,8,4,0,14,13,10,9,6,5,2,1,')
print(' 12,8,4,0,12,8,4,0,14,13,10,9,6,5,2,1')
print(' ));')
print(' x = _mm256_permute4x64_epi64(x,0xd8);')
print(' _mm_storeu_si128((__m128i *) writing,_mm256_extractf128_si256(x,0));')
print(' *((uint32 *) (out+0)) = _mm256_extract_epi32(x,4);')
print(' *((uint32 *) (out+4)) = _mm256_extract_epi32(x,6);')
print(' reading += 16;')
print(' writing += 8;')
print(' out += 8;')
print(' }')
return
if looplen >= 24 and bytes == 2:
print(' reading = (uint16 *) %s;' % reading)
print(' writing = %s;' % todo)
print(' i = %d;' % ((looplen+15)//16))
print(' while (i > 0) {')
print(' __m256i x,x2,y,y2;')
print(' --i;')
if looplen%16:
print(' if (!i) {')
print(' reading -= %d;' % (2*(16-(looplen%16))))
print(' writing -= %d;' % (16-(looplen%16)))
print(' out -= %d;' % (2*(16-(looplen%16))))
print(' }')
print(' x = _mm256_loadu_si256((__m256i *) (reading+0));')
print(' x2 = _mm256_loadu_si256((__m256i *) (reading+16));')
if reading == 'R0' and round:
print(' x = _mm256_mulhrs_epi16(x,_mm256_set1_epi16(10923));')
print(' x2 = _mm256_mulhrs_epi16(x2,_mm256_set1_epi16(10923));')
if 0:
print(' x = _mm256_mullo_epi16(x,_mm256_set1_epi16(3));')
print(' x2 = _mm256_mullo_epi16(x2,_mm256_set1_epi16(3));')
else:
print(' x = _mm256_add_epi16(x,_mm256_add_epi16(x,x));')
print(' x2 = _mm256_add_epi16(x2,_mm256_add_epi16(x2,x2));')
if reading == 'R0' and offset != 0:
print(' x = _mm256_add_epi16(x,_mm256_set1_epi16(%d));' % offset)
print(' x2 = _mm256_add_epi16(x2,_mm256_set1_epi16(%d));' % offset)
if reading == 'R0':
print(' x &= _mm256_set1_epi16(16383);')
print(' x2 &= _mm256_set1_epi16(16383);')
if reading == 'R0' and div3:
print(' x = _mm256_mulhi_epi16(x,_mm256_set1_epi16(21846));')
print(' x2 = _mm256_mulhi_epi16(x2,_mm256_set1_epi16(21846));')
print(' y = x & _mm256_set1_epi32(65535);')
print(' y2 = x2 & _mm256_set1_epi32(65535);')
print(' x = _mm256_srli_epi32(x,16);')
print(' x2 = _mm256_srli_epi32(x2,16);')
print(' x = _mm256_mullo_epi32(x,_mm256_set1_epi32(%d));' % m0)
print(' x2 = _mm256_mullo_epi32(x2,_mm256_set1_epi32(%d));' % m0)
print(' x = _mm256_add_epi32(y,x);')
print(' x2 = _mm256_add_epi32(y2,x2);')
print(' x = _mm256_shuffle_epi8(x,_mm256_set_epi8(')
print(' 15,14,11,10,7,6,3,2,13,12,9,8,5,4,1,0,')
print(' 15,14,11,10,7,6,3,2,13,12,9,8,5,4,1,0')
print(' ));')
print(' x2 = _mm256_shuffle_epi8(x2,_mm256_set_epi8(')
print(' 15,14,11,10,7,6,3,2,13,12,9,8,5,4,1,0,')
print(' 15,14,11,10,7,6,3,2,13,12,9,8,5,4,1,0')
print(' ));')
print(' x = _mm256_permute4x64_epi64(x,0xd8);')
print(' x2 = _mm256_permute4x64_epi64(x2,0xd8);')
print(' _mm256_storeu_si256((__m256i *) writing,_mm256_permute2f128_si256(x,x2,0x31));')
print(' _mm256_storeu_si256((__m256i *) out,_mm256_permute2f128_si256(x,x2,0x20));')
print(' reading += 32;')
print(' writing += 16;')
print(' out += 32;')
print(' }')
return
print(' for (i = 0;i < %d;++i) {' % looplen)
print(' r0 = %s;' % access(reading,'2*i'))
print(' r1 = %s;' % access(reading,'2*i+1'))
print(' r2 = r0+r1*(uint32)%d;' % m0)
for j in range(bytes):
print(' *out++ = r2; r2 >>= 8;')
print(' %s[i] = r2;' % todo)
print(' }')
todo = 'R'
while Mlen > 1:
print(' ')
n0 = m0*m0
bytes0 = 0
while n0 >= top:
bytes0 += 1
n0 = (n0+255)>>8
if Mlen&1:
looplen = Mlen//2
printloop(looplen,reading,todo,m0,bytes0)
r0 = access(reading,2*looplen)
print(' %s[%d] = %s;' % (todo,looplen,r0))
n1 = m1
else:
n1 = m0*m1
bytes1 = 0
while n1 >= top:
bytes1 += 1
n1 = (n1+255)>>8
if bytes1 == bytes0:
looplen = (Mlen+1)//2
printloop(looplen,reading,todo,m0,bytes0)
else:
looplen = (Mlen-1)//2
printloop(looplen,reading,todo,m0,bytes0)
print(' r0 = %s;' % access(reading,2*looplen))
print(' r1 = %s;' % access(reading,2*looplen+1))
print(' r2 = r0+r1*(uint32)%d;' % m0)
for j in range(bytes1):
print(' *out++ = r2; r2 >>= 8;')
print(' %s[%d] = r2;' % (todo,looplen))
m0,m1,Mlen = n0,n1,(Mlen+1)//2
reading = todo
print(' ')
print(' r0 = %s;' % access(reading,0))
while m1 > 1:
print(' *out++ = r0; r0 >>= 8;')
m1 = (m1+255)>>8
print('}')