import math as m
import socket
import settings as s
import hashlib, json, logging, netifaces, re, logging
from struct import pack
from datetime import datetime

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes

from cryptography.hazmat.primitives.asymmetric import dh

from Crypto.Random import random
from Crypto.Util import number

class Measurement:
	@staticmethod
	def measureTime(msg, quiet = False):
		global current
		temp = datetime.now()
		if not quiet:
			diff=temp-current
			logging.debug(msg + " duration: " + str(diff.seconds*1000000+diff.microseconds))
		current = temp

	@staticmethod
	def measureTotalTime(msg, quiet = False):
		global currentTotal
		temp = datetime.now()
		if not quiet:
			diff=temp-currentTotal
			logging.info(msg + " duration: " + str(diff.seconds*1000000+diff.microseconds))
		currentTotal = temp

class Security:
	@staticmethod
	def generateNonce():
		res = random.getrandbits(s.RSNL*8)
		return res.to_bytes(s.RSNL,'big')

	@staticmethod
	def generateRSN():
		res = random.getrandbits(s.RSNL*8)
		return res.to_bytes(s.RSNL,'big')

	@staticmethod
	def incrementRSN(rsn):
		res = int.from_bytes(rsn,'big')
		res += 1
		return res.to_bytes(s.RSNL,'big')

	@staticmethod
	def addRSNtoDB(rsn):
		f = open(s.RSNDB, 'a+')
		r = int.from_bytes(rsn,'big')
		f.write(str(r) + '\n')
		f.close()

	@staticmethod
	def checkDupRSN(rsn):
		try:
			f = open(s.RSNDB)
		except:
			Security.addRSNtoDB(rsn)
			return False
		db = f.read()
		f.close()
		r = int.from_bytes(rsn,'big')
		if str(r) in db:
			return True
		else:
			Security.addRSNtoDB(rsn)
			return False

	@staticmethod
	def generateDH():
		rnd = random.getrandbits(s.DHL*8)
		res = pow(s.g, rnd, s.p)
		return res.to_bytes(s.DHL,'big'),rnd

	@staticmethod
	def generateStrongDH():
		rnd = number.getStrongPrime(s.DHL*8) 
		res = pow(s.g, rnd, s.p)
		return res.to_bytes(s.DHL,'big'),rnd

	@staticmethod
	def generateDHnew():
		random = dh.DHParameters.generate_private_key()
		rnd = number.getStrongPrime(s.DHL*8) 
		res = pow(s.g, rnd, s.p)
		return res.to_bytes(s.DHL,'big'),rnd

	@staticmethod
	def calculateDH(y, gx):
		gxt = int.from_bytes(gx,'big')
		res = pow(gxt, y, s.p)
		return res.to_bytes(s.DHL,'big')

	@staticmethod
	def generateECDH():
		priv = ec.generate_private_key(ec.SECP256K1(), default_backend())
		publ = priv.public_key().public_bytes(serialization.Encoding.DER,serialization.PublicFormat.SubjectPublicKeyInfo)
		return publ, priv

	@staticmethod
	def calculateECDH(priv, pub_der):
		pub = serialization.load_der_public_key(pub_der, default_backend())
		shared = priv.exchange(ec.ECDH(), pub)
		return shared

	@staticmethod
	def HMAC(data, key, alg):
        #data - binary data
        #key - binary key
        #alg - string representing the hash algorithm
		#creating maps with all values for subsitution
		trans_5C = bytes([(x ^ 0x5C) for x in range(256)])
		trans_36 = bytes([(x ^ 0x36) for x in range(256)])

		o = hashlib.new(alg)
		i = hashlib.new(alg)
		
		#padding the key to match the block size
		key += "\0".encode() * (i.block_size - len(key))

		#hmac = Ho(key XOR 5C,i(key XOR 36,data))
		o.update(key.translate(trans_5C))
		i.update(key.translate(trans_36))
		i.update(data)
		o.update(i.digest())

		res = o.digest()
		return res

	@staticmethod
	def verifyHMAC(data, key, alg, hmac):
	#hmac - value to verify
		test = Security.HMAC(data, key, alg)
		if hmac == test:
		    return True
		return False

	@staticmethod
	def DSIG(data, key, alg):
        #data - binary data
        #key - RSA key object
        #alg - string representing the hash algorithm
		if key.can_sign():
			h = hashlib.new(alg)
			h.update(data)
			res = key.sign(h.digest(),'')
			#rsa signature size is equal to the key size
			#ceil is here because it returns n-1 instead of n
			n = m.ceil(key.size()/8)
			return res[0].to_bytes(n,'big')

	@staticmethod
	def verifyDSIG(data, key, alg, sig):
	#sig - signature to verify
		h = hashlib.new(alg)
		h.update(data)
		sig_tuple = int.from_bytes(sig, 'big'),''
		return key.verify(h.digest(),sig_tuple)

	@staticmethod
	def ECDSIG(data, key, alg):
        #data - binary data
        #key - EC key object
        #alg - string representing the hash algorithm
		signer = key.signer(ec.ECDSA(hashes.SHA256()))
		signer.update(data)
		signature = signer.finalize()
		return(signature)

	@staticmethod
	def verifyECDSIG(data, key, alg, sig):
	#sig - signature to verify
		verifier = key.verifier(sig,ec.ECDSA(hashes.SHA256()))
		verifier.update(data)
		return verifier.verify()

	@staticmethod
	def checkDuplicateMsg(hashes,msg):
		hash = hashlib.new('MD5')
		hash.update(msg)
		#print("hash " + s.dataToHexString(hash.digest()))
		if hash.digest() in hashes:
			res = True
		else:
			res = False
			if len(hashes) == s.WINDOW_SIZE:
				hashes.pop(0)
			hashes.append(hash.digest())
		return res, hashes

	@staticmethod
	def prf(data, key, hash_algorithm=s.halg):
		return Security.HMAC(data, key, hash_algorithm)

	@staticmethod
	def prfplus(key, data, bitlen):
		ret = bytes()
		prev = bytes()
		bytelen=int(bitlen/8)
		round = 1
		while len(ret) < bytelen:
			prev = Security.prf(key, prev + data + pack("!B", round))
			ret += prev
			round += 1
		return ret[:bytelen]

