#!/usr/bin/env python3

import random

def save(fn,x):
  with open(fn) as f:
    cur = f.read()
  if cur == x: return
  with open(fn,'w') as f:
    f.write(x)

# ----- Python versions of the subroutines

def littleendian(x,bytes):
  return [255&(x>>(8*i)) for i in range(bytes)]

def int16(x):
  x %= 65536
  if x >= 32768: x -= 65536
  return x

def freeze(x,q):
  x %= q
  if x+x >= q: x -= q
  return x

limit = 16384

def Encode(R,M):
  if len(M) == 0: return []
  S = []
  if len(M) == 1:
    r,m = R[0],M[0]
    while m > 1:
      S += [r%256]
      r,m = r//256,(m+255)//256
    return S
  R2,M2 = [],[]
  for i in range(0,len(M)-1,2):
    m,r = M[i]*M[i+1],R[i]+M[i]*R[i+1]
    while m >= limit:
      S += [r%256]
      r,m = r//256,(m+255)//256
    R2 += [r]
    M2 += [m]
  if len(M)&1:
    R2 += [R[-1]]
    M2 += [M[-1]]
  return S+Encode(R2,M2)

def Decode(S,M):
  if len(M) == 0: return []
  if len(M) == 1: return [sum(S[i]*256**i for i in range(len(S)))%M[0]]
  k = 0
  bottom,M2 = [],[]
  for i in range(0,len(M)-1,2):
    m,r,t = M[i]*M[i+1],0,1
    while m >= limit:
      r,t,k,m = r+S[k]*t,t*256,k+1,(m+255)//256
    bottom += [(r,t)]
    M2 += [m]
  if len(M)&1:
    M2 += [M[-1]]
  R2 = Decode(S[k:],M2)
  R = []
  for i in range(0,len(M)-1,2):
    r,t = bottom[i//2]
    r += t*R2[i//2]
    R += [r%M[i]]
    R += [(r//M[i])%M[i+1]]
  if len(M)&1:
    R += [R2[-1]]
  return R

def decode_int16(s):
  s = list(s)
  assert len(s) == 2
  assert all(si >= 0 for si in s)
  assert all(si < 256 for si in s)
  x = [s[0]+256*s[1]]
  return {'x':x,'s':s}

def decode_pxint16(p,s):
  s = list(s)
  assert len(s) == 2*p
  assert all(si >= 0 for si in s)
  assert all(si < 256 for si in s)
  x = [s[2*i]+256*s[2*i+1] for i in range(p)]
  return {'x':x,'s':s}

def decode_pxint32(p,s):
  s = list(s)
  assert len(s) == 4*p
  assert all(si >= 0 for si in s)
  assert all(si < 256 for si in s)
  x = [s[4*i]+256*s[4*i+1]+65536*s[4*i+2]+16777216*s[4*i+3] for i in range(p)]
  return {'x':x,'s':s}

def decode_px3(p,s):
  s = list(s)
  assert len(s) == (p+3)/4
  assert all(si >= 0 for si in s)
  assert all(si < 256 for si in s)
  x = [((s[i//4]>>(2*(i%4)))&3)-1 for i in range(p)]
  return {'x':x,'s':s}

def decode_pxr(p,r,s):
  s = list(s)
  assert all(si >= 0 for si in s)
  assert all(si < 256 for si in s)
  x = Decode(s,p*[r])
  x = [3*xi-3*(r-1)//2 for xi in x]
  return {'x':x,'s':s}

def decode_pxq(p,q,s):
  s = list(s)
  assert all(si >= 0 for si in s)
  assert all(si < 256 for si in s)
  x = Decode(s,p*[q])
  x = [xi-(q-1)//2 for xi in x]
  return {'x':x,'s':s}

def encode_int16(x):
  x = list(x)
  assert len(x) == 1
  assert all(xi >= 0 for xi in x)
  assert all(xi < 65536 for xi in x)
  s = [x[0]%256,x[0]//256]
  return {'s':s,'x':x}

def encode_pxint16(p,x):
  x = list(x)
  assert len(x) == p
  assert all(xi >= 0 for xi in x)
  assert all(xi < 65536 for xi in x)
  s = []
  for xi in x:
    s += [xi%256,xi//256]
  return {'s':s,'x':x}

def encode_px3(p,x):
  x = list(x)
  assert len(x) == p
  assert all(xi >= 0 for xi in x)
  assert all(xi < 256 for xi in x)
  xpad = list(x)
  while len(xpad)%4: xpad += [-1]
  s = [(xpad[i]+1)+4*(xpad[i+1]+1)+16*(xpad[i+2]+1)+64*(xpad[i+3]+1) for i in range(0,len(xpad),4)]
  s = [si%256 for si in s]
  assert len(s) == (p+3)/4
  return {'s':s,'x':x}

def encode_pxfreeze3(p,x):
  x = list(x)
  assert len(x) == p
  assert all(xi >= -32768 for xi in x)
  assert all(xi < 32768 for xi in x)
  s = [xi-3*((10923*xi+16384)>>15) for xi in x]
  s = [si%256 for si in s]
  return {'s':s,'x':x}

def encode_pxr(p,r,x):
  x = list(x)
  assert len(x) == p
  assert all(xi >= -32768 for xi in x)
  assert all(xi < 32768 for xi in x)
  R = [(((xi+3*(r-1)//2)&16383)*10923)>>15 for xi in x]
  s = Encode(R,p*[r])
  return {'s':s,'x':x}

def encode_pxrround(p,r,x):
  x = list(x)
  assert len(x) == p
  assert all(xi >= -32768 for xi in x)
  assert all(xi < 32768 for xi in x)
  y = [3*((10923*xi+16384)>>15) for xi in x]
  R = [(((yi+3*(r-1)//2)&16383)*10923)>>15 for yi in y]
  s = Encode(R,p*[r])
  return {'s':s,'x':x}

def encode_pxq(p,q,x):
  x = list(x)
  assert len(x) == p
  assert all(xi >= -32768 for xi in x)
  assert all(xi < 32768 for xi in x)
  R = [(xi+(q-1)//2)&16383 for xi in x]
  s = Encode(R,p*[q])
  return {'s':s,'x':x}

def mult(p,f,g):
  f = list(f)
  assert len(f) == p
  g = list(g)
  assert len(g) == p
  fg = [0]*(p+p-1)
  for i in range(p):
    for j in range(p):
      fg[i+j] += f[i]*g[j]
  for i in reversed(range(0,p-1)):
    fg[i] += fg[i+p]
    fg[i+1] += fg[i+p]
    fg[i+p] = 0
  return fg[:p]

def recip(q,R0,R1):
  R0 = list(R0)
  R1 = list(R1)
  d = len(R1)
  assert d > 0
  assert len(R0) == d+1
  f = [fi%q for fi in reversed(R0)]
  g = [gi%q for gi in reversed(R1)]+[0]
  v = [0]
  r = [1]
  n = 2*d-1
  delta = 1

  while n > 0:
    v = [0]+v
    r = r+[0]
    assert f[0] != 0
    f = f[:n]
    if delta > 0 and g[0] != 0: delta,f,g,v,r = -delta,g,f,r,v
    f0,g0 = f[0],g[0]
    delta += 1
    g = [(f0*gi-g0*fi)%q for fi,gi in zip(f,g)]
    g = g[1:]+[0]
    v = v[:d]
    r = r[:d]
    r = [(f0*ri-g0*vi)%q for ri,vi in zip(r,v)]
    n -= 1
    g = g[:n]

  if delta != 0: return

  f0inv = pow(f[0],q-2,q)
  v = v[:d]
  v = [(f0inv*vi)%q for vi in v]
  v.reverse()
  return v

invcheck = True

def core_inv3sntrup(p,f):
  f = list(f)
  assert len(f) == p
  assert all(fi >= 0 for fi in f)
  assert all(fi < 256 for fi in f)
  F = [[0,1,0,-1][3&fi] for fi in f]
  v = recip(3,[-1,-1]+[0]*(p-2)+[1],F)
  if v is None:
    result = [0]*p+[255]
  else:
    if invcheck:
      vf = mult(p,v,F)
      vf = [vfi%3 for vfi in vf]
      assert vf == [1]+[0]*(p-1)
    result = [freeze(vi,3)%256 for vi in v]+[0]
  return {'h':result,'n':f,'k':[],'c':[]}

def core_invsntrup(p,q,f):
  f = list(f)
  assert len(f) == p
  assert all(fi >= 0 for fi in f)
  assert all(fi < 256 for fi in f)
  F = [3*freeze(fi,256) for fi in f]
  v = recip(q,[-1,-1]+[0]*(p-2)+[1],F)
  if v is None:
    result = [0]*(2*p)+[255]
  else:
    if invcheck:
      vf = mult(p,v,F)
      vf = [vfi%q for vfi in vf]
      assert vf == [1]+[0]*(p-1)
    result = []
    for vi in v:
      vi = freeze(vi,q)
      result += littleendian(vi,2)
    result += [0]
  assert len(result) == 2*p+1
  return {'h':result,'n':f,'k':[],'c':[]}

def core_mult3sntrup(p,f,g):
  f = list(f)
  assert len(f) == p
  assert all(fi >= 0 for fi in f)
  assert all(fi < 256 for fi in f)
  f = [[0,1,0,-1][3&fi] for fi in f]
  g = list(g)
  assert len(g) == p
  assert all(gi >= 0 for gi in g)
  assert all(gi < 256 for gi in g)
  g = [[0,1,0,-1][3&gi] for gi in g]
  fg = mult(p,f,g)
  fg = [freeze(fgi,3)%256 for fgi in fg]
  return {'h':fg,'n':f,'k':g,'c':[]}

def core_multsntrup(p,q,f,g):
  f = list(f)
  assert len(f) == p
  assert all(fi >= -32768 for fi in f)
  assert all(fi < 32768 for fi in f)
  g = list(g)
  assert len(g) == p
  assert all(gi >= 0 for gi in g)
  assert all(gi < 256 for gi in g)
  g = [[0,1,0,-1][3&gi] for gi in g]
  fg = mult(p,f,g)
  fg = [freeze(fgi,q) for fgi in fg]
  encodefg = []
  for fgi in fg: encodefg += littleendian(fgi,2)
  encodef = []
  for fi in f: encodef += littleendian(fi,2)
  return {'h':encodefg,'n':encodef,'k':g,'c':[]}

def core_scale3sntrup(p,q,f):
  f = list(f)
  assert len(f) == p
  assert all(fi >= -32768 for fi in f)
  assert all(fi < 32768 for fi in f)

  h = []
  for x in f:
    x = int16(3*x)
    x = int16(x-(q+1)//2)
    if x < 0: x = int16(x+q)
    if x < 0: x = int16(x+q)
    x = int16(x-(q-1)//2)
    h += [x]

  encodef = []
  for fi in f: encodef += littleendian(fi,2)
  encodeh = []
  for hi in h: encodeh += littleendian(hi,2)
  return {'h':encodeh,'n':encodef,'k':[],'c':[]}

def core_weightsntrup(p,x):
  x = list(x)
  assert len(x) == p
  assert all(xi >= 0 for xi in x)
  assert all(xi < 256 for xi in x)
  y = sum(xi&1 for xi in x)
  assert y >= 0
  assert y <= p
  return {'h':littleendian(y,2),'n':x,'k':[],'c':[]}

def core_wforcesntrup(p,w,x):
  x = list(x)
  assert len(x) == p
  assert all(xi >= 0 for xi in x)
  assert all(xi < 256 for xi in x)
  y = sum(xi&1 for xi in x)
  assert y >= 0
  assert y <= p
  z = x if y == w else [1]*w+[0]*(p-w)
  return {'h':z,'n':x,'k':[],'c':[]}

# ----- precomputed test vectors

precomputed = {}
precomputedxtype = {}

def precompute():
  global precomputed

  systems = ( # p,q,sk,pk,ct,w,ninv
    (653,4621,1518,994,897,288,[2,2,0,1]),
    (761,4591,1763,1158,1039,286,[2,2,0,2,2,2,1,2,1,2,0,2,1,1,1,1,2,0,1,1]),
    (857,5167,1999,1322,1184,322,[1,0,0,1,2,0,0,0,1,1,1,2,0,2,1,2,2,0,0,1,1,0,1,2,0,1]),
    (953,6343,2254,1505,1349,396,[1,0,2,1,0,1]),
    (1013,7177,2417,1623,1455,448,[2,1,0,0,1]),
    (1277,7879,3059,2067,1847,492,[2,2,0,1]),
  )

  random.seed('decode_int16')
  inputs = [[random.randrange(256) for j in range(2)] for loop in range(16)]
  precomputed['decode','int16'] = [decode_int16(s) for s in inputs]
  precomputedxtype['decode','int16'] = 'uint16_t'

  random.seed('encode_int16')
  inputs = [[random.randrange(65536)] for loop in range(16)]
  precomputed['encode','int16'] = [encode_int16(x) for x in inputs]
  precomputedxtype['encode','int16'] = 'uint16_t'

  numtests = 4

  for p,q,sk,pk,ct,w,ninv in systems:
    r = (q+2)//3
    assert q == 3*r-2

    random.seed(f'decode_{p}xint16')
    inputs = [[random.randrange(256) for j in range(2*p)] for loop in range(numtests)]
    precomputed['decode',f'{p}xint16'] = [decode_pxint16(p,s) for s in inputs]
    precomputedxtype['decode',f'{p}xint16'] = 'uint16_t'

    random.seed(f'decode_{p}xint32')
    inputs = [[random.randrange(256) for j in range(4*p)] for loop in range(numtests)]
    precomputed['decode',f'{p}xint32'] = [decode_pxint32(p,s) for s in inputs]
    precomputedxtype['decode',f'{p}xint32'] = 'uint32_t'

    random.seed(f'decode_{p}x3')
    inputs = [[random.randrange(256) for j in range((p+3)//4)] for loop in range(numtests)]
    precomputed['decode',f'{p}x3'] = [decode_px3(p,s) for s in inputs]

    random.seed(f'decode_{p}x{r}')
    inputs = [[random.randrange(256) for j in range(ct-32)] for loop in range(numtests)]
    precomputed['decode',f'{p}x{r}'] = [decode_pxr(p,r,s) for s in inputs]
    precomputedxtype['decode',f'{p}x{r}'] = 'int16_t'

    random.seed(f'decode_{p}x{q}')
    inputs = [[random.randrange(256) for j in range(pk)] for loop in range(numtests)]
    precomputed['decode',f'{p}x{q}'] = [decode_pxq(p,q,s) for s in inputs]
    precomputedxtype['decode',f'{p}x{q}'] = 'int16_t'

    random.seed(f'encode_{p}xint16')
    inputs = [[random.randrange(65536) for j in range(p)] for loop in range(numtests)]
    precomputed['encode',f'{p}xint16'] = [encode_pxint16(p,x) for x in inputs]
    precomputedxtype['encode',f'{p}xint16'] = 'uint16_t'

    random.seed(f'encode_{p}x3')
    inputs = [[random.randrange(256) for j in range(p)] for loop in range(numtests)]
    precomputed['encode',f'{p}x3'] = [encode_px3(p,x) for x in inputs]

    random.seed(f'encode_{p}xfreeze3')
    inputs = [[random.randrange(-32768,32768) for j in range(p)] for loop in range(numtests)]
    precomputed['encode',f'{p}xfreeze3'] = [encode_pxfreeze3(p,x) for x in inputs]
    precomputedxtype['encode',f'{p}xfreeze3'] = 'int16_t'

    random.seed(f'encode_{p}x{r}')
    inputs = [[random.randrange(-32768,32768) for j in range(p)] for loop in range(numtests)]
    precomputed['encode',f'{p}x{r}'] = [encode_pxr(p,r,x) for x in inputs]
    precomputedxtype['encode',f'{p}x{r}'] = 'int16_t'

    random.seed(f'encode_{p}x{r}round')
    inputs = [[random.randrange(-32768,32768) for j in range(p)] for loop in range(numtests)]
    precomputed['encode',f'{p}x{r}round'] = [encode_pxrround(p,r,x) for x in inputs]
    precomputedxtype['encode',f'{p}x{r}round'] = 'int16_t'

    random.seed(f'encode_{p}x{q}')
    inputs = [[random.randrange(-32768,32768) for j in range(p)] for loop in range(numtests)]
    precomputed['encode',f'{p}x{q}'] = [encode_pxq(p,q,x) for x in inputs]
    precomputedxtype['encode',f'{p}x{q}'] = 'int16_t'

    random.seed(f'core_inv3sntrup{p}')
    inputs = [[random.randrange(256) for j in range(p)] for loop in range(numtests)]
    inputs += [[freeze(ninvi,3)%256 for ninvi in ninv]+[0]*(p-len(ninv))]
    precomputed['core',f'inv3sntrup{p}'] = [core_inv3sntrup(p,x) for x in inputs]

    random.seed(f'core_invsntrup{p}')
    inputs = [[random.randrange(256) for j in range(p)] for loop in range(numtests)]
    precomputed['core',f'invsntrup{p}'] = [core_invsntrup(p,q,x) for x in inputs]

    random.seed(f'core_mult3sntrup{p}')
    inputs = [([random.randrange(256) for j in range(p)],[random.randrange(256) for j in range(p)]) for loop in range(numtests)]
    precomputed['core',f'mult3sntrup{p}'] = [core_mult3sntrup(p,x,y) for (x,y) in inputs]

    random.seed(f'core_multsntrup{p}')
    inputs = [([random.randrange(-32768,32768) for j in range(p)],[random.randrange(256) for j in range(p)]) for loop in range(numtests)]
    precomputed['core',f'multsntrup{p}'] = [core_multsntrup(p,q,x,y) for (x,y) in inputs]

    random.seed(f'core_scale3sntrup{p}')
    inputs = [[random.randrange(-32768,32768) for j in range(p)] for loop in range(numtests)]
    precomputed['core',f'scale3sntrup{p}'] = [core_scale3sntrup(p,q,x) for x in inputs]

    random.seed(f'core_weightsntrup{p}')
    inputs = [[random.randrange(256) for j in range(p)] for loop in range(numtests)]
    inputs += [[0]*p]
    inputs += [[255]*p]
    for weight in w-1,w,w+1:
      v = [0]*p
      while sum(map(bool,v)) < weight: v[random.randrange(p)] = random.randrange(1,256,254)
      inputs += [v]
    precomputed['core',f'weightsntrup{p}'] = [core_weightsntrup(p,x) for x in inputs]
    precomputed['core',f'wforcesntrup{p}'] = [core_wforcesntrup(p,w,x) for x in inputs]

precompute()

# ----- generating test program

H = ['''\
#ifndef ntruprime_test_h
#define ntruprime_test_h

#define aligned ntruprime_test_aligned
#define callocplus ntruprime_test_callocplus
#define checksum ntruprime_test_checksum
#define checksum_clear ntruprime_test_checksum_clear
#define checksum_expected ntruprime_test_checksum_expected
#define double_canary ntruprime_test_double_canary
#define endianness ntruprime_test_endianness
#define forked ntruprime_test_forked
#define input_compare ntruprime_test_input_compare
#define input_prepare ntruprime_test_input_prepare
#define myrandom ntruprime_test_myrandom
#define ok ntruprime_test_ok
#define output_compare ntruprime_test_output_compare
#define output_prepare ntruprime_test_output_prepare
#define public ntruprime_test_public
#define secret ntruprime_test_secret
#define targeti ntruprime_test_targeti
#define targetn ntruprime_test_targetn
#define targetoffset ntruprime_test_targetoffset
#define targeto ntruprime_test_targeto
#define targetp ntruprime_test_targetp
#define valgrind ntruprime_test_valgrind

extern const char *targeto;
extern const char *targetp;
extern const char *targeti;
extern const char *targetn;
extern const char *targetoffset;
extern int ok;
extern int valgrind;

extern unsigned long long myrandom(void);
extern void forked(void (*)(long long),long long);
extern void *aligned(void *,long long);
extern void *callocplus(long long);
extern void secret(void *,long long);
extern void public(void *,long long);
extern void double_canary(unsigned char *,unsigned char *,unsigned long long);
extern void input_prepare(unsigned char *,unsigned char *,unsigned long long);
extern void output_prepare(unsigned char *,unsigned char *,unsigned long long);
extern void input_compare(const unsigned char *,const unsigned char *,unsigned long long,const char *);
extern void output_compare(const unsigned char *,const unsigned char *,unsigned long long,const char *);
extern void checksum_expected(const char *);
extern void checksum(const unsigned char *,unsigned long long);
extern void checksum_clear(void);
extern void endianness(unsigned char *,unsigned long long,unsigned long long);

''']

Z = [r'''/* WARNING: auto-generated (by autogen/test); do not edit */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <time.h>
#include <assert.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <fcntl.h>
#include <sys/resource.h>
#include "crypto_uint8.h"
#include "crypto_uint32.h"
#include "crypto_uint64.h"
#include "crypto_declassify.h"
#include <ntruprime.h> /* -lntruprime */
#include <randombytes.h>
#include "ntruprime_test.h"

const char *targeto = 0;
const char *targetp = 0;
const char *targeti = 0;
const char *targetn = 0;
const char *targetoffset = 0;

int ok = 1;

#define fail ((ok = 0),printf)

/* ----- valgrind support */

int valgrind = 0;
static unsigned char valgrind_undefined_byte = 0;
static char *volatile valgrind_pointer = 0;

static char *valgrind_malloc_1(void)
{
  char *x = malloc(1);
  if (!x) abort();
  *(char **volatile) &valgrind_pointer = x;
  return valgrind_pointer;
}

static void valgrind_init(void)
{
  char *e = getenv("valgrind_multiplier");
  char *x;
  if (!e) return;
  x = valgrind_malloc_1();
  valgrind_undefined_byte = x[0]+1;
  valgrind_undefined_byte *= atoi(e);
  valgrind_undefined_byte ^= x[0]+1;
  free(x);
  valgrind = 1;
}

void secret(void *xvoid,long long xlen)
{
  unsigned char *x = xvoid;
  while (xlen > 0) {
    *x ^= valgrind_undefined_byte;
    ++x;
    --xlen;
  }
}

void public(void *x,long long xlen)
{
  crypto_declassify(x,xlen);
}

/* ----- rng and hash, from supercop/try-anything.c */

typedef crypto_uint8 u8;
typedef crypto_uint32 u32;
typedef crypto_uint64 u64;

#define FOR(i,n) for (i = 0;i < n;++i)

static u32 L32(u32 x,int c) { return (x << c) | ((x&0xffffffff) >> (32 - c)); }

static u32 ld32(const u8 *x)
{
  u32 u = x[3];
  u = (u<<8)|x[2];
  u = (u<<8)|x[1];
  return (u<<8)|x[0];
}

static void st32(u8 *x,u32 u)
{
  int i;
  FOR(i,4) { x[i] = u; u >>= 8; }
}

static const u8 sigma[17] = "expand 32-byte k";

static void core_salsa(u8 *out,const u8 *in,const u8 *k)
{
  u32 w[16],x[16],y[16],t[4];
  int i,j,m;

  FOR(i,4) {
    x[5*i] = ld32(sigma+4*i);
    x[1+i] = ld32(k+4*i);
    x[6+i] = ld32(in+4*i);
    x[11+i] = ld32(k+16+4*i);
  }

  FOR(i,16) y[i] = x[i];

  FOR(i,20) {
    FOR(j,4) {
      FOR(m,4) t[m] = x[(5*j+4*m)%16];
      t[1] ^= L32(t[0]+t[3], 7);
      t[2] ^= L32(t[1]+t[0], 9);
      t[3] ^= L32(t[2]+t[1],13);
      t[0] ^= L32(t[3]+t[2],18);
      FOR(m,4) w[4*j+(j+m)%4] = t[m];
    }
    FOR(m,16) x[m] = w[m];
  }

  FOR(i,16) st32(out + 4 * i,x[i] + y[i]);
}

static void salsa20(u8 *c,u64 b,const u8 *n,const u8 *k)
{
  u8 z[16],x[64];
  u32 u,i;
  if (!b) return;
  FOR(i,16) z[i] = 0;
  FOR(i,8) z[i] = n[i];
  while (b >= 64) {
    core_salsa(x,z,k);
    FOR(i,64) c[i] = x[i];
    u = 1;
    for (i = 8;i < 16;++i) {
      u += (u32) z[i];
      z[i] = u;
      u >>= 8;
    }
    b -= 64;
    c += 64;
  }
  if (b) {
    core_salsa(x,z,k);
    FOR(i,b) c[i] = x[i];
  }
}

static void increment(u8 *n)
{
  if (!++n[0])
    if (!++n[1])
      if (!++n[2])
        if (!++n[3])
          if (!++n[4])
            if (!++n[5])
              if (!++n[6])
                if (!++n[7])
                  ;
}

static unsigned char testvector_n[8];

static void testvector_clear(void)
{
  memset(testvector_n,0,sizeof testvector_n);
}

static void testvector(unsigned char *x,unsigned long long xlen)
{
  const static unsigned char testvector_k[33] = "generate inputs for test vectors";
  salsa20(x,xlen,testvector_n,testvector_k);
  increment(testvector_n);
}

unsigned long long myrandom(void)
{
  unsigned char x[8];
  unsigned long long result;
  testvector(x,8);
  result = x[7];
  result = (result<<8)|x[6];
  result = (result<<8)|x[5];
  result = (result<<8)|x[4];
  result = (result<<8)|x[3];
  result = (result<<8)|x[2];
  result = (result<<8)|x[1];
  result = (result<<8)|x[0];
  return result;
}

static unsigned char canary_n[8];

static void canary(unsigned char *x,unsigned long long xlen)
{
  const static unsigned char canary_k[33] = "generate pad to catch overwrites";
  salsa20(x,xlen,canary_n,canary_k);
  increment(canary_n);
}

void double_canary(unsigned char *x2,unsigned char *x,unsigned long long xlen)
{
  if (valgrind) return;
  canary(x - 16,16);
  canary(x + xlen,16);
  memcpy(x2 - 16,x - 16,16);
  memcpy(x2 + xlen,x + xlen,16);
}

void input_prepare(unsigned char *x2,unsigned char *x,unsigned long long xlen)
{
  testvector(x,xlen);
  if (valgrind) {
    memcpy(x2,x,xlen);
    return;
  }
  canary(x - 16,16);
  canary(x + xlen,16);
  memcpy(x2 - 16,x - 16,xlen + 32);
}

void input_compare(const unsigned char *x2,const unsigned char *x,unsigned long long xlen,const char *fun)
{
  if (valgrind) return;
  if (memcmp(x2 - 16,x - 16,xlen + 32)) {
    fail("failure: %s overwrites input\n",fun);
  }
}

void output_prepare(unsigned char *x2,unsigned char *x,unsigned long long xlen)
{
  if (valgrind) {
    memcpy(x2,x,xlen);
    return;
  }
  canary(x - 16,xlen + 32);
  memcpy(x2 - 16,x - 16,xlen + 32);
}

void output_compare(const unsigned char *x2,const unsigned char *x,unsigned long long xlen,const char *fun)
{
  if (valgrind) return;
  if (memcmp(x2 - 16,x - 16,16)) {
    fail("failure: %s writes before output\n",fun);
  }
  if (memcmp(x2 + xlen,x + xlen,16)) {
    fail("failure: %s writes after output\n",fun);
  }
}

/* ----- knownrandombytes */

static const int knownrandombytes_is_only_for_testing_not_for_cryptographic_use = 1;
#define knownrandombytes randombytes

#define QUARTERROUND(a,b,c,d) \
  a += b; d = L32(d^a,16); \
  c += d; b = L32(b^c,12); \
  a += b; d = L32(d^a, 8); \
  c += d; b = L32(b^c, 7);

static void core_chacha(u8 *out,const u8 *in,const u8 *k)
{
  u32 x[16],y[16];
  int i,j;
  FOR(i,4) {
    x[i] = ld32(sigma+4*i);
    x[12+i] = ld32(in+4*i);
  }
  FOR(i,8) x[4+i] = ld32(k+4*i);
  FOR(i,16) y[i] = x[i];
  FOR(i,10) {
    FOR(j,4) { QUARTERROUND(x[j],x[j+4],x[j+8],x[j+12]) }
    FOR(j,4) { QUARTERROUND(x[j],x[((j+1)&3)+4],x[((j+2)&3)+8],x[((j+3)&3)+12]) }
  }
  FOR(i,16) st32(out+4*i,x[i]+y[i]);
}

static void chacha20(u8 *c,u64 b,const u8 *n,const u8 *k)
{
  u8 z[16],x[64];
  u32 u,i;
  if (!b) return;
  FOR(i,16) z[i] = 0;
  FOR(i,8) z[i+8] = n[i];
  while (b >= 64) {
    core_chacha(x,z,k);
    FOR(i,64) c[i] = x[i];
    u = 1;
    FOR(i,8) {
      u += (u32) z[i];
      z[i] = u;
      u >>= 8;
    }
    b -= 64;
    c += 64;
  }
  if (b) {
    core_chacha(x,z,k);
    FOR(i,b) c[i] = x[i];
  }
}

#define crypto_rng_OUTPUTBYTES 736

static int crypto_rng(
        unsigned char *r, /* random output */
        unsigned char *n, /* new key */
  const unsigned char *g  /* old key */
)
{
  static const unsigned char nonce[8] = {0};
  unsigned char x[32+crypto_rng_OUTPUTBYTES];
  chacha20(x,sizeof x,nonce,g);
  memcpy(n,x,32);
  memcpy(r,x+32,crypto_rng_OUTPUTBYTES);
  return 0;
}

static unsigned char knownrandombytes_g[32];
static unsigned char knownrandombytes_r[crypto_rng_OUTPUTBYTES];
static unsigned long long knownrandombytes_pos = crypto_rng_OUTPUTBYTES;

static void knownrandombytes_clear(void)
{
  memset(knownrandombytes_g,0,sizeof knownrandombytes_g);
  memset(knownrandombytes_r,0,sizeof knownrandombytes_r);
  knownrandombytes_pos = crypto_rng_OUTPUTBYTES;
}

void knownrandombytes_main(void *xvoid,long long xlen)
{
  unsigned char *x = xvoid;
  assert(knownrandombytes_is_only_for_testing_not_for_cryptographic_use);

  while (xlen > 0) {
    if (knownrandombytes_pos == crypto_rng_OUTPUTBYTES) {
      crypto_rng(knownrandombytes_r,knownrandombytes_g,knownrandombytes_g);
      knownrandombytes_pos = 0;
    }
    *x++ = knownrandombytes_r[knownrandombytes_pos];
    xlen -= 1;
    knownrandombytes_r[knownrandombytes_pos++] = 0;
  }
}

void knownrandombytes(void *xvoid,long long xlen)
{
  knownrandombytes_main(xvoid,xlen);
  secret(xvoid,xlen);
}

/* ----- checksums */

static unsigned char checksum_state[64];
static char checksum_hex[65];

void checksum_expected(const char *expected)
{
  long long i;
  for (i = 0;i < 32;++i) {
    checksum_hex[2 * i] = "0123456789abcdef"[15 & (checksum_state[i] >> 4)];
    checksum_hex[2 * i + 1] = "0123456789abcdef"[15 & checksum_state[i]];
  }
  checksum_hex[2 * i] = 0;

  if (strcmp(checksum_hex,expected))
    fail("failure: checksum mismatch: %s expected %s\n",checksum_hex,expected);
}

void checksum_clear(void)
{
  memset(checksum_state,0,sizeof checksum_state);
  knownrandombytes_clear();
  testvector_clear();
  /* not necessary to clear canary */
}

void checksum(const unsigned char *x,unsigned long long xlen)
{
  u8 block[16];
  int i;
  while (xlen >= 16) {
    core_salsa(checksum_state,x,checksum_state);
    x += 16;
    xlen -= 16;
  }
  FOR(i,16) block[i] = 0;
  FOR(i,xlen) block[i] = x[i];
  block[xlen] = 1;
  checksum_state[0] ^= 1;
  core_salsa(checksum_state,block,checksum_state);
}

#include "limits.inc"

void *callocplus(long long len)
{
  if (valgrind) {
    unsigned char *x = malloc(len);
    if (!x) abort();
    return x;
  } else {
    unsigned char *x = calloc(1,len + 256);
    long long i;
    if (!x) abort();
    for (i = 0;i < len + 256;++i) x[i] = random();
    return x;
  }
}

void *aligned(void *x,long long len)
{
  if (valgrind)
    return x;
  else {
    long long i;
    unsigned char *y = x;
    y += 64;
    y += 63 & (-(unsigned long) y);
    for (i = 0;i < len;++i) y[i] = 0;
    return y;
  }
}

/* ----- catching SIGILL, SIGBUS, SIGSEGV, etc. */

void forked(void (*test)(long long),long long impl)
{
  if (valgrind) {
    test(impl);
    return;
  }
  fflush(stdout);
  pid_t child = fork();
  int childstatus = -1;
  if (child == -1) {
    fprintf(stderr,"fatal: fork failed: %s",strerror(errno));
    exit(111);
  }
  if (child == 0) {
    ok = 1;
    limits();
    test(impl);
    if (!ok) exit(100);
    exit(0);
  }
  if (waitpid(child,&childstatus,0) != child) {
    fprintf(stderr,"fatal: wait failed: %s",strerror(errno));
    exit(111);
  }
  if (childstatus)
    fail("failure: process failed, status %d\n",childstatus);
  fflush(stdout);
}

/* ----- endianness */

/* on big-endian machines, flip into little-endian */
/* other types of endianness are not supported */
void endianness(unsigned char *e,unsigned long long words,unsigned long long bytesperword)
{
  long long i = 1;

  if (1 == *(unsigned char *) &i) return;

  while (words > 0) {
    for (i = 0;2 * i < bytesperword;++i) {
      long long j = bytesperword - 1 - i;
      unsigned char ei = e[i];
      e[i] = e[j];
      e[j] = ei;
    }
    e += bytesperword;
    words -= 1;
  }
}
''']

# ==========

checksums = {}
operations = []
primitives = {}
sizes = {}
exports = {}
prototypes = {}
nooverlap = set()

with open('api') as f:
  for line in f:
    line = line.strip()
    if line.startswith('crypto_'):
      line = line.split()
      x = line[0].split('/')
      assert len(x) == 2
      o = x[0].split('_')[1]
      if o not in operations: operations += [o]
      p = x[1]
      if o not in primitives: primitives[o] = []
      primitives[o] += [p]
      if len(line) >= 3:
        checksums[o,p] = line[1],line[2]
      for option in line[3:]:
        if option == 'nooverlap':
          nooverlap.add((o,p))
      continue
    if line.startswith('#define '):
      x = line.split(' ')
      x = x[1].split('_')
      assert len(x) == 4
      assert x[0] == 'crypto'
      o = x[1]
      p = x[2]
      if (o,p) not in sizes: sizes[o,p] = ''
      sizes[o,p] += line+'\n'
      continue
    if line.endswith(');'):
      fun,args = line[:-2].split('(')
      rettype,fun = fun.split()
      fun = fun.split('_')
      o = fun[1]
      assert fun[0] == 'crypto'
      if o not in exports: exports[o] = []
      exports[o] += ['_'.join(fun[1:])]
      if o not in prototypes: prototypes[o] = []
      prototypes[o] += [(rettype,fun,args)]

# ========== verify

Z += [r'''
/* ----- verify, derived from supercop/crypto_verify/try.c */
''']

for p in primitives['verify']:
  Z += [r'''
static int (*crypto_verify_BYTES)(const unsigned char *,const unsigned char *);

static unsigned char *test_verify_BYTES_x;
static unsigned char *test_verify_BYTES_y;

static void test_verify_BYTES_check(void)
{
  unsigned char *x = test_verify_BYTES_x;
  unsigned char *y = test_verify_BYTES_y;
  int r;

  secret(x,BYTES);
  secret(y,BYTES);
  r = crypto_verify_BYTES(x,y);
  public(x,BYTES);
  public(y,BYTES);
  public(&r,sizeof r);

  if (r == 0) {
    if (memcmp(x,y,BYTES))
      fail("failure: different strings pass verify\n");
  } else if (r == -1) {
    if (!memcmp(x,y,BYTES))
      fail("failure: equal strings fail verify\n");
  } else {
    fail("failure: weird return value\n");
  }
}

void test_verify_BYTES_impl(long long impl)
{
  unsigned char *x = test_verify_BYTES_x;
  unsigned char *y = test_verify_BYTES_y;

  if (targeti && strcmp(targeti,".") && strcmp(targeti,ntruprime_dispatch_verify_BYTES_implementation(impl))) return;
  if (targetn && atol(targetn) != impl) return;
  if (impl >= 0) {
    crypto_verify_BYTES = ntruprime_dispatch_verify_BYTES(impl);
    printf("verify_BYTES %lld implementation %s compiler %s\n",impl,ntruprime_dispatch_verify_BYTES_implementation(impl),ntruprime_dispatch_verify_BYTES_compiler(impl));
  } else {
    crypto_verify_BYTES = ntruprime_verify_BYTES;
    printf("verify_BYTES selected implementation %s compiler %s\n",ntruprime_verify_BYTES_implementation(),ntruprime_verify_BYTES_compiler());
  }

  randombytes(x,BYTES);
  randombytes(y,BYTES);
  test_verify_BYTES_check();
  memcpy(y,x,BYTES);
  test_verify_BYTES_check();
  y[myrandom() % BYTES] = myrandom();
  test_verify_BYTES_check();
  y[myrandom() % BYTES] = myrandom();
  test_verify_BYTES_check();
  y[myrandom() % BYTES] = myrandom();
  test_verify_BYTES_check();
}

static void test_verify_BYTES(void)
{
  if (targeto && strcmp(targeto,"verify")) return;
  if (targetp && strcmp(targetp,"BYTES")) return;

  test_verify_BYTES_x = callocplus(BYTES);
  test_verify_BYTES_y = callocplus(BYTES);

  for (long long offset = 0;offset < 2;++offset) {
    if (targetoffset && atol(targetoffset) != offset) continue;
    if (offset && valgrind) break;
    printf("verify_BYTES offset %lld\n",offset);
    for (long long impl = -1;impl < ntruprime_numimpl_verify_BYTES();++impl)
      forked(test_verify_BYTES_impl,impl);
    ++test_verify_BYTES_x;
    ++test_verify_BYTES_y;
  }
}
'''.replace('BYTES',p)]

# ==========

todo = (
  ('hashblocks',(
    ('h','crypto_hashblocks_STATEBYTES','crypto_hashblocks_STATEBYTES'),
    ('m',None,'4096'),
  ),(
    ('loops','4096','32768'),
    ('maxtest','128','4096'),
  ),(
    ('',(),('h',),('m','mlen')),
  )),
  ('hash',(
    ('h','crypto_hash_BYTES','crypto_hash_BYTES'),
    ('m',None,'4096+crypto_hash_BYTES'),
  ),(
    ('loops','64','512'),
    ('maxtest','128','4096'),
  ),(
    ('',('h',),(),('m','mlen')),
  )),
  ('decode',(
    ('x','crypto_decode_ITEMS','crypto_decode_ITEMS*crypto_decode_ITEMBYTES'),
    ('s','crypto_decode_STRBYTES','crypto_decode_STRBYTES'),
  ),(
    ('loops','1024','4096'),
  ),(
    ('',('x',),(),('s',)),
  )),
  ('encode',(
    ('s','crypto_encode_STRBYTES','crypto_encode_STRBYTES'),
    ('x','crypto_encode_ITEMS','crypto_encode_ITEMS*crypto_encode_ITEMBYTES'),
  ),(
    ('loops','1024','4096'),
  ),(
    ('',('s',),(),('x',)),
  )),
  ('sort',(
    ('x',None,'4096'),
  ),(
    ('loops','1024','4096'),
    ('maxtest','128','4096'),
  ),(
    ('',(),('x',),('xwords',)),
  )),
  ('core',(
    ('h','crypto_core_OUTPUTBYTES','crypto_core_OUTPUTBYTES'),
    ('n','crypto_core_INPUTBYTES','crypto_core_INPUTBYTES'),
    ('k','crypto_core_KEYBYTES','crypto_core_KEYBYTES'),
    ('c','crypto_core_CONSTBYTES','crypto_core_CONSTBYTES'),
  ),(
    ('loops','512','4096'),
  ),(
    ('',('h',),(),('n','k','c')),
  )),
  ('kem',(
    ('p','crypto_kem_PUBLICKEYBYTES','crypto_kem_PUBLICKEYBYTES'),
    ('s','crypto_kem_SECRETKEYBYTES','crypto_kem_SECRETKEYBYTES'),
    ('k','crypto_kem_BYTES','crypto_kem_BYTES'),
    ('c','crypto_kem_CIPHERTEXTBYTES','crypto_kem_CIPHERTEXTBYTES'),
    ('t','crypto_kem_BYTES','crypto_kem_BYTES'),
  ),(
    ('loops','8','64'),
  ),(
    ('_keypair',('p','s'),(),()),
    ('_enc',('c','k'),(),('p',)),
    ('_dec',('t',),(),('c','s')),
  )),
)

for t in todo:
  o,vars,howmuch,tests = t
  
  for p in primitives[o]:
    X = []

    X += [f'/* ----- {o}/{p}, derived from supercop/crypto_{o}/try.c */\n']
    X += ['\n']
    X += ['#include <stdio.h>\n']
    X += ['#include <stdlib.h>\n']
    X += ['#include <string.h>\n']
    X += ['#include <stdint.h>\n']
    X += ['#include <ntruprime.h>\n']
    X += ['#include "ntruprime_test.h"\n']
    X += ['\n']
    X += ['#define fail ((ok = 0),printf)\n']

    X += ['static const char *%s_%s_checksums[] = {\n' % (o,p)]
    X += ['  "%s",\n' % checksums[o,p][0]]
    X += ['  "%s",\n' % checksums[o,p][1]]
    X += ['} ;\n']
    X += ['\n']

    for rettype,fun,args in prototypes[o]:
      X += ['static %s (*%s)(%s);\n' % (rettype,'_'.join(fun),args)]

    if (o,p) in sizes:
      for line in sizes[o,p].splitlines():
        psize = line.split()[1]
        size1 = psize.replace('crypto_%s_%s_'%(o,p),'crypto_%s_'%o)
        size2 = psize.replace('crypto_','ntruprime_')
        X += ['#define %s %s\n' % (size1,size2)]
      X += ['\n']

    for v,initsize,allocsize in vars:
      X += ['static void *storage_%s_%s_%s;\n' % (o,p,v)]
      X += ['static unsigned char *test_%s_%s_%s;\n' % (o,p,v)]
    for v,initsize,allocsize in vars:
      X += ['static void *storage_%s_%s_%s2;\n' % (o,p,v)]
      X += ['static unsigned char *test_%s_%s_%s2;\n' % (o,p,v)]
    X += ['\n']

    if (o,p) in precomputed:
      xtype = precomputedxtype.get((o,p),'unsigned char')
      X += ['#define precomputed_%s_%s_NUM %d\n' % (o,p,len(precomputed[o,p]))]
      X += ['\n']
      for v,initsize,allocsize in vars:
        vtype = 'unsigned char'
        if v == 'x':
          vtype = xtype
          allocsize = allocsize.split('*')
          assert allocsize[1] == f'crypto_{o}_ITEMBYTES'
          allocsize = allocsize[0]
        X += ['static const %s precomputed_%s_%s_%s[precomputed_%s_%s_NUM][%s] = {\n' % (vtype,o,p,v,o,p,allocsize)]
        for precomp in precomputed[o,p]:
          X += ['  {%s},\n' % ','.join(str(c) for c in precomp[v])]
        X += ['} ;\n']
        X += ['\n']

    X += ['static void test_%s_%s_impl(long long impl)\n' % (o,p)]
    X += ['{\n']
    for v,initsize,allocsize in vars:
      X += ['  unsigned char *%s = test_%s_%s_%s;\n' % (v,o,p,v)]
    for v,initsize,allocsize in vars:
      X += ['  unsigned char *%s2 = test_%s_%s_%s2;\n' % (v,o,p,v)]
    for v,initsize,allocsize in vars:
      if initsize is None:
        X += ['  long long %slen;\n' % v]
        if v == 'x':
          X += ['  long long xwords;\n']
      else:
        if v == 'x':
          X += ['  long long xwords = %s;\n' % (initsize)]
          X += ['  long long xlen;\n']
        else:
          X += ['  long long %slen = %s;\n' % (v,initsize)]
    X += ['\n']

    X += ['  if (targeti && strcmp(targeti,".") && strcmp(targeti,ntruprime_dispatch_%s_%s_implementation(impl))) return;\n' % (o,p)]
    X += ['  if (targetn && atol(targetn) != impl) return;\n'] # XXX: atoll is slightly unportable

    X += ['  if (impl >= 0) {\n']
    for rettype,fun,args in prototypes[o]:
      f2 = ['ntruprime','dispatch',o,p]+fun[2:]
      X += ['    %s = %s(impl);\n' % ('_'.join(fun),'_'.join(f2))]
    X += ['    printf("%s_%s %%lld implementation %%s compiler %%s\\n",impl,ntruprime_dispatch_%s_%s_implementation(impl),ntruprime_dispatch_%s_%s_compiler(impl));\n' % (o,p,o,p,o,p)]
    X += ['  } else {\n']
    for rettype,fun,args in prototypes[o]:
      f2 = ['ntruprime',o,p]+fun[2:]
      X += ['    %s = %s;\n' % ('_'.join(fun),'_'.join(f2))]
    X += ['    printf("%s_%s selected implementation %%s compiler %%s\\n",ntruprime_%s_%s_implementation(),ntruprime_%s_%s_compiler());\n' % (o,p,o,p,o,p)]
    X += ['  }\n']

    X += ['  for (long long checksumbig = 0;checksumbig < 2;++checksumbig) {\n']

    maxtestdefined = False
    for v,small,big in howmuch:
      X += ['    long long %s = checksumbig ? %s : %s;\n' % (v,big,small)]
      if v == 'maxtest': maxtestdefined = True
    X += ['\n']
    X += ['    checksum_clear();\n']
    X += ['\n']
    X += ['    for (long long loop = 0;loop < loops;++loop) {\n']

    wantresult = False
    for f,output,inout,input in tests:
      cof = 'crypto_'+o+f
      for rettype,fun,args in prototypes[o]:
        if cof == '_'.join(fun):
          if rettype != 'void':
            wantresult = True
    if wantresult:
      X += ['      int result;\n']

    itembytes = f'crypto_{o}_BYTES' if o == 'sort' else f'crypto_{o}_ITEMBYTES'

    if maxtestdefined and any('mlen' in input for f,output,inout,input in tests):
      X += ['      mlen = myrandom() % (maxtest + 1);\n']
    if maxtestdefined and any('hlen' in input for f,output,inout,input in tests):
      X += ['      hlen = myrandom() % (maxtest + 1);\n']
    if maxtestdefined and any('xwords' in input for f,output,inout,input in tests):
      X += ['      xwords = myrandom() % (maxtest + 1);\n']
    if any('x' in output+inout+input for f,output,inout,input in tests):
      X += [f'      xlen = xwords*{itembytes};\n']
    X += ['\n']

    initialized = set()
    for f,output,inout,input in tests:
      cof = 'crypto_'+o+f

      cofrettype = None
      for rettype,fun,args in prototypes[o]:
        if cof == '_'.join(fun):
          cofrettype = rettype

      expected = '0'
      unexpected = 'nonzero'
      if cof == 'crypto_hashblocks':
        expected = 'mlen % crypto_hashblocks_BLOCKBYTES'
        unexpected = 'unexpected value'

      for v in output:
        if len(v) == 1:
          X += ['      output_prepare(%s2,%s,%slen);\n' % (v,v,v)]
          # v now has CDE where C is canary, D is canary, E is canary
          # v2 now has same CDE
          # D is at start of v with specified length
          # C is 16 bytes before beginning
          # E is 16 bytes past end
      for v in input+inout:
        if len(v) == 1:
          if v in initialized:
            X += ['      memcpy(%s2,%s,%slen);\n' % (v,v,v)]
            X += ['      double_canary(%s2,%s,%slen);\n' % (v,v,v)]
          else:
            X += ['      input_prepare(%s2,%s,%slen);\n' % (v,v,v)]
            # v now has CTE where C is canary, T is test data, E is canary
            # v2 has same CTE
            initialized.add(v)

      if 'x' in v:
        X += [f'      endianness(x,xwords,{itembytes});\n']

      for v in input+inout:
        if len(v) == 1:
          X += ['      secret(%s,%slen);\n' % (v,v)]

      args = ','.join(output+inout+input)
      if cofrettype == 'void':
        X += ['      %s(%s);\n' % (cof,args)]
      else:
        X += ['      result = %s(%s);\n' % (cof,args)]
        X += ['      public(&result,sizeof result);\n']
        X += ['      if (result != %s) fail("failure: %s returns %s\\n");\n' % (expected,cof,unexpected)]
    
      for v in input+inout+output:
        if len(v) == 1:
          X += ['      public(%s,%slen);\n' % (v,v)]

      if 'x' in v:
        X += [f'      endianness(x,xwords,{itembytes});\n']

      if cof == 'crypto_kem_dec':
        X += ['      if (memcmp(t,k,klen) != 0) fail("failure: %s does not match k\\n");\n' % cof]

      for v in output+inout:
        if len(v) == 1:
          X += ['      checksum(%s,%slen);\n' % (v,v)]
          # output v,v2 now has COE,CDE where O is output; checksum O
          initialized.add(v)
      for v in output+inout:
        if len(v) == 1:
          if cof == 'crypto_sign_open' and v == 't':
            X += ['      output_compare(%s2,%s,%slen,"%s");\n' % (v,v,'c',cof)]
          else:
            X += ['      output_compare(%s2,%s,%slen,"%s");\n' % (v,v,v,cof)]
            # output_compare checks COE,CDE for equal C, equal E
      for v in input:
        if len(v) == 1:
          X += ['      input_compare(%s2,%s,%slen,"%s");\n' % (v,v,v,cof)]
          # input_compare checks CTE,CTE for equal C, equal T, equal E
    
      deterministic = True
      if inout+input == (): deterministic = False
      if cof == 'crypto_kem_enc': deterministic = False
    
      if deterministic:
        X += ['\n']
        for v in output+inout+input:
          if len(v) == 1:
            X += ['      double_canary(%s2,%s,%slen);\n' % (v,v,v)]
            # old output v,v2: COE,CDE; new v,v2: FOG,FDG where F,G are new canaries
            # old inout v,v2: COE,CTE; new v,v2: FOG,FTG
            # old input v,v2: CTE,CTE; new v,v2: FTG,FTG
    
        if 'x' in v:
          X += [f'      endianness(x2,xwords,{itembytes});\n']

        for v in input+inout:
          if len(v) == 1:
            X += ['      secret(%s2,%slen);\n' % (v,v)]

        args = ','.join([v if v.endswith('words') or v.endswith('len') else v+'2' for v in output+inout+input])
        if cofrettype == 'void':
          X += ['      %s(%s);\n' % (cof,args)]
        else:
          X += ['      result = %s(%s);\n' % (cof,args)]
          X += ['      public(&result,sizeof result);\n']
          X += ['      if (result != %s) fail("failure: %s returns %s\\n");\n' % (expected,cof,unexpected)]
    
        for v in input+inout+output:
          if len(v) == 1:
            X += ['      public(%s2,%slen);\n' % (v,v)]

        if 'x' in v:
          X += [f'      endianness(x2,xwords,{itembytes});\n']

        for w in output + inout:
          if len(w) == 1:
            # w,w2: COE,COE; goal now is to compare O
            X += ['      if (memcmp(%s2,%s,%slen) != 0) fail("failure: %s is nondeterministic\\n");\n' % (w,w,w,cof)]
    
      overlap = deterministic
      if inout != (): overlap = False
      if (o,p) in nooverlap: overlap = False

      # XXX: overlap test assumes that inputs are at least as big as outputs
    
      if overlap:
        for y in output:
          if len(y) == 1:
            X += ['\n']
            for v in output:
              if len(v) == 1:
                X += ['      double_canary(%s2,%s,%slen);\n' % (v,v,v)]
            for v in input:
              if len(v) == 1:
                X += ['      double_canary(%s2,%s,%slen);\n' % (v,v,v)]
            for x in input:
              if len(x) == 1:
                # try writing to x2 instead of y, while reading x2
                args = ','.join([x+'2' if v==y else v for v in output] + [x+'2' if v==x else v for v in input])
    
                for v in input+inout:
                  v2 = x+'2' if v==x else v
                  if len(v) == 1:
                    X += ['      secret(%s,%slen);\n' % (v2,v)]

                if cofrettype == 'void':
                  X += ['      %s(%s);\n' % (cof,args)]
                else:
                  X += ['      result = %s(%s);\n' % (cof,args)]
                  X += ['      public(&result,sizeof result);\n']
                  X += ['      if (result != %s) fail("failure: %s with %s=%s overlap returns %s\\n");\n' % (expected,cof,x,y,unexpected)]
    
                for v in output:
                  v2 = x+'2' if v==y else v
                  if len(v) == 1:
                    X += ['      public(%s,%slen);\n' % (v2,v)]
                for v in input:
                  if v == x: continue
                  if len(v) == 1:
                    X += ['      public(%s,%slen);\n' % (v,v)]

                X += ['      if (memcmp(%s2,%s,%slen) != 0) fail("failure: %s does not handle %s=%s overlap\\n");\n' % (x,y,y,cof,x,y)]
                X += ['      memcpy(%s2,%s,%slen);\n' % (x,x,x)]
    
      if cof == 'crypto_kem_dec':
        X += ['\n']
        for tweaks in range(3):
          X += ['      c[myrandom() % clen] += 1 + (myrandom() % 255);\n']
          X += ['      %s(t,c,s);\n' % cof]
          X += ['      checksum(t,tlen);\n']

    X += ['    }\n']
    if cof == 'crypto_core' and p.startswith('wforce'):
      X += ['    {\n']
      X += ['      long long weight,i,direction;\n']
      X += ['      for (weight = 0;weight <= nlen;++weight) {\n']
      X += ['        for (direction = 0;direction < 2;++direction) {\n']
      X += ['          output_prepare(h2,h,hlen);\n']
      X += ['          input_prepare(n2,n,nlen);\n']
      X += ['          input_prepare(k2,k,klen);\n']
      X += ['          input_prepare(c2,c,clen);\n']
      X += ['          for (i = 0;i < nlen;++i) {\n']
      X += ['            n[i] &= ~1;\n']
      X += ['            if (direction) {\n']
      X += ['              if (nlen-1-i < weight) n[i] += 1;\n']
      X += ['            } else {\n']
      X += ['              if (i < weight) n[i] += 1;\n']
      X += ['            }\n']
      X += ['            n2[i] = n[i];\n']
      X += ['          }\n']
      X += ['          crypto_core(h,n,k,c);\n']
      X += ['          checksum(h,hlen);\n']
      X += ['          output_compare(h2,h,hlen,"crypto_core");\n']
      X += ['          input_compare(n2,n,nlen,"crypto_core");\n']
      X += ['          input_compare(k2,k,klen,"crypto_core");\n']
      X += ['          input_compare(c2,c,clen,"crypto_core");\n']
      X += ['          double_canary(h2,h,hlen);\n']
      X += ['          double_canary(n2,n,nlen);\n']
      X += ['          double_canary(k2,k,klen);\n']
      X += ['          double_canary(c2,c,clen);\n']
      X += ['          crypto_core(h2,n2,k2,c2);\n']
      X += ['          if (memcmp(h2,h,hlen) != 0) fail("failure: crypto_core is nondeterministic");\n']
      X += ['        }\n']
      X += ['      }\n']
      X += ['    }\n']

    X += ['    checksum_expected(%s_%s_checksums[checksumbig]);\n' % (o,p)]
    X += ['  }\n']

    # ----- test vectors computed by python

    for f,output,inout,input in tests:
      cof = 'crypto_'+o+f
      if (o,p) in precomputed:
        X += ['  for (long long precomp = 0;precomp < precomputed_%s_%s_NUM;++precomp) {\n' % (o,p)]
        for v,initsize,allocsize in vars:
          if v in output:
            X += ['    output_prepare(%s2,%s,%s);\n' % (v,v,allocsize)]
          if v in input+inout:
            X += ['    input_prepare(%s2,%s,%s);\n' % (v,v,allocsize)]
            X += ['    memcpy(%s,precomputed_%s_%s_%s[precomp],%s);\n' % (v,o,p,v,allocsize)]
            X += ['    memcpy(%s2,precomputed_%s_%s_%s[precomp],%s);\n' % (v,o,p,v,allocsize)]

        args = ','.join(output+inout+input)
        X += ['    %s(%s);\n' % (cof,args)]

        for v,initsize,allocsize in vars:
          if v in output+inout:
            X += ['    if (memcmp(%s,precomputed_%s_%s_%s[precomp],%s)) {\n' % (v,o,p,v,allocsize)]
            X += ['      fail("failure: %s fails precomputed test vectors\\n");\n' % cof]
            X += ['      printf("expected %s: ");\n' % v]
            X += ['      for (long long pos = 0;pos < %s;++pos) printf("%%02x",((unsigned char *) precomputed_%s_%s_%s[precomp])[pos]);\n' % (allocsize,o,p,v)]
            X += ['      printf("\\n");\n']
            X += ['      printf("received %s: ");\n' % v]
            X += ['      for (long long pos = 0;pos < %s;++pos) printf("%%02x",%s[pos]);\n' % (allocsize,v)]
            X += ['      printf("\\n");\n']
            X += ['    }\n']

        for v,initsize,allocsize in vars:
          if v in output+inout:
            X += ['    output_compare(%s2,%s,%s,"%s");\n' % (v,v,allocsize,cof)]
          if v in input:
            X += ['    input_compare(%s2,%s,%s,"%s");\n' % (v,v,allocsize,cof)]

        X += ['  }\n']

    X += ['}\n']
    X += ['\n']

    X += ['void test_%s_%s(void)\n' % (o,p)]
    X += ['{\n']
    X += ['  long long maxalloc = 0;\n']
    X += ['  if (targeto && strcmp(targeto,"%s")) return;\n' % o]
    X += ['  if (targetp && strcmp(targetp,"%s")) return;\n' % p]

    if cof == 'crypto_sort':
      for v,initsize,allocsize in vars:
        X += ['  storage_%s_%s_%s = callocplus(ntruprime_sort_%s_BYTES*%s);\n' % (o,p,v,p,allocsize)]
        X += ['  test_%s_%s_%s = aligned(storage_%s_%s_%s,ntruprime_sort_%s_BYTES*%s);\n' % (o,p,v,o,p,v,p,allocsize)]
        X += [f'  if (ntruprime_sort_{p}_BYTES*{allocsize} > maxalloc) maxalloc = ntruprime_sort_{p}_BYTES*{allocsize};\n']
      for v,initsize,allocsize in vars:
        X += ['  storage_%s_%s_%s2 = callocplus(maxalloc);\n' % (o,p,v)]
        X += ['  test_%s_%s_%s2 = aligned(storage_%s_%s_%s2,ntruprime_sort_%s_BYTES*%s);\n' % (o,p,v,o,p,v,p,allocsize)]
    else:
      for v,initsize,allocsize in vars:
        X += ['  storage_%s_%s_%s = callocplus(%s);\n' % (o,p,v,allocsize)]
        X += ['  test_%s_%s_%s = aligned(storage_%s_%s_%s,%s);\n' % (o,p,v,o,p,v,allocsize)]
        X += [f'  if ({allocsize} > maxalloc) maxalloc = {allocsize};\n']
      for v,initsize,allocsize in vars:
        X += ['  storage_%s_%s_%s2 = callocplus(maxalloc);\n' % (o,p,v)]
        X += ['  test_%s_%s_%s2 = aligned(storage_%s_%s_%s2,%s);\n' % (o,p,v,o,p,v,allocsize)]
    X += ['\n']

    if o in ('encode','decode','sort'): # requires alignment
      X += ['  for (long long offset = 0;offset < 1;++offset) {\n']
    else:
      X += ['  for (long long offset = 0;offset < 2;++offset) {\n']
    X += ['    if (targetoffset && atol(targetoffset) != offset) continue;\n']
    X += ['    if (offset && valgrind) break;\n']
    X += ['    printf("%s_%s offset %%lld\\n",offset);\n' % (o,p)]
    X += ['    for (long long impl = -1;impl < ntruprime_numimpl_%s_%s();++impl)\n' % (o,p)]
    X += ['      forked(test_%s_%s_impl,impl);\n' % (o,p)]
    for v,initsize,allocsize in vars:
      X += ['    ++test_%s_%s_%s;\n' % (o,p,v)]
    for v,initsize,allocsize in vars:
      X += ['    ++test_%s_%s_%s2;\n' % (o,p,v)]

    X += ['  }\n']

    for v,initsize,allocsize in reversed(vars):
      X += ['  free(storage_%s_%s_%s2);\n' % (o,p,v)]
    for v,initsize,allocsize in reversed(vars):
      X += ['  free(storage_%s_%s_%s);\n' % (o,p,v)]

    X += ['}\n']

    if (o,p) in sizes:
      for line in sizes[o,p].splitlines():
        psize = line.split()[1]
        size1 = psize.replace('crypto_%s_%s_'%(o,p),'crypto_%s_'%o)
        X += ['#undef %s\n' % size1]
      X += ['\n']

    save(f'command/ntruprime-test_{o}_{p}.c',''.join(X))


Z += [r'''/* ----- top level */

#include "print_cpuid.inc"

int main(int argc,char **argv)
{
  valgrind_init();
  if (valgrind) limits();

  setvbuf(stdout,0,_IOLBF,0);
  printf("ntruprime version %s\n",ntruprime_version);
  printf("ntruprime arch %s\n",ntruprime_arch);
  print_cpuid();

  if (valgrind) {
    printf("valgrind %d",(int) valgrind);
    printf(" declassify %d",(int) crypto_declassify_uses_valgrind);
    if (!crypto_declassify_uses_valgrind)
      printf(" (expect false positives)");
    printf("\n");
  }

  if (*argv) ++argv;
  if (*argv) {
    targeto = *argv++;
    if (*argv) {
      targetp = *argv++;
      if (*argv) {
        targeti = *argv++;
        if (*argv) {
          targetn = *argv++;
          if (*argv) {
            targetoffset = *argv++;
          }
        }
      }
    }
  }

''']

for p in primitives['verify']:
  Z += '  test_verify_BYTES();\n'.replace('BYTES',p)

for t in todo:
  o,vars,howmuch,tests = t
  for p in primitives[o]:
    Z += '  test_%s_%s();\n' % (o,p)
    H += [f'#define test_{o}_{p} ntruprime_test_{o}{p}\n']
    H += [f'extern void test_{o}_{p}(void);\n']

Z += [r'''
  if (!ok) {
    printf("some tests failed\n");
    return 100;
  }
  printf("all tests succeeded\n");
  return 0;
}
''']

H += ['#endif\n']

save('command/ntruprime_test.h',''.join(H))
save('command/ntruprime-test.c',''.join(Z))