import sys, os, socket, base64, logging, threading
import json
import settings as s

from Crypto.PublicKey import RSA, DSA
from Utils import Security, Format, Communication, Negotiation, Measurement

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend

class Client:
	def __init__(self, host, algs, key, db, proto, ipv, ifc, ec):
		Measurement.measureTime("", True)
		self.h = host
		self.port = s.PORT
		self.a = json.loads(algs)
		self.d = json.loads(db)
		self.p = proto
		self.ipv = ipv
		self.ifc = ifc
		self.ni = Security.generateNonce();
		self.rsn = b'' 
		self.nr = b''
		self.ec = ec
		if self.ec == True:
			self.k = serialization.load_pem_private_key(key,password=None,backend=default_backend())
			self.dh_gx, self.x = Security.generateECDH()
		else:
			self.k = RSA.importKey(key)
			self.dh_gx, self.x = Security.generateStrongDH()
		self.dh_gy = b''
		self.dh_key = b'' 
		self.err = ""
		self.state = 0
		self.msghashes = []
		Measurement.measureTime("CREATE")

	def resolve_setup(self):
		if self.ipv == '4':
			family = socket.AF_INET
		if self.ipv == '6':
			family = socket.AF_INET6
		if self.p == 'tcp':
			stype = socket.SOCK_STREAM
			proto = socket.IPPROTO_TCP 
		if self.p == 'sctp':
			stype = socket.SOCK_STREAM
			proto = socket.IPPROTO_SCTP
		if self.p == 'udp':
			stype = socket.SOCK_DGRAM
			proto = socket.IPPROTO_UDP
		if self.p == 'ip':
			stype = socket.SOCK_RAW
			#proto = socket.IPPROTO_RAW
			proto = s.IPPROTO
			self.port = ''
		if self.p == 'eth':
			self.s = socket.socket(socket.PF_PACKET, socket.SOCK_RAW)
			self.s.bind((self.ifc, s.ETHTPE))
		else:
			ai_list = socket.getaddrinfo(self.h,self.port,family,stype,proto)
			logging.debug('ai_list: %s',ai_list)
			self.s = socket.socket(ai_list[0][0],ai_list[0][1],ai_list[0][2])
			self.h = ai_list[0][4]
			logging.debug('host: %s',self.h)
			if self.p == 'udp':
				self.s.setblocking(0)
				self.s.settimeout(s.TIMEOUT)
			if self.p == 'ip':
				self.s.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 0)

	def connect(self):
		if self.p in ('tcp', 'sctp'):
			self.s.connect(self.h)

	def send_data(self):
		if self.p == 'eth':
			self.data = s.stringToHexData(self.h) + \
				Communication.getIfcMACbytes(self.ifc) + \
				s.ETHTPE_B + self.data
			#logging.debug('Sending Ethernet data: ' + s.dataToHexString(self.data))
		Communication.sendData(self.s,self.p,self.data,self.h)

	def recv_data(self):
		val = True
		while (val):
			self.data, addr = Communication.recvData(self.s,self.p)
			val, self.msghashes = Security.checkDuplicateMsg(self.msghashes, self.data)

	def disconnect(self):
		self.s.close()

	def prepare_data(self,tpe):
		Measurement.measureTime("", True)
		self.data = Format.prepareData(tpe, self.dh_gx, self.dh_gy, \
			self.dh_key, self.k, self.rsn, self.a, self.err) 
		Measurement.measureTime("PREP_"+str(tpe))

	def parse_data(self):
		Measurement.measureTime("", True)
		data = Format.parseData(self.data)
		d = data[0]
		t = data[1]
		if t == b'\x01':
			if self.state != 0:
				logging.error('Client wrong state ' + str(self.state))
				exit()
			self.dh_gy = data[2]
			k2 = data[3]
			dsig = data[4]
			hmac = data[5]
			self.dh_key = Security.calculateDH(self.x, self.dh_gy)
			self.k2 = RSA.importKey(k2)
			r = Security.verifyDSIG(self.dh_gx + self.dh_gy + d + t, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			r = Security.verifyHMAC(k2, self.dh_key, s.halg, hmac)
			if ( not r ):
				logging.error('HMAC failed: type: ' + str(t))
			self.state = 1
		elif t == b'\x03':
			if self.state != 1:
				logging.error('Client wrong state' + str(self.state))
				exit()
			self.rsn = data[2]
			a2 = data[3]
			dsig = data[4]
			# this is wrong, XXX fix it
			r = Security.verifyDSIG(d + t + self.rsn + a2, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			if (Security.checkDupRSN(self.rsn)):
				logging.error('Client duplicate RSN detected' + str(self.rsn))
				exit()
			self.a2 = json.loads(a2.decode())
			self.rsn = Security.incrementRSN(self.rsn)
			self.state = 2
#		if t == b'\x04':
		else:
			logging.error('Wrong message type: ' + str(t))
			logging.error('State: ' + str(self.state))
			exit()
		Measurement.measureTime("PROC_"+str(t))

	def prepare_data2(self,tpe):
		Measurement.measureTime("", True)
		self.data = Format.prepareData2(tpe, self.dh_gx, self.dh_gy, \
			self.dh_key, self.k, self.ni, self.nr, self.a, self.err) 
		Measurement.measureTime("PREP_"+str(tpe))

	def parse_data2(self):
		Measurement.measureTime("", True)
		data = Format.parseData2(self.data)
		d = data[0]
		t = data[1]
		if t == b'\x01':
			tpe = "INIT_R"
			if self.state != 0:
				logging.error('Client wrong state ' + str(self.state))
				exit()
			self.dh_gy = data[2]
			self.nr = data[3]
			k2 = data[4]
			dsig = data[5]
			hmac = data[6]
			self.dh_key = Security.calculateDH(self.x, self.dh_gy)
			self.k2 = RSA.importKey(k2)
			headOfData = self.dh_gx + self.dh_gy + self.ni + self.nr
			r = Security.verifyDSIG(headOfData + d + t, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			r = Security.verifyHMAC(k2, self.dh_key, s.halg, hmac)
			if ( not r ):
				logging.error('HMAC failed: type: ' + str(t))
			self.state = 1
		elif t == b'\x03':
			tpe = "LIST_R"
			if self.state != 1:
				logging.error('Client wrong state' + str(self.state))
				exit()
			a2 = data[2]
			dsig = data[3]
			headOfData = self.dh_gx + self.dh_gy + self.ni + self.nr
			r = Security.verifyDSIG(headOfData + a2 + d + t, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			self.a2 = json.loads(a2.decode())
			self.state = 2
#		if t == b'\x04':
		else:
			tpe = "ERROR"
			logging.error('Wrong message type: ' + str(t))
			logging.error('State: ' + str(self.state))
			exit()
		Measurement.measureTime("PROC_"+str(tpe))

	def prepare_data3(self,tpe):
		Measurement.measureTime("", True)
		self.data = Format.prepareData4(tpe, self.dh_gx, self.dh_gy, \
			self.dh_key, self.k, self.ni, self.nr, self.a, self.err, ec = self.ec) 
		Measurement.measureTime("PREP_"+str(tpe))

	def parse_data3(self):
		Measurement.measureTime("", True)
		data = Format.parseData3(self.data)
		d = data[0]
		t = data[1]
		if t == b'\x01':
			tpe = "INIT_R"
			if self.state != 0:
				logging.error('Client wrong state ' + str(self.state))
				exit()
			self.dh_gy = data[2]
			self.nr = data[3]
			k2 = data[4]
			dsig = data[5]
			hmac = data[6]
			self.dh_shared_secret = Security.calculateDH(self.x, self.dh_gy)
			self.keyseed = Security.prf(self.ni+self.nr,self.dh_shared_secret)
			self.dh_key = Security.prfplus(self.keyseed, (self.nr + self.ni + k2), 32)
			self.k2 = RSA.importKey(k2)
			r = Security.verifyDSIG(self.dh_gy, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			r = Security.verifyHMAC(k2 + self.ni + self.nr + d + t, self.dh_key, s.halg, hmac)
			if ( not r ):
				logging.error('HMAC failed: type: ' + str(t))
			self.state = 1
		elif t == b'\x03':
			tpe = "LIST_R"
			if self.state != 1:
				logging.error('Client wrong state' + str(self.state))
				exit()
			a2 = data[2]
			dsig = data[3]
			headOfData = self.dh_gx + self.dh_gy + self.ni + self.nr
			r = Security.verifyDSIG(headOfData + a2 + d + t, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			self.a2 = json.loads(a2.decode())
			self.state = 2
#		if t == b'\x04':
		else:
			tpe = "ERROR"
			logging.error('Wrong message type: ' + str(t))
			logging.error('State: ' + str(self.state))
			exit()
		Measurement.measureTime("PROC_"+str(tpe))

	def parse_data4(self):
		Measurement.measureTime("", True)
		data = Format.parseData4(self.data)
		d = data[0]
		t = data[1]
		if t == b'\x01':
			tpe = "INIT_R"
			if self.state != 0:
				logging.error('Client wrong state ' + str(self.state))
				exit()
			self.dh_gy = data[2]
			self.nr = data[3]
			k2 = data[4]
			dsig = data[5]
			hmac = data[6]
			if self.ec == True:
				self.dh_shared_secret = Security.calculateECDH(self.x, self.dh_gy)
				self.k2 = serialization.load_pem_public_key(k2,backend=default_backend())
				r = Security.verifyECDSIG(self.dh_gy, self.k2, s.halg, dsig)
			else:
				self.dh_shared_secret = Security.calculateDH(self.x, self.dh_gy)
				self.k2 = RSA.importKey(k2)
				r = Security.verifyDSIG(self.dh_gy, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			self.keyseed = Security.prf(self.ni+self.nr,self.dh_shared_secret)
			self.dh_key = Security.prfplus(self.keyseed, (self.nr + self.ni + k2), s.DHL*8)
			r = Security.verifyHMAC(k2 + self.ni + self.nr + t, self.dh_key, s.halg, hmac)
			if ( not r ):
				logging.error('HMAC failed: type: ' + str(t))
			self.state = 1
		elif t == b'\x03':
			tpe = "LIST_R"
			if self.state != 1:
				logging.error('Client wrong state' + str(self.state))
				exit()
			a2 = data[2]
			dsig = data[3]
			headOfData = self.dh_gx + self.dh_gy + self.ni + self.nr
			if self.ec == True:
				r = Security.verifyECDSIG(headOfData + a2 + t, self.k2, s.halg, dsig)
			else:
				r = Security.verifyDSIG(headOfData + a2 + t, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			self.a2 = json.loads(a2.decode())
			self.state = 2
#		if t == b'\x04':
		else:
			tpe = "ERROR"
			logging.error('Wrong message type: ' + str(t))
			logging.error('State: ' + str(self.state))
			exit()
		Measurement.measureTime("PROC_"+str(tpe))
	
	def negotiate(self):
		Measurement.measureTime("",True)
		res = Negotiation.algorithm(self.a2, self.a)
		Measurement.measureTime("NEGOTIATE")
		return res

	def negotiate2(self):
		Measurement.measureTime("",True)
		res = Negotiation.algorithm2(self.a, self.a2)
		self.negotiated = res
		Measurement.measureTime("NEGOTIATE")
		return res

	def calculateSymmetricKey(self):
		keylen = int(self.negotiated["secret_key"]["algorithm"].split("_")[1])
		key = Security.prfplus(self.keyseed, "protocol data".encode(), keylen)
		logging.info("Shared secret key (" + str(keylen) + " bit): " + s.dataToHexString(key))
