#!/usr/bin/env python3

systems = ( # p,q,sk,pk,ct,w
  (653,4621,1518,994,897,288),
  (761,4591,1763,1158,1039,286),
  (857,5167,1999,1322,1184,322),
  (953,6343,2254,1505,1349,396),
  (1013,7177,2417,1623,1455,448),
  (1277,7879,3059,2067,1847,492),
)

apiall = {}

topdir = 'crypto_kem'

import shutil
import os
import re
from math import floor,ceil
import subprocess

def readfile(fn):
  with open(fn) as f:
    return f.read()

def writefile(fn,x):
  with open(fn,'w') as f:
    f.write(x)

def writefiletime(fn,x,t):
  writefile(fn,x)
  os.utime(fn,(t,t))

build_subroutine_done = set()
def build_subroutine(sub,macros):
  src = 'src/%s'%sub
  srctime = os.stat(src).st_mtime

  o = sub.split('/')[0]

  op = 'crypto_%s'%sub
  for m in macros:
    op = re.sub(m,str(macros[m]),op)

  if op in build_subroutine_done: return
  print('building %s'%op)

  shutil.rmtree(op,ignore_errors=True)
  os.makedirs(op)

  api = ''
  if sub.startswith('verify/'):
    api = '#define CRYPTO_BYTES %s\n'%macros['BYTES']
  if sub.startswith('decode/') or sub.startswith('encode/'):
    api = '#define CRYPTO_STRBYTES %s\n'%macros['STRBYTES']
    api += '#define CRYPTO_ITEMS %s\n'%macros['ITEMS']
    api += '#define CRYPTO_ITEMBYTES %s\n'%macros['ITEMBYTES']
  if sub.startswith('core/'):
    api = '#define CRYPTO_OUTPUTBYTES %s\n'%macros['OUTPUTBYTES']
    api += '#define CRYPTO_INPUTBYTES %s\n'%macros['INPUTBYTES']
    api += '#define CRYPTO_KEYBYTES %s\n'%macros['KEYBYTES']
    api += '#define CRYPTO_CONSTBYTES %s\n'%macros['CONSTBYTES']

  if o not in apiall: apiall[o] = ''
  apiall[o] += api.replace('CRYPTO_',op.replace('/','_'))

  for impl in sorted(os.listdir(src)):
    if impl == 'supercop': continue
    srci = '%s/%s'%(src,impl)
    srcitime = os.stat(srci).st_mtime
    opi = '%s/%s'%(op,impl)

    os.makedirs(opi)
    if impl != 'supercop':
      writefiletime('%s/api.h'%opi,api,srctime)

    for fn in sorted(os.listdir(srci)):
      if fn == 'mult768.c':
        if macros['P'] > 768: continue
      if fn in ('mult1024.c','precomp7681.inc','precomp10753.inc'):
        if macros['P'] < 768: continue
        if macros['P'] > 1024: continue
      if fn in ('mult1280.c'):
        if macros['P'] < 1024: continue
        if macros['P'] > 1280: continue

      srcif = '%s/%s'%(srci,fn)
      opif = '%s/%s'%(opi,fn)
      x = readfile(srcif)
      for m in macros:
        x = re.sub('{%s}'%m,'%s'%macros[m],x)
      writefiletime(opif,x,os.stat(srcif).st_mtime)
      if os.stat(srcif).st_mode&1: os.chmod(opif,0o755)

    if os.path.exists('%s/Makefile'%opi):
      subprocess.check_call('cd %s; make'%opi,shell=True)

    os.utime(opi,(srcitime,srcitime))

  changedsrc = False
  for m in macros:
    newsrc = re.sub(m,'%s'%macros[m],src)
    if newsrc != src:
      changedsrc = True
      src = newsrc
  if changedsrc and os.path.exists(src):
    srctime = os.stat(src).st_mtime
    for impl in sorted(os.listdir(src)):
      if impl == 'supercop': continue
      srci = '%s/%s'%(src,impl)
      srcitime = os.stat(srci).st_mtime
      opi = '%s/%s'%(op,impl)
      os.makedirs(opi)
      writefiletime('%s/api.h'%opi,api,srctime)
      for fn in sorted(os.listdir(srci)):
        srcif = '%s/%s'%(srci,fn)
        opif = '%s/%s'%(opi,fn)
        x = readfile(srcif)
        writefiletime(opif,x,os.stat(srcif).st_mtime)
        if os.stat(srcif).st_mode&1: os.chmod(opif,0o755)
      os.utime(opi,(srcitime,srcitime))

  os.utime(op,(srctime,srctime))

  build_subroutine_done.add(op)