class Format:
	@staticmethod
	def addLength(data):
		d = len(data).to_bytes(2,'big') + data
		return d

	@staticmethod
	def splitData(data,nvals):
		values = []
		vals = 0
		while vals < nvals:
			val_len = int.from_bytes(data[0:2],'big')
			val = data[2:val_len+2]
			values.append(val)
			data = data[val_len+2:]
			vals += 1
		return tuple(values)

	@staticmethod
	def packData(data):
		packed = bytes()
		for d in data:
			packed = packed + Format.addLength(d)
		return packed

	@staticmethod
	def prepareData(tpe, dh_gx, dh_gy, dh_key, k, rsn, algs, err):
		if tpe == 'INIT_1':
			t = b'\x00'
			data = s.d + t + dh_gx
		if tpe == 'INIT_2':
			t = b'\x01'
			pk = k.publickey().exportKey()
			dsig = Security.DSIG(dh_gx + dh_gy + s.d + t, k, s.halg)
			hmac = Security.HMAC(pk, dh_key, s.halg)
			data = s.d + t + dh_gy + s.d + pk + s.d + dsig + s.d + hmac
		if tpe == 'INIT_3':
			t = b'\x02'
			pk = k.publickey().exportKey()
			dsig = Security.DSIG(dh_gy + dh_gx + s.d + t, k, s.halg)
			hmac = Security.HMAC(pk, dh_key, s.halg)
			data = s.d + t + pk + s.d + dsig + s.d + hmac
		if tpe == 'LIST':
			t = b'\x03'
			a = json.dumps(algs).encode()
# this is wrong, XXX fix it
			data = s.d + t + rsn + s.d + a
			dsig = Security.DSIG(s.d + t + rsn + a, k, s.halg)
			data = data + s.d + dsig
		if tpe == 'ABORT':
			t = b'\x04'
			data = s.d + t + rsn + s.d + err
		data = Format.addLength(data)
		return data

	@staticmethod
	def parseData(data):
		delim = data[0:2] #delimiter
		#print(s.dataToHexString(delim))
		t = data[2:3] #type 
		data = data[3:] #rest of data
		dsplit = data.split(delim)
		if t == b'\x00':
			#print(s.dataToHexString(data[0:10]))
			return delim, t, data #dh_gx
		if t == b'\x01':
			dh_gy, pkb, dsig, hmac = dsplit
			return delim, t, dh_gy, pkb, dsig, hmac
		if t == b'\x02':
			pka, dsig, hmac = dsplit
			return delim, t, pka, dsig, hmac
		if t == b'\x03':
			rsn, algs, dsig = dsplit
			return delim, t, rsn, algs, dsig
		if t == b'\x04':
			rsn, err = dsplit
			return delim, t, rsn, err
		return data

	@staticmethod
	def prepareData2(tpe, dh_gx, dh_gy, dh_key, k, ni, nr, algs, err):
		if tpe == 'INIT_I':
			t = b'\x00'
			data = s.d + t + dh_gx + s.d + ni
		if tpe == 'INIT_R':
			t = b'\x01'
			pk = k.publickey().exportKey()
			dsig = Security.DSIG(dh_gx + dh_gy + ni + nr + s.d + t, k, s.halg)
			hmac = Security.HMAC(pk, dh_key, s.halg)
			data = s.d + t + dh_gy + s.d + nr + s.d + pk + s.d + dsig + s.d + hmac
		if tpe == 'LIST_I':
			t = b'\x02'
			pk = k.publickey().exportKey()
			a = json.dumps(algs).encode()
			dsig = Security.DSIG(dh_gx + dh_gy + ni + nr + a + s.d + t, k, s.halg)
			hmac = Security.HMAC(pk, dh_key, s.halg)
			data = s.d + t + pk + s.d + a + s.d + dsig + s.d + hmac
		if tpe == 'LIST_R':
			t = b'\x03'
			a = json.dumps(algs).encode()
			dsig = Security.DSIG(dh_gx + dh_gy + ni + nr + a + s.d + t, k, s.halg)
			data = s.d + t + a + s.d + dsig
		if tpe == 'ABORT':
			t = b'\x04'
			data = s.d + t + rsn + s.d + err
		data = Format.addLength(data)
		return data
	
	@staticmethod
	def parseData2(data):
		delim = data[0:2] #delimiter
		#print(s.dataToHexString(delim))
		t = data[2:3] #type 
		data = data[3:] #rest of data
		dsplit = data.split(delim)
		if t == b'\x00':
			#print(s.dataToHexString(data[0:10]))
			dh_gx, ni = dsplit
			return delim, t, dh_gx, ni
		if t == b'\x01':
			dh_gy, nr, pkb, dsig, hmac = dsplit
			return delim, t, dh_gy, nr, pkb, dsig, hmac
		if t == b'\x02':
			pka, algs, dsig, hmac = dsplit
			return delim, t, pka, algs, dsig, hmac
		if t == b'\x03':
			algs, dsig = dsplit
			return delim, t, algs, dsig
		if t == b'\x04':
			rsn, err = dsplit
			return delim, t, rsn, err
		return data

	@staticmethod
	def prepareData3(tpe, dh_gx, dh_gy, dh_key, k, ni, nr, algs, err, dsig_init = ""):
		if tpe == 'INIT_I':
			t = b'\x00'
			data = s.d + t + dh_gx + s.d + ni
		if tpe == 'INIT_R':
			t = b'\x01'
			pk = k.publickey().exportKey()
#			dsig = Security.DSIG(dh_gy, k, s.halg)
			dsig = dsig_init
			hmac = Security.HMAC(pk + ni + nr + s.d + t, dh_key, s.halg)
			data = s.d + t + dh_gy + s.d + nr + s.d + pk + s.d + dsig + s.d + hmac
		if tpe == 'LIST_I':
			t = b'\x02'
			pk = k.publickey().exportKey()
			a = json.dumps(algs).encode()
			dsig = Security.DSIG(a + dh_gx + dh_gy , k, s.halg)
			hmac = Security.HMAC(pk + nr + ni + s.d + t, dh_key, s.halg)
			data = s.d + t + pk + s.d + a + s.d + dsig + s.d + hmac
		if tpe == 'LIST_R':
			t = b'\x03'
			a = json.dumps(algs).encode()
			dsig = Security.DSIG(dh_gx + dh_gy + ni + nr + a + s.d + t, k, s.halg)
			data = s.d + t + a + s.d + dsig
		if tpe == 'ABORT':
			t = b'\x04'
			data = s.d + t + rsn + s.d + err
		data = Format.addLength(data)
		dlen = int.from_bytes(data[:2],'big')
		logging.debug(tpe + " length: " + str(dlen+2))
		return data

	@staticmethod
	def parseData3(data):
		delim = data[0:2] #delimiter
		t = data[2:3] #type 
		logging.debug(str(t) + " length: " + str(len(data)+2))
		data = data[3:] #rest of data
		dsplit = data.split(delim)
		if t == b'\x00':
			dh_gx, ni = dsplit
			return delim, t, dh_gx, ni
		if t == b'\x01':
			dh_gy, nr, pkb, dsig, hmac = dsplit
			return delim, t, dh_gy, nr, pkb, dsig, hmac
		if t == b'\x02':
			pka, algs, dsig, hmac = dsplit
			return delim, t, pka, algs, dsig, hmac
		if t == b'\x03':
			algs, dsig = dsplit
			return delim, t, algs, dsig
		if t == b'\x04':
			rsn, err = dsplit
			return delim, t, rsn, err
		return data

	@staticmethod
	def prepareData4(tpe, dh_gx, dh_gy, dh_key, k, ni, nr, algs, err, dsig_init = "", ec = False):
		if tpe == 'INIT_I':
			t = b'\x00'
			data = t + Format.packData((dh_gx, ni))
		if tpe == 'INIT_R':
			t = b'\x01'
			pk = s.getPublicKey(k,ec)
			dsig = dsig_init
			hmac = Security.HMAC(pk + ni + nr + t, dh_key, s.halg)
			data = t + Format.packData((dh_gy, nr, pk, dsig, hmac))
		if tpe == 'LIST_I':
			t = b'\x02'
			pk = s.getPublicKey(k,ec)
			a = json.dumps(algs).encode()
			if ec:
				dsig = Security.ECDSIG(a + dh_gx + dh_gy , k, s.halg)
			else:
				dsig = Security.DSIG(a + dh_gx + dh_gy , k, s.halg)
			hmac = Security.HMAC(pk + nr + ni + t, dh_key, s.halg)
			data = t + Format.packData((pk, a, dsig, hmac))
		if tpe == 'LIST_R':
			t = b'\x03'
			a = json.dumps(algs).encode()
			if ec:
				dsig = Security.ECDSIG(dh_gx + dh_gy + ni + nr + a + t, k, s.halg)
			else:
				dsig = Security.DSIG(dh_gx + dh_gy + ni + nr + a + t, k, s.halg)
			data = t + Format.packData((a, dsig))
		if tpe == 'ABORT':
			t = b'\x04'
			data = s.d + t + rsn + s.d + err
		data = Format.addLength(data)
		dlen = int.from_bytes(data[:2],'big')
		logging.debug(tpe + " length: " + str(dlen+2))
		return data

	@staticmethod
	def parseData4(data):
		t = data[0:1] #type 
		logging.debug(str(t) + " length: " + str(len(data)+2))
		data = data[1:] #rest of data
		if t == b'\x00':
			dh_gx, ni = Format.splitData(data, 2)
			return 0, t, dh_gx, ni
		if t == b'\x01':
			dh_gy, nr, pkb, dsig, hmac = Format.splitData(data, 5)
			return 0, t, dh_gy, nr, pkb, dsig, hmac
		if t == b'\x02':
			pka, algs, dsig, hmac = Format.splitData(data, 4)
			return 0, t, pka, algs, dsig, hmac
		if t == b'\x03':
			algs, dsig = Format.splitData(data, 2)
			return 0, t, algs, dsig
		if t == b'\x04':
			rsn, err = dsplit
			return delim, t, rsn, err
		return data

class Communication:
	@staticmethod
	def sendData(s,p,data,h):
		if p in ('tcp', 'sctp', 'eth'):
			l = s.send(data)
		if p in ('udp', 'ip'):
			l = s.sendto(data,h)

	@staticmethod
	def recvData(so,p):
		if p in ('tcp', 'sctp'):
			dlen = int.from_bytes(so.recv(2),'big')
			msg = b''
			while len(msg)<dlen:
				chunk = so.recv(dlen-len(msg))
				msg += chunk
			return msg, ''
		if p in ('udp'):
			try:
				data, addr = so.recvfrom(s.BUF)
			except socket.timeout:
				print("Received timed out. Exiting...")
				exit(2)
			dlen = int.from_bytes(data[0:2],'big')
			return data[2:dlen+2], addr
		if p in ('ip'):
			data, addr = so.recvfrom(s.BUF)
			#logging.debug('Data needs parsing before: ' + s.dataToHexString(data))
			data2 = data[20:]
			#logging.debug('Data needs parsing after: ' + s.dataToHexString(data2))
			dlen = int.from_bytes(data2[0:2],'big')
			#print(dlen)
			return data2[2:dlen+2], addr
		if p in ('eth'):
			data = so.recv(s.BUF)
			addr = s.dataToHexString(data[6:12])
			#logging.debug('Ethernet sender address: ' + addr)
			data2 = data[14:]
			dlen = int.from_bytes(data2[0:2],'big')
			#logging.debug('Data needs parsing Ethernet: ' + s.dataToHexString(data2[2:dlen+2]))
			return data2[2:dlen+2], addr
	
	@staticmethod
	def getIfcMAC(ifc):
		return netifaces.ifaddresses(ifc)[netifaces.AF_LINK][0]['addr']

	@staticmethod
	def getIfcMACbytes(ifc):
		return s.stringToHexData(Communication.getIfcMAC(ifc))

class Negotiation:
	@staticmethod
	def algorithm(list1, list2):
		res = {}
		for t in list1['algorithm types']:
			algC = ''
			curr = -1
			for alg in list1['algorithm types'][t]:
				alg1v = list1['algorithm types'][t][alg]['value']
				if alg in list2['algorithm types'][t]:
					alg2v = list2['algorithm types'][t][alg]['value']
					maxv = min(alg1v,alg2v)
					if curr < maxv:
						curr = maxv
						algC = alg
					elif curr == maxv:
						algC1v = list1['algorithm types'][t][algC]['value']
						algC2v = list2['algorithm types'][t][algC]['value']
						algCBetter = max(algC1v,algC2v)
						algBetter = max(alg1v,alg2v)
						if algBetter > algCBetter:
							algC = alg
						elif algBetter == algCBetter:
							if alg1v > algC1v:
								algC = alg
			res[t] = {}
			res[t]['timestamp'] = datetime.now().strftime("%Y-%m-%d %H:%M")
			res[t]['algorithm'] = algC
		return res

	@staticmethod
	def algorithm2(listI, listR):
		res = {}
		lI = listI['algorithm types']
		lR = listR['algorithm types']
		for t in lI:
			algC = ''
			for alg in lI[t]:
				if alg in lR[t]:
					algC = alg
					break
			res[t] = {}
			res[t]['timestamp'] = datetime.now().strftime("%Y-%m-%d %H:%M")
			res[t]['algorithm'] = algC
		return res

