#!/usr/bin/env python3
import matplotlib.pyplot as plt
from sys import argv
import h5py
import numpy as np
from AES_Sbox import AES_Sbox

#-----------------------------------------------------------
import ascad_best_mlp
import ascad_best_cnn
import other_mlp_01
import simple_model
models = [ascad_best_mlp, ascad_best_cnn, other_mlp_01, simple_model]
#-----------------------------------------------------------

hf = h5py.File("ASCAD_data/ASCAD_databases/ASCAD.h5","r")
U = np.array(hf['Attack_traces/traces'])
P = np.array(hf['Attack_traces/metadata']['plaintext'][:,2])

def plot_meanrank(rankmat, maxtrc, label):
	nt = np.arange(maxtrc) + 1
	mr = np.mean(rankmat, 0)
	plt.xlabel('number of traces')
	plt.ylabel('mean rank')
	plt.plot(nt, mr, label = label)

def mk_rankmat(model, nruns, maxtrc, batches):
	realkey = 0xe0
	rankmat = np.tile(0, (nruns, maxtrc))
	for krun in range(nruns):
		print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
		samp = batches[krun,:]
		lps = np.log(model.predict(U[samp,:]))
		lpsums = np.zeros(256)
		for i in range(maxtrc):
			S = AES_Sbox[P[samp[i]] ^ range(0x100)]
			lpsums += lps[i, S]
			rnk = sum(lpsums >= lpsums[realkey])
			rankmat[krun, i] = rnk
	return rankmat

try:
	nruns = int(argv[1])
except:	
	nruns = 100
try:
	maxtrc = int(argv[2])
except:
	maxtrc = 70

batches = np.zeros((nruns, maxtrc), 'int')
for i in range(nruns):
	batches[i,:] = np.random.choice(U.shape[0], maxtrc, False)

plt.rc('axes', prop_cycle=(plt.cycler('color', ['r', 'y', 'g', 'b'])))
plt.grid()
for model in models:
	rankmat = mk_rankmat(model, nruns, maxtrc, batches)
	plot_meanrank(rankmat, maxtrc, model.__name__)
plt.legend()	
plt.title('Model comparison using %d test runs' % nruns)
plt.tight_layout()	
plt.show(block = False)
input('hit <RETURN> to exit...')