# ----- sizes of encodings

limit = 16384

def Encode(R,M):
  if len(M) == 0: return []
  S = []
  if len(M) == 1:
    r,m = R[0],M[0]
    while m > 1:
      S += [r%256]
      r,m = r//256,(m+255)//256
    return S
  R2,M2 = [],[]
  for i in range(0,len(M)-1,2):
    m,r = M[i]*M[i+1],R[i]+M[i]*R[i+1]
    while m >= limit:
      S += [r%256]
      r,m = r//256,(m+255)//256
    R2 += [r]
    M2 += [m]
  if len(M)&1:
    R2 += [R[-1]]
    M2 += [M[-1]]
  return S+Encode(R2,M2)

def encodebytes(M):
  return len(Encode([0]*len(M),M))

# ----- main loop

for p,q,sk,pk,ct,w in systems:
  q23 = (q+2)//3
  q14 = round(2.0**14/q)
  q15 = floor(2.0**15/q)
  q18 = round(2.0**18/q)
  q27 = round(2.0**27/q)
  q31 = floor(2.0**31/q)
  qinv = pow(q,2**14-1,2**16)
  if qinv >= 2**15: qinv -= 2**16

  q10753 = (10753<<16)%q
  if q10753 >= q/2: q10753 -= q

  ppadsort = p
  if p == 761: ppadsort = 768
  if p == 953: ppadsort = 960
  if p == 1013: ppadsort = 1024
  if p == 1277: ppadsort = 1280
  assert ppadsort >= p
  
  ppad = p
  while ppad%16 != 1: ppad += 1

  ppad64 = p
  while ppad64%64 != 1: ppad64 += 1

  # ----- parameter checks

  assert p > 0
  assert p%4 == 1
  assert q > 0
  assert q%6 == 1
  assert w > 0
  assert 2*p >= 3*w
  assert q >= 16*w+1
  assert p < 1280
  assert q < 8192

  smallbytes = ceil(p/4.0)
  roundedbytes = encodebytes([q23]*p)

  assert pk == encodebytes([q]*p)
  assert sk == 3*smallbytes+pk+32
  assert ct == 32+roundedbytes

  # ----- build relevant subroutines first

  build_subroutine('verify/BYTES',{'BYTES':ct})

  macros = {'STRBYTES':2,'ITEMS':1,'ITEMBYTES':2}
  build_subroutine('decode/int16',macros)
  build_subroutine('encode/int16',macros)

  macros = {'P':p,'STRBYTES':(p+3)//4,'ITEMS':p,'ITEMBYTES':1}
  macros['AVXLOOPS'] = ceil(p/128)
  macros['AVXOVERSHOOT'] = 32*ceil(p/128)-floor(p/4)
  build_subroutine('decode/Px3',macros)
  build_subroutine('encode/Px3',macros)

  macros = {'P':p,'STRBYTES':p*2,'ITEMS':p,'ITEMBYTES':2}
  build_subroutine('decode/Pxint16',macros)
  build_subroutine('encode/Pxint16',macros)

  macros = {'P':p,'STRBYTES':p*4,'ITEMS':p,'ITEMBYTES':4}
  build_subroutine('decode/Pxint32',macros)

  macros = {'P':p,'Q':q,'Q12':(q-1)//2,'STRBYTES':encodebytes([q]*p),'ITEMS':p,'ITEMBYTES':2}
  build_subroutine('decode/PxQ',macros)
  build_subroutine('encode/PxQ',macros)

  macros = {'P':p,'Q':q,'Q12':(q-1)//2,'R':q23,'STRBYTES':encodebytes([q23]*p),'ITEMS':p,'ITEMBYTES':2}
  build_subroutine('decode/PxR',macros)
  build_subroutine('encode/PxR',macros)
  build_subroutine('encode/PxRround',macros)

  macros = {'P':p,'STRBYTES':p,'ITEMS':p,'ITEMBYTES':2}
  build_subroutine('encode/Pxfreeze3',macros)

  macros = {'P':p,'PPAD64':ppad64,'OUTPUTBYTES':p+1,'INPUTBYTES':p,'KEYBYTES':0,'CONSTBYTES':0}
  build_subroutine('core/inv3sntrupP',macros)

  macros = {'P':p,'OUTPUTBYTES':p,'INPUTBYTES':p,'KEYBYTES':p,'CONSTBYTES':0}
  build_subroutine('core/mult3sntrupP',macros)

  macros = {'P':p,'Q':q,'QINV':qinv,
    'Q18':q18,'15Q':q15,'Q27':q27,'Q31':q31,'Q10753':q10753,
    'OUTPUTBYTES':2*p,'INPUTBYTES':2*p,'KEYBYTES':p,'CONSTBYTES':0}
  build_subroutine('core/multsntrupP',macros)

  macros = {'P':p,'PPAD':ppad,'Q':q,'QINV':qinv,
    'Q14':q14,'15Q':q15,'Q18':q18,'Q27':q27,'Q31':q31,
    'OUTPUTBYTES':2*p+1,'INPUTBYTES':p,'KEYBYTES':0,'CONSTBYTES':0}
  build_subroutine('core/invsntrupP',macros)

  macros = {'P':p,'Q':q,'OUTPUTBYTES':2*p,'INPUTBYTES':2*p,'KEYBYTES':0,'CONSTBYTES':0}
  build_subroutine('core/scale3sntrupP',macros)

  macros = {'P':p,'OUTPUTBYTES':2,'INPUTBYTES':p,'KEYBYTES':0,'CONSTBYTES':0}
  macros['AVXENDINGMASK'] = ','.join(['1']*(p%32)+['0']*(32-(p%32)))
  build_subroutine('core/weightsntrupP',macros)

  macros = {'P':p,'W':w,'OUTPUTBYTES':p,'INPUTBYTES':p,'KEYBYTES':0,'CONSTBYTES':0}
  build_subroutine('core/wforcesntrupP',macros)

  # ----- and now build the kem

  system = f'sntrup{p}'
  src = 'src/kem/sntrupP'
  srctime = os.stat(src).st_mtime

  op = '%s/%s'%(topdir,system)
  print('building %s'%op)

  shutil.rmtree(op,ignore_errors=True)

  os.makedirs(op)

  opreal = 'crypto_kem/%s'%system

  api  = '#define CRYPTO_SECRETKEYBYTES %s\n'%sk
  api += '#define CRYPTO_PUBLICKEYBYTES %s\n'%pk
  api += '#define CRYPTO_CIPHERTEXTBYTES %s\n'%ct
  api += '#define CRYPTO_BYTES 32\n'

  if 'kem' not in apiall: apiall['kem'] = ''
  apiall['kem'] += api.replace('CRYPTO_',opreal.replace('/','_'))

  for impl in sorted(os.listdir(src)):
    srci = '%s/%s'%(src,impl)
    opi = '%s/%s'%(op,impl)

    os.makedirs(opi)

    target = '%s/api.h'%opi
    writefiletime(target,api,os.stat(srci).st_mtime)

    for fn in sorted(os.listdir(srci)):
      srcif = '%s/%s'%(srci,fn)
      opif = '%s/%s'%(opi,fn)

      if (impl,fn) in (('factored','params.h'),('avx','Makefile')):
        x = readfile(srcif)
        x = re.sub('{P}','%s'%p,x)
        x = re.sub('{PPADSORT}','%s'%ppadsort,x)
        x = re.sub('{Q}','%s'%q,x)
        x = re.sub('{R}','%d'%q23,x)
        x = re.sub('{Q18}','%d'%q18,x)
        x = re.sub('{Q27}','%d'%q27,x)
        x = re.sub('{Q31}','%d'%q31,x)
        x = re.sub('{W}','%s'%w,x)
        x = re.sub('{CIPHERTEXTBYTES}','%s'%ct,x)
        writefiletime(opif,x,os.stat(srcif).st_mtime)
        continue

      if p != 761 and fn == 'CHANGES': continue # XXX

      x = readfile(srcif)

      writefiletime(opif,x,os.stat(srcif).st_mtime)
      if os.stat(srcif).st_mode&1: os.chmod(opif,0o755)

  for impl in sorted(os.listdir('src/kem/sntrupP')):
    opi = '%s/%s'%(op,impl)
    if os.path.exists('%s/Makefile'%opi):
      subprocess.check_call('cd %s; make'%opi,shell=True)

  os.utime(op,(srctime,srctime))

  for impl in sorted(os.listdir(src)):
    srci = '%s/%s'%(src,impl)
    srcitime = os.stat(srci).st_mtime
    opi = '%s/%s'%(op,impl)
    os.utime(opi,(srcitime,srcitime))

with open('api','w') as f:
  for sub in 'verify','decode','encode','core','kem':
    f.write(apiall[sub])
    f.write('\n')