import sys, os, socket, base64, logging, threading
import json
import settings as s

from Crypto.PublicKey import RSA, DSA
from Utils import Security, Format, Communication, Negotiation, Measurement

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend

class Server:
	def __init__(self, bindaddr, algs, key, db, proto, ipv, ifc, ec):
		Measurement.measureTime("", True)
		self.b = bindaddr 
		self.port = s.PORT
		self.a = json.loads(algs)
		self.d = json.loads(db)
		self.p = proto
		self.ipv = ipv
		self.ifc = ifc
#		self.refreshDH()
		self.ec = ec
		if self.ec == True:
			self.k = serialization.load_pem_private_key(key,password=None,backend=default_backend())
		else:
			self.k = RSA.importKey(key)
		self.rsn = Security.generateRSN()
#		self.nr = Security.generateNonce()
		self.ni = b''
		self.dh_gx = b''
		self.dh_key = b'' 
		self.err = ""
		self.state = 0
		self.msghashes = []
		Measurement.measureTime("CREATE")

	def refreshDH(self):
		if self.ec == True:
			self.dh_gy, self.y = Security.generateECDH()
			self.init_dsig = Security.ECDSIG(self.dh_gy, self.k, s.halg)
		else:
			self.dh_gy, self.y = Security.generateStrongDH()
			self.init_dsig = Security.DSIG(self.dh_gy, self.k, s.halg)

	def refreshNonce(self):
		self.nr = Security.generateNonce()

	def refreshRSN(self):
		self.rsn = Security.generateRSN()

	def bind(self):
		if self.ipv == '4':
			family = socket.AF_INET
		if self.ipv == '6':
			family = socket.AF_INET6
		if self.p == 'tcp':
			stype = socket.SOCK_STREAM
			proto = socket.IPPROTO_TCP 
		if self.p == 'sctp':
			stype = socket.SOCK_STREAM
			proto = socket.IPPROTO_SCTP
		if self.p == 'udp':
			stype = socket.SOCK_DGRAM
			proto = socket.IPPROTO_UDP
		if self.p == 'ip':
			stype = socket.SOCK_RAW
			#proto = socket.IPPROTO_RAW
			proto = s.IPPROTO
		if self.p == 'eth':
			self.s = socket.socket(socket.PF_PACKET, socket.SOCK_RAW)
			self.s.bind((self.ifc, s.ETHTPE))
		else:
			self.s = socket.socket(family,stype,proto)
			if self.p == 'ip':
				self.s.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 0)
				self.s.bind((self.b,self.port))
			else:
				self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
				self.s.bind((self.b,self.port))

	def listen_accept(self):
		self.state = 0
		if self.p in ('tcp', 'sctp'):
			Measurement.measureTime("", True)
			self.s.listen(1)
			self.cs, addr = self.s.accept()
			Measurement.measureTime("LISTEN_ACCEPT")

	def send_data(self):
		Measurement.measureTime("", True)
		if self.p in ('tcp', 'sctp'):
			Communication.sendData(self.cs,self.p,self.data,self.h)
		if self.p in ('udp', 'ip'):
			Communication.sendData(self.s,self.p,self.data,self.h)
		if self.p == 'eth':
			self.data = s.stringToHexData(self.h) + \
				Communication.getIfcMACbytes(self.ifc) + \
				s.ETHTPE_B + self.data
			#logging.debug('Sending Ethernet data: ' + s.dataToHexString(self.data))
			Communication.sendData(self.s,self.p,self.data,self.h)
		Measurement.measureTime("SEND")

	def recv_data(self):
		Measurement.measureTime("", True)
		val = True
		while (val):
			if self.p in ('tcp', 'sctp'):
				self.data, self.h = Communication.recvData(self.cs,self.p)
			if self.p in ('udp', 'ip', 'eth'):
				self.data, self.h = Communication.recvData(self.s,self.p)
			val, self.msghashes = Security.checkDuplicateMsg(self.msghashes, self.data)
		Measurement.measureTime("RECV")

	def client_disconnect(self):
		Measurement.measureTime("", True)
		if self.p in ('tcp', 'sctp'):
			self.cs.close()

	def disconnect(self):
		self.s.close()

	def prepare_data(self,tpe):
		Measurement.measureTime("", True)
		self.data = Format.prepareData(tpe, self.dh_gx, self.dh_gy, \
			self.dh_key, self.k, self.rsn, self.a, self.err) 
		Measurement.measureTime("PREP_"+tpe)

	def parse_data(self):
		Measurement.measureTime("", True)
		data = Format.parseData(self.data)
		d = data[0]
		t = data[1]
		if t == b'\x00':
			if self.state != 0:
				logging.error('Server wrong state ' + str(self.state))
				exit()
			self.dh_gx = data[2] #dh_gx
			self.dh_key = Security.calculateDH(self.y, self.dh_gx)
			self.state = 1
		elif t == b'\x02':
			if self.state != 1:
				logging.error('Server wrong state ' + str(self.state))
				exit()
			k2 = data[2]
			dsig = data[3]
			hmac = data[4]
			self.k2 = RSA.importKey(k2)
			r = Security.verifyDSIG(self.dh_gy + self.dh_gx + d + t, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			r = Security.verifyHMAC(k2, self.dh_key, s.halg, hmac)
			if ( not r ):
				logging.error('HMAC failed: type: ' + str(t))
			self.state = 2
		elif t == b'\x03':
			if self.state != 2:
				logging.error('Server wrong state ' + str(self.state))
				exit()
			rsn = data[2];
			a2 = data[3]
			dsig = data[4]
			r = Security.verifyDSIG(d + t + rsn + a2, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			if (Security.checkDupRSN(rsn)):
				logging.error('Client duplicate RSN detected' + str(rsn))
				exit()
			if ( rsn != Security.incrementRSN(self.rsn) ):
				logging.error('RSN wrong: type: ' + str(t))
			self.a2 = json.loads(a2.decode())
			self.state = 3
		else:
			logging.error('Wrong message type: ' + str(t))
			logging.error('State: ' + str(self.state))
			exit()
#		if t == b'\x04':
		Measurement.measureTime("PROC_"+str(t))

	def prepare_data2(self,tpe):
		Measurement.measureTime("", True)
		self.data = Format.prepareData2(tpe, self.dh_gx, self.dh_gy, \
			self.dh_key, self.k, self.ni, self.nr, self.a, self.err) 
		Measurement.measureTime("PREP_"+tpe)

	def parse_data2(self):
		Measurement.measureTime("", True)
		data = Format.parseData2(self.data)
		d = data[0]
		t = data[1]
		if t == b'\x00':
			tpe = "INIT_I"
			if self.state != 0:
				logging.error('Server wrong state ' + str(self.state))
				exit()
			self.dh_gx = data[2] #dh_gx
			self.ni = data[3]
			self.dh_key = Security.calculateDH(self.y, self.dh_gx)
			self.state = 1
		elif t == b'\x02':
			tpe = "LIST_I"
			if self.state != 1:
				logging.error('Server wrong state ' + str(self.state))
				exit()
			k2 = data[2]
			a2 = data[3]
			dsig = data[4]
			hmac = data[5]
			self.k2 = RSA.importKey(k2)
			headOfData = self.dh_gx + self.dh_gy + self.ni + self.nr
			r = Security.verifyDSIG(headOfData + a2 + d + t, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			r = Security.verifyHMAC(k2, self.dh_key, s.halg, hmac)
			if ( not r ):
				logging.error('HMAC failed: type: ' + str(t))
			self.a2 = json.loads(a2.decode())
			self.state = 2
		else:
			tpe = 'ERROR'
			logging.error('Wrong message type: ' + str(t))
			logging.error('State: ' + str(self.state))
			exit()
#		if t == b'\x04':
		Measurement.measureTime("PROC_"+tpe)

	def prepare_data3(self,tpe):
		Measurement.measureTime("", True)
		self.data = Format.prepareData4(tpe, self.dh_gx, self.dh_gy, \
			self.dh_key, self.k, self.ni, self.nr, self.a, self.err, \
			self.init_dsig, self.ec) 
		Measurement.measureTime("PREP_"+tpe)

	def parse_data3(self):
		Measurement.measureTime("", True)
		data = Format.parseData4(self.data)
		d = data[0]
		t = data[1]
		if t == b'\x00':
			tpe = "INIT_I"
			if self.state != 0:
				logging.error('Server wrong state ' + str(self.state))
				exit()
			self.dh_gx = data[2] #dh_gx
			self.ni = data[3]
#			Measurement.measureTime("", True)

			if self.ec == True:
				self.dh_shared_secret = Security.calculateECDH(self.y, self.dh_gx)
			else:
				self.dh_shared_secret = Security.calculateDH(self.y, self.dh_gx)
			self.keyseed = Security.prf(self.ni+self.nr,self.dh_shared_secret)
			self.dh_key = Security.prfplus(self.keyseed, (self.nr + self.ni + s.getPublicKey(self.k,self.ec)), s.DHL*8)
			#self.dh_key = Security.calculateDH(self.y, self.dh_gx)
#			Measurement.measureTime("CALCULATE_DH")
			self.state = 1
		elif t == b'\x02':
			tpe = "LIST_I"
			if self.state != 1:
				logging.error('Server wrong state ' + str(self.state))
				exit()
			k2 = data[2]
			a2 = data[3]
			dsig = data[4]
			hmac = data[5]
			if self.ec == True:
				self.k2 = serialization.load_pem_public_key(k2,backend=default_backend())
				r = Security.verifyECDSIG(a2 + self.dh_gx + self.dh_gy, self.k2, s.halg, dsig)
			else:
				self.k2 = RSA.importKey(k2)
				r = Security.verifyDSIG(a2 + self.dh_gx + self.dh_gy, self.k2, s.halg, dsig)
			if ( not r ):
				logging.error('DSIG failed: type: ' + str(t))
			r = Security.verifyHMAC(k2 + self.nr + self.ni + t , self.dh_key, s.halg, hmac)
			if ( not r ):
				logging.error('HMAC failed: type: ' + str(t))
			self.a2 = json.loads(a2.decode())
			self.state = 2
		else:
			tpe = 'ERROR'
			logging.error('Wrong message type: ' + str(t))
			logging.error('State: ' + str(self.state))
			exit()
#		if t == b'\x04':
		Measurement.measureTime("PROC_"+tpe)

	def negotiate(self):
		Measurement.measureTime("", True)
		res = Negotiation.algorithm(self.a, self.a2)
		Measurement.measureTime("NEGOTIATE")
		return res

	def negotiate2(self):
		Measurement.measureTime("", True)
		res = Negotiation.algorithm2(self.a2, self.a)
		self.negotiated = res
		Measurement.measureTime("NEGOTIATE")
		return res

	def calculateSymmetricKey(self):
		keylen = int(self.negotiated["secret_key"]["algorithm"].split("_")[1])
		key=Security.prfplus(self.keyseed, "protocol data".encode(), keylen)
		logging.info("Shared secret key (" + str(keylen) + " bit): " + s.dataToHexString(key))
