#!/usr/bin/env python3
import h5py
import numpy as np
import pandas as pd

nredR = 5
nredX = 5

hf = h5py.File("ASCAD_data/ASCAD_databases/ASCAD.h5","r")

U = np.array(hf['Profiling_traces/traces'], dtype='float64')
P = np.array(hf['Profiling_traces/metadata']['plaintext'][:,2])
S = np.array(hf['Profiling_traces/labels'], dtype='uint8')
R = np.array(hf['Profiling_traces/metadata']['masks'][:,0])
X = S ^ R

# make linear fit -----------------------------------------------
N = np.array(pd.crosstab(R,X))
nRN = (N.T / N.sum(axis=1)).T
nXN = (N / N.sum(axis=0)).T
A = np.vstack([np.hstack([np.eye(256),nRN]), np.hstack([nXN,np.eye(256)])])
A = np.delete(A, 256, 0)
A = np.delete(A, 256, 1)
Ainv = np.linalg.inv(A)

uR = np.zeros([256, 700])
for r in range(0x100):
	uR[r,:] = U[R == r,:].mean(axis=0)

uX = np.zeros([256, 700])
for x in range(0x100):
	uX[x,:] = U[X == x,:].mean(axis=0)

uRX = np.vstack([uR, uX[1:,:]])

mRX = Ainv.dot(uRX)

mR = mRX[:256,:]
mX = np.vstack([np.zeros(700), mRX[256:,:]])
# linear fit done -----------------------------------------------

V = np.cov((U - mR[R,:] - mX[X,:]).T)
W = np.linalg.inv(np.linalg.cholesky(V))

egR = np.linalg.eigh(np.cov(W.dot(mR.T)))
egX = np.linalg.eigh(np.cov(W.dot(mX.T)))

iR = (-egR[0]).argsort()[:nredR]
iX = (-egX[0]).argsort()[:nredX]

red = np.hstack([egR[1][:,iR], egX[1][:,iX]]).T
Vred = red.dot(red.T)
Wred = np.linalg.inv(np.linalg.cholesky(Vred))

Proj = Wred.dot(red.dot(W)).T
pmR = mR.dot(Proj)
pmX = mX.dot(Proj)


outname = 'simple_model.bin'
f = open(outname, 'wb')
Proj.astype(np.float32).tofile(f)
pmR.astype(np.float32).tofile(f)
pmX.astype(np.float32).tofile(f)
f.close()
print('model saved in %s' % outname)
