-rwxr-xr-x 12849 libntruprime-20241021/src/decode/PxQ/avx/decodegen.py raw
#!/usr/bin/env python3
import sys
from math import floor
Mlen = 761
if len(sys.argv) > 1: Mlen = int(sys.argv[1])
m0 = 4591
if len(sys.argv) > 2: m0 = int(sys.argv[2])
m1 = m0
if len(sys.argv) > 3: m1 = int(sys.argv[3])
# M is Mlen-1 copies of m0, plus 1 copy of m1
offset = m0//2
if len(sys.argv) > 4: offset = int(sys.argv[4])
div3 = False
if len(sys.argv) > 5: div3 = sys.argv[5]=='True'
todo = 'R0'
top = 16384
print('/* auto-generated; do not edit */')
print('/* 20240812 djb: more cryptoint usage */')
print('')
print('#include <immintrin.h>')
print('#include "crypto_decode.h"')
print('#include "crypto_int16.h"')
print('#include "crypto_int32.h"')
print('#define int16 crypto_int16')
print('#define int32 crypto_int32')
print("""
static inline int16 mullo(int16 x,int16 y)
{
return x*y;
}
static inline int16 mulhi(int16 x,int16 y)
{
return (x*(int32)y)>>16;
}
static inline __m256i add(__m256i x,__m256i y)
{
return _mm256_add_epi16(x,y);
}
static inline __m256i sub(__m256i x,__m256i y)
{
return _mm256_sub_epi16(x,y);
}
static inline __m256i shiftleftconst(__m256i x,int16 y)
{
return _mm256_slli_epi16(x,y);
}
static inline __m256i signedshiftrightconst(__m256i x,int16 y)
{
return _mm256_srai_epi16(x,y);
}
static inline __m256i subconst(__m256i x,int16 y)
{
return sub(x,_mm256_set1_epi16(y));
}
static inline __m256i mulloconst(__m256i x,int16 y)
{
return _mm256_mullo_epi16(x,_mm256_set1_epi16(y));
}
static inline __m256i mulhiconst(__m256i x,int16 y)
{
return _mm256_mulhi_epi16(x,_mm256_set1_epi16(y));
}
static inline __m256i ifgesubconst(__m256i x,int16 y)
{
__m256i y16 = _mm256_set1_epi16(y);
__m256i top16 = _mm256_set1_epi16(y-1);
return sub(x,_mm256_cmpgt_epi16(x,top16) & y16);
}
static inline __m256i ifnegaddconst(__m256i x,int16 y)
{
return add(x,signedshiftrightconst(x,15) & _mm256_set1_epi16(y));
}
""")
print('void crypto_decode(void *v,const unsigned char *s)')
print('{')
print(' int16 *%s = v;' % todo)
tmparrays = None
layer = 1
x = Mlen
while x > 1:
x = (x+1)//2
if tmparrays == None:
tmparrays = 'int16 '
else:
tmparrays += ','
tmparrays += 'R%d[%d]' % (layer,x)
layer += 1
if tmparrays:
print(' %s;' % tmparrays)
print(' long long i;')
print(' int16 a0,a1,a2;')
print(' __m256i A0,A1,A2,S0,S1,B0,B1,C0,C1;')
def poke(todo,pos,contents):
if todo == 'R0' and div3:
return '%s[%s] = 3*%s%+d;' % (todo,pos,contents,-offset)
if todo == 'R0' and offset != 0:
return '%s[%s] = %s%+d;' % (todo,pos,contents,-offset)
return '%s[%s] = %s;' % (todo,pos,contents)
def mulmoddata(c,q):
y = (c<<16)%q
if y*2 >= q: y -= q
t = 0
u = q
while not u&1:
t += 1
u >>= 1
uinv = pow(u,16383,65536)
if uinv>=32768: uinv -= 65536
assert (u*uinv)%65536 == 1
z = (y-(c<<16))>>t
z = (z*uinv)%65536
if z >= 32768: z -= 65536
assert (z*u-((y-(c<<16))>>t))%(2**16) == 0
return y,t,u,uinv,z
def inner(indent,reading,inpos,m0,m1,bytes,outpos0,outpos1):
stanza = ''
y,t,u,uinv,z = mulmoddata(256,m0)
stanza += indent + 'a2 = a0 = %s[%s];\n' % (reading,inpos)
a0lower,a0upper = 0,1<<14
for loop in range(bytes):
if y > 0:
a0lower,a0upper = a0lower*y,a0upper*y
else:
a0lower,a0upper = a0upper*y,a0lower*y
a0lower,a0upper = a0lower-(m0<<15),a0upper+(m0<<15)
a0lower,a0upper = a0lower>>16,a0upper>>16
stanza += indent + 'a0 = mulhi(a0,%d)-mulhi(mullo(a0,%d),%d); /* %d...%d */\n' % (y,z,m0,a0lower,a0upper)
a0upper += 255
stanza += indent + 'a0 += s[%d*i+%d]; /* %d...%d */\n' % (bytes,bytes-1-loop,a0lower,a0upper)
if a0upper >= 2*m0:
y1,t1,u1,uinv1,z1 = mulmoddata(1,m0)
if y1 > 0:
a0lower,a0upper = a0lower*y1,a0upper*y1
else:
a0lower,a0upper = a0upper*y1,a0lower*y1
a0lower,a0upper = a0lower-(m0<<15),a0upper+(m0<<15)
a0lower,a0upper = a0lower>>16,a0upper>>16
stanza += indent + 'a0 = mulhi(a0,%d)-mulhi(mullo(a0,%d),%d); /* %d...%d */\n' % (y1,z1,m0,a0lower,a0upper)
while a0upper >= m0:
a0lower,a0upper = a0lower-m0,a0upper-m0
stanza += indent + 'a0 -= %d; /* %d..>%d */\n' % (m0,a0lower,a0upper)
while a0lower < 0:
a0lower,a0upper = min(0,a0lower+m0),max(m0-1,a0upper)
stanza += indent + 'a0 += %d&crypto_int16_negative_mask(a0); /* %d...%d */\n' % (m0,a0lower,a0upper)
if bytes == 0:
stanza += indent + 'a1 = (a2-a0)>>%d;\n' % t
elif bytes == 1:
if t == 0:
stanza += indent + 'a1 = (a2<<8)+s[i]-a0;\n'
elif t == 8:
stanza += indent + 'a1 = a2+((s[i]-a0)>>8);\n'
elif t < 8:
stanza += indent + 'a1 = (a2<<%d)+((s[i]-a0)>>%d);\n' % (8-t,t)
else:
stanza += indent + 'a1 = (a2+((s[i]-a0)>>8))>>%d;\n' % t-8
else:
assert bytes == 2
if t == 0:
stanza += indent + 'a1 = (s[2*i+1]<<8)+s[2*i]-a0;\n'
elif t == 8:
stanza += indent + 'a1 = (a2<<8)+s[2*i+1]+((s[2*i]-a0)>>8);\n'
elif t < 8:
stanza += indent + 'a1 = (a2<<%d)+(s[2*i+1]<<%d)+((s[2*i]-a0)>>%d);\n' % (16-t,8-t,t)
else:
stanza += indent + 'a1 = ((((int32)a2)<<16)+(s[2*i+1]<<8)+s[2*i]-a0)>>%d;\n' % t
stanza += indent + 'a1 = mullo(a1,%d);\n' % uinv
stanza += '\n'
stanza += indent + '/* invalid inputs might need reduction mod %d */\n' % m1
stanza += indent + 'a1 -= %d;\n' % m1
stanza += indent + 'a1 += %d&crypto_int16_negative_mask(a1);\n' % m1
stanza += '\n'
stanza += indent + '%s\n' % poke(todo,outpos0,'a0')
stanza += indent + '%s\n' % poke(todo,outpos1,'a1')
return stanza
# XXX: caller must ensure that outpos1 is outpos0+1
def inner16(indent,reading,inpos,m0,m1,bytes,outpos0,outpos1):
stanza = ''
y,t,u,uinv,z = mulmoddata(256,m0)
stanza += indent + 'A2 = A0 = _mm256_loadu_si256((__m256i *) &%s[%s]);\n' % (reading,inpos)
a0lower,a0upper = 0,1<<14
if bytes == 1:
stanza += indent + 'S0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *) (s+i)));\n'
if bytes == 2:
stanza += indent + 'S0 = _mm256_loadu_si256((__m256i *) (s+2*i));\n'
stanza += indent + 'S1 = _mm256_srli_epi16(S0,8);\n'
stanza += indent + 'S0 &= _mm256_set1_epi16(255);\n'
for loop in reversed(range(bytes)):
if y > 0:
a0lower,a0upper = a0lower*y,a0upper*y
else:
a0lower,a0upper = a0upper*y,a0lower*y
a0lower,a0upper = a0lower-(m0<<15),a0upper+(m0<<15)
a0lower,a0upper = a0lower>>16,a0upper>>16
stanza += indent + 'A0 = sub(mulhiconst(A0,%d),mulhiconst(mulloconst(A0,%d),%d)); /* %d...%d */\n' % (y,z,m0,a0lower,a0upper)
a0upper += 255
stanza += indent + 'A0 = add(A0,S%d); /* %d...%d */\n' % (loop,a0lower,a0upper)
if a0upper >= 2*m0:
y1,t1,u1,uinv1,z1 = mulmoddata(1,m0)
if y1 > 0:
a0lower,a0upper = a0lower*y1,a0upper*y1
else:
a0lower,a0upper = a0upper*y1,a0lower*y1
a0lower,a0upper = a0lower-(m0<<15),a0upper+(m0<<15)
a0lower,a0upper = a0lower>>16,a0upper>>16
stanza += indent + 'A0 = sub(mulhiconst(A0,%d),mulhiconst(mulloconst(A0,%d),%d)); /* %d...%d */\n' % (y1,z1,m0,a0lower,a0upper)
while a0upper >= m0:
a0lower,a0upper = a0lower-m0,a0upper-m0
stanza += indent + 'A0 = subconst(A0,%d); /* %d...%d */\n' % (m0,a0lower,a0upper)
while a0lower < 0:
a0lower,a0upper = min(0,a0lower+m0),max(m0-1,a0upper)
stanza += indent + 'A0 = ifnegaddconst(A0,%d); /* %d...%d */\n' % (m0,a0lower,a0upper)
if bytes == 0:
stanza += indent + 'A1 = signedshiftrightconst(sub(A2,A0),%d);\n' % t
elif bytes == 1:
if t == 0:
stanza += indent + 'A1 = add(shiftleftconst(A2,8),sub(S0,A0));\n'
elif t == 8:
stanza += indent + 'A1 = add(A2,signedshiftrightconst(sub(S0,A0),8));\n'
elif t < 8:
stanza += indent + 'A1 = add(shiftleftconst(A2,%d),signedshiftrightconst(sub(S0,A0),%d));\n' % (8-t,t)
else:
raise Exception('shift distances above 8 unimplemented')
else:
assert bytes == 2
if t == 0:
stanza += indent + 'A1 = add(shiftleftconst(S1,8),sub(S0,A0));\n'
elif t == 8:
stanza += indent + 'A1 = add(add(shiftleftconst(A2,8),S1),signedshiftrightconst(sub(S0,A0),8));\n'
elif t < 8:
stanza += indent + 'A1 = add(add(shiftleftconst(A2,%d),shiftleftconst(S1,%d)),signedshiftrightconst(sub(S0,A0),%d));\n' % (16-t,8-t,t)
else:
raise Exception('shift distances above 8 unimplemented')
stanza += indent + 'A1 = mulloconst(A1,%d);\n' % uinv
stanza += '\n'
stanza += indent + '/* invalid inputs might need reduction mod %d */\n' % m1
# stanza += indent + 'A1 = subconst(A1,%d);\n' % m1
# stanza += indent + 'A1 = ifnegaddconst(A1,%d);\n' % m1
stanza += indent + 'A1 = ifgesubconst(A1,%d);\n' % m1
stanza += '\n'
if todo == 'R0' and div3:
stanza += indent + 'A0 = mulloconst(A0,3);\n'
stanza += indent + 'A1 = mulloconst(A1,3);\n'
if todo == 'R0' and offset != 0:
stanza += indent + 'A0 = subconst(A0,%d);\n' % offset
stanza += indent + 'A1 = subconst(A1,%d);\n' % offset
stanza += indent + '/* A0: r0r2r4r6r8r10r12r14 r16r18r20r22r24r26r28r30 */\n'
stanza += indent + '/* A1: r1r3r5r7r9r11r13r15 r17r19r21r23r25r27r29r31 */\n'
stanza += indent + 'B0 = _mm256_unpacklo_epi16(A0,A1);\n'
stanza += indent + 'B1 = _mm256_unpackhi_epi16(A0,A1);\n'
stanza += indent + '/* B0: r0r1r2r3r4r5r6r7 r16r17r18r19r20r21r22r23 */\n'
stanza += indent + '/* B1: r8r9r10r11r12r13r14r15 r24r25r26r27r28r29r30r31 */\n'
stanza += indent + 'C0 = _mm256_permute2x128_si256(B0,B1,0x20);\n'
stanza += indent + 'C1 = _mm256_permute2x128_si256(B0,B1,0x31);\n'
stanza += indent + '/* C0: r0r1r2r3r4r5r6r7 r8r9r10r11r12r13r14r15 */\n'
stanza += indent + '/* C1: r16r17r18r19r20r21r22r23 r24r25r26r27r28r29r30r31 */\n'
stanza += indent + '_mm256_storeu_si256((__m256i *) (&%s[%s]),C0);\n' % (todo,outpos0)
stanza += indent + '_mm256_storeu_si256((__m256i *) (16+&%s[%s]),C1);\n' % (todo,outpos0)
return stanza
def stanzaloop(looplen,reading,todo,m0,bytes):
stanza = ''
stanza += ' s -= %d;\n' % (bytes*looplen)
if looplen < 1: return stanza
if looplen % 16 == 0:
stanza += ' for (i = %d;i >= 0;i -= 16) {\n' % (looplen-16)
stanza += inner16(' ',reading,'i',m0,m0,bytes,'2*i','2*i+1')
stanza += ' }\n'
return stanza
if looplen >= 16:
stanza += ' i = %d;\n' % (looplen-16)
stanza += ' for (;;) {\n'
stanza += inner16(' ',reading,'i',m0,m0,bytes,'2*i','2*i+1')
stanza += ' if (!i) break;\n'
stanza += ' i = -16-((~15)&-i);\n'
stanza += ' }\n'
return stanza
stanza += ' for (i = %d;i >= 0;--i) {\n' % (looplen-1)
stanza += inner(' ',reading,'i',m0,m0,bytes,'2*i','2*i+1')
stanza += ' }\n'
return stanza
stanzas = []
layer = 1
reading = 'R%d' % layer
while Mlen > 1:
n0 = m0*m0
bytes0 = 0
while n0 >= top:
bytes0 += 1
n0 = (n0+255)>>8
if Mlen&1:
looplen = Mlen//2
r0 = '%s[%s]' % (reading,looplen)
stanza = ' %s\n' % poke(todo,2*looplen,r0)
stanza += stanzaloop(looplen,reading,todo,m0,bytes0)
n1 = m1
stanzas += [stanza]
else:
n1 = m0*m1
bytes1 = 0
while n1 >= top:
bytes1 += 1
n1 = (n1+255)>>8
if m1 == m0:
looplen = (Mlen+1)//2
stanza = stanzaloop(looplen,reading,todo,m0,bytes0)
else:
looplen = (Mlen-1)//2
stanza = ' i = 0;\n'
stanza += ' s -= %d;\n' % bytes1
stanza += inner(' ',reading,looplen,m0,m1,bytes1,2*looplen,2*looplen+1)
stanza += stanzaloop(looplen,reading,todo,m0,bytes0)
stanzas += [stanza]
if m0 == m1:
stanzas += [' /* %s ------> %s: reconstruct mod %d*[%d] */\n' % (reading,todo,Mlen,m0)]
else:
stanzas += [' /* %s ------> %s: reconstruct mod %d*[%d]+[%d] */\n' % (reading,todo,Mlen-1,m0,m1)]
m0,m1,Mlen = n0,n1,(Mlen+1)//2
layer += 1
todo = reading
reading = 'R%d' % layer
stanza = ''
stanza += ' s += crypto_decode_STRBYTES;\n'
stanza += ' a1 = 0;\n'
q = m1
y,t,u,uinv,z = mulmoddata(256,m1)
a1lower = 0
a1upper = 0
while m1 > 1:
if m1 != q:
stanza += ' a1 = mulhi(a1,%d)-mulhi(mullo(a1,%d),%d);\n' % (y,z,q)
if y > 0:
a1lower,a1upper = a1lower*y,a1upper*y
a1lower,a1upper = a1lower-(q<<15),a1upper+(q<<15)-1
else:
a1lower,a1upper = a1upper*y,a1lower*y
a1lower,a1upper = a1upper-(q<<15),a1lower+(q<<15)-1
a1lower >>= 16
a1upper >>= 16
a1upper += 255
stanza += ' a1 += *--s; /* %d...%d */\n' % (a1lower,a1upper)
m1 = (m1+255)>>8
while a1upper >= q:
a1lower,a1upper = a1lower-q,a1upper-q
stanza += ' a1 -= %d; /* %d...%d */\n' % (q,a1lower,a1upper)
while a1lower < 0:
a1lower,a1upper = min(0,a1lower+q),max(a1upper,q-1)
stanza += ' a1 += %d&crypto_int16_negative_mask(a1); /* %d...%d */\n' % (q,a1lower,a1upper)
stanza += ' %s\n' % poke(todo,0,'a1')
stanzas += [stanza]
for stanza in reversed(stanzas):
print(' ')
sys.stdout.write(stanza)
print('}')