#!/usr/bin/env python3

import sys, os, socket, getopt, base64, time, logging, threading
import json
import settings as s

from server import Server
from client import Client

from Crypto.PublicKey import RSA, DSA
from Utils import Security, Format, Communication, Negotiation, Measurement
from datetime import datetime
#from Crypto.Hash import MD5, SHA, SHA512

from cryptography.hazmat.backends import default_backend

def usage(): 
	print("Usage:")
	print("SERVER: ./acnp.py -S -a algs_file -k private_key_file -d database")
	print("CLIENT: ./acnp.py -C -a algs_file -k private_key_file -d database -h host")

def main():
	try:
		opts,args = getopt.getopt(sys.argv[1:], "tusie:SCK46h:a:k:d:v:l:L:AE", \
				["tcp", "udp", "sctp", "ip", "eth", "server", "client", "keys",\
					"4", "6", "host=", "algs=", "keys=", "db=", "version=",\
					"logging=", "logfile=", "attacker", "elliptic"])
	except(getopt.GetoptError, err):
		usage()
		print(str(err))
		sys.exit(2)

	proto = 'tcp'
	server = False
	client = False
	ipv = '4'
	ifc = ''
	ALGS = ''
	HOST = ''
	KEY = ''
	DB = ''
	version = 2
	levl = 20
	attacker = False
	elliptic = False
	logfile=''

	for o,a in opts:
		if o in ("-t","--tcp"):
			proto = 'tcp'
		elif o in ("-u","--udp"):
			proto = 'udp'
		elif o in ("-s","--sctp"):
			proto = 'sctp'
		elif o in ("-i","--ip"):
			proto = 'ip'
		elif o in ("-e","--eth"):
			proto = 'eth'
			ifc = a
		elif o in ("-S","--server"):
			server = True
		elif o in ("-C","--client"):
			client = True
		elif o in ("-A","--attacker"):
			attacker = True
		elif o in ("-E","--elliptic"):
			elliptic = True
		elif o in ("-K","--keys"):
			global NEG_KEYS
			NEG_KEYS = True
		elif o in ("-4"):
			ipv = '4'
		elif o in ("-6"):
			ipv = '6'
		elif o in ("-h","--host"):
			HOST = a
		elif o in ("-a","--algs"):
			ALGS = a
		elif o in ("-k","--key"):
			KEY = a
		elif o in ("-d","--db"):
			DB = a
		elif o in ("-v","--version"):
			version = int(a)
		elif o in ("-l","--logging"):
			levl = int(a)
		elif o in ("-L","--logfile"):
			logfile = a
		else:
			usage()
			sys.exit(2)

	if logfile != '':
		logging.basicConfig(filename=logfile, level=levl, \
			format='%(levelname)s - %(message)s' )
	else:
		logging.basicConfig(stream=sys.stderr, level=levl, \
			format='%(levelname)s - %(message)s' )

	if ALGS != '':
		f = open(ALGS)
		algs = f.read()
		f.close()
	else:
		print("Wrong algs")
		usage()
		sys.exit(1)

	if KEY != '':
		f = open(KEY,"rb")
		key = f.read()
		f.close()
	else:
		print("Wrong key")
		usage()
		sys.exit(1)

	if DB != '':
		f = open(DB)
		db = f.read()
		f.close()
		db_json = json.loads(db)
	else:
		if (server):
			if (os.path.exists(s.SERVER_DB)):
				f = open(s.SERVER_DB)
				db = f.read()
				f.close()
			else:
				db = '{}'
		elif (client):
			if (os.path.exists(s.CLIENT_DB)):
				f = open(s.CLIENT_DB)
				db = f.read()
				f.close()
			else:
				db = '{}'

	default_backend()

	bindaddr = ''

	#test if the interface exists
	if proto == 'eth':
		logging.debug ('Using interface ' + ifc + ' (' + Communication.getIfcMAC(ifc)+ ')')
		logging.debug ('Sending to host ' + HOST )

	if (client and attacker):
		c = Client(HOST, algs, key, db, proto, ipv, ifc)
		logging.debug('Client created')
		c.resolve_setup()
		while True:
			logging.debug('Client setup')
			c.connect()
			logging.debug('Client connect')
			logging.debug('Client dh: ' + s.dataToHexString(c.dh_gx))
			logging.debug('Client Nonce: ' + s.dataToHexString(c.ni))
			logging.debug('Client public key:\n' + c.k.publickey().exportKey().decode())
# sending gx, ni INIT_I
			c.prepare_data3('INIT_I')
			c.send_data()

	if (server and version == 0):
		srv = Server(bindaddr, algs, key, db, proto, ipv, ifc)
		logging.debug('Server created')
		srv.bind()
		logging.debug('Server listening')
		while True:
			srv.refreshRSN()
			srv.listen_accept()
			logging.debug('Server accepted')
			logging.debug('Server dh: ' + s.dataToHexString(srv.dh_gy))
			logging.debug('Server public key:\n' + srv.k.publickey().exportKey().decode())
# recving gx INIT_1
			srv.recv_data()
			srv.parse_data()
			logging.debug('Client dh: ' + s.dataToHexString(srv.dh_gx))
			logging.debug('DH key: ' + s.dataToHexString(srv.dh_key))
# sending gy, pks INIT_2
			srv.prepare_data('INIT_2')
			srv.send_data()
			#time.sleep(0.3)
# recving pkc INIT_3
			srv.recv_data()
			srv.parse_data()
			logging.debug('Client public key:\n' + srv.k2.publickey().exportKey().decode())
# sending serv_list LIST
			logging.debug('Server RSN: ' + s.dataToHexString(srv.rsn))
			logging.debug('Server LIST: ' + json.dumps(srv.a))
			srv.prepare_data('LIST')
			srv.send_data()
# recving clie_list LIST
			srv.recv_data()
			srv.parse_data()
			logging.debug('Client LIST: ' + json.dumps(srv.a2))
			logging.debug('Server negotiated: \n' + json.dumps(srv.negotiate(), sort_keys=True, indent=4))
			srv.client_disconnect()
# disconnect
		srv.disconnect()

	if (client and version == 0):
		c = Client(HOST, algs, key, db, proto, ipv, ifc)
		logging.debug('Client created')
		c.resolve_setup()
		logging.debug('Client setup')
		c.connect()
		logging.debug('Client connect')
		logging.debug('Client dh: ' + s.dataToHexString(c.dh_gx))
		logging.debug('Client public key:\n' + c.k.publickey().exportKey().decode())
# sending gx INIT_1
		c.prepare_data('INIT_1')
		c.send_data()
		#time.sleep(0.3)
# recving gy, pks INIT_2
		c.recv_data()
		c.parse_data()
		logging.debug('Server dh: ' + s.dataToHexString(c.dh_gy))
		logging.debug('DH key: ' + s.dataToHexString(c.dh_key))
		logging.debug('Server public key:\n' + c.k2.publickey().exportKey().decode())
# sending pkc INIT_3
		c.prepare_data('INIT_3')
		c.send_data()
# recving serv_list LIST
		c.recv_data()
		c.parse_data()
		logging.debug('Client RSN: ' + s.dataToHexString(c.rsn))
		logging.debug('Server LIST: ' + json.dumps(c.a2))
# sending clie_list LIST
		logging.debug('Client LIST: ' + json.dumps(c.a))
		c.prepare_data('LIST')
		c.send_data()
		logging.debug('Client negotiated: \n' + json.dumps(c.negotiate(), sort_keys=True, indent=4))
# disconnect
		c.disconnect()

	global current
	current = datetime.now()

	if (server and version == 1):
		srv = Server(bindaddr, algs, key, db, proto, ipv, ifc)
		logging.debug('Server created')
		srv.bind()
		logging.debug('Server listening')
		while True:
			srv.listen_accept()
			logging.debug('Server accepted')
			logging.debug('Server dh: ' + s.dataToHexString(srv.dh_gy))
			logging.debug('Server Nonce: ' + s.dataToHexString(srv.nr))
			logging.debug('Server public key:\n' + srv.k.publickey().exportKey().decode())
# recving gx, ni INIT_I
			srv.recv_data()
			srv.parse_data2()
			logging.debug('Client dh: ' + s.dataToHexString(srv.dh_gx))
			logging.debug('Client Nonce: ' + s.dataToHexString(srv.ni))
			logging.debug('DH key: ' + s.dataToHexString(srv.dh_key))
# sending gy, nr, pks INIT_R
			srv.prepare_data2('INIT_R')
			srv.send_data()
			#time.sleep(0.3)
# recving pkc, clie_list LIST_I
			srv.recv_data()
			srv.parse_data2()
			logging.debug('Client public key:\n' + srv.k2.publickey().exportKey().decode())
			logging.debug('Client LIST: ' + json.dumps(srv.a2))
# sending serv_list LIST_R
			logging.debug('Server LIST: ' + json.dumps(srv.a))
			srv.prepare_data2('LIST_R')
			srv.send_data()
# over
			logging.debug('Server negotiated: \n' + json.dumps(srv.negotiate2(), \
				sort_keys=True, indent=4))
			srv.client_disconnect()
# disconnect
		srv.disconnect()

	if (client and version == 1):
		c = Client(HOST, algs, key, db, proto, ipv, ifc)
		logging.debug('Client created')
		c.resolve_setup()
		logging.debug('Client setup')
		c.connect()
		logging.debug('Client connect')
		logging.debug('Client dh: ' + s.dataToHexString(c.dh_gx))
		logging.debug('Client Nonce: ' + s.dataToHexString(c.ni))
		logging.debug('Client public key:\n' + c.k.publickey().exportKey().decode())
# sending gx, ni INIT_I
		c.prepare_data2('INIT_I')
		c.send_data()
		#time.sleep(0.3)
# recving gy, nr, pks INIT_R
		c.recv_data()
		c.parse_data2()
		logging.debug('Server dh: ' + s.dataToHexString(c.dh_gy))
		logging.debug('DH key: ' + s.dataToHexString(c.dh_key))
		logging.debug('Server Nonce: ' + s.dataToHexString(c.nr))
		logging.debug('Server public key:\n' + c.k2.publickey().exportKey().decode())
# sending clie_list, pkc LIST_I
		logging.debug('Client LIST: ' + json.dumps(c.a))
		c.prepare_data2('LIST_I')
		c.send_data()
# recving serv_list LIST_R
		c.recv_data()
		c.parse_data2()
		logging.debug('Server LIST: ' + json.dumps(c.a2))
# over 
		res = c.negotiate2()
		logging.debug('Client negotiated: \n' + json.dumps(res, sort_keys=True, indent=4))
# disconnect
		c.disconnect()

	if (server and version == 2):
		srv = Server(bindaddr, algs, key, db, proto, ipv, ifc, elliptic)
		example = refreshDHN(srv)
		# dirty fix that waits for first values to be generated
		time.sleep(2)
		logging.debug('Server created')
		srv.bind()
		logging.debug('Server listening')
		try:
			while True:
				srv.listen_accept()
				logging.debug('Server accepted')
				logging.debug('Server dh: ' + s.dataToHexString(srv.dh_gy))
				logging.debug('Server Nonce: ' + s.dataToHexString(srv.nr))
				logging.info('Server public key:\n' + s.printPublicKeyFromPrivate(srv.k,elliptic))
				# recving gx, ni INIT_I
				if proto in ('tcp', 'sctp'):
					Measurement.measureTotalTime("",True)
					srv.recv_data()
				else:
					srv.recv_data()
					Measurement.measureTotalTime("",True)
				srv.parse_data3()
				logging.debug('Client dh: ' + s.dataToHexString(srv.dh_gx))
				logging.debug('Client Nonce: ' + s.dataToHexString(srv.ni))
				logging.debug('DH key: ' + s.dataToHexString(srv.dh_key))
				# sending gy, nr, pks INIT_R
				srv.prepare_data3('INIT_R')
				srv.send_data()
				#time.sleep(0.3)
				# recving pkc, clie_list LIST_I
				srv.recv_data()
				srv.parse_data3()
				logging.info('Client public key:\n' + s.printPublicKey(srv.k2,elliptic))
				logging.debug('Client LIST: ' + json.dumps(srv.a2))
				# sending serv_list LIST_R
				logging.debug('Server LIST: ' + json.dumps(srv.a))
				srv.prepare_data3('LIST_R')
				srv.send_data()
				# over
				logging.info('Server negotiated: \n' + json.dumps(srv.negotiate2(), \
					sort_keys=True, indent=4))
				srv.client_disconnect()
				srv.calculateSymmetricKey()
				Measurement.measureTotalTime("TOTAL")
		except KeyboardInterrupt:
			print("Caught KeyboardInterrupt, terminating server")
			# disconnect
			srv.disconnect()

	if (client and version == 2):
		Measurement.measureTotalTime("", True)
		c = Client(HOST, algs, key, db, proto, ipv, ifc, elliptic)
		logging.debug('Client created')
		c.resolve_setup()
		logging.debug('Client setup')
		c.connect()
		logging.debug('Client connect')
		logging.debug('Client dh: ' + s.dataToHexString(c.dh_gx))
		logging.debug('Client Nonce: ' + s.dataToHexString(c.ni))
		logging.info('Client public key:\n' + s.printPublicKeyFromPrivate(c.k, elliptic))
# sending gx, ni INIT_I
		c.prepare_data3('INIT_I')
		c.send_data()
		#time.sleep(0.3)
# recving gy, nr, pks INIT_R
		c.recv_data()
		c.parse_data4()
		logging.debug('Server dh: ' + s.dataToHexString(c.dh_gy))
		logging.debug('DH key: ' + s.dataToHexString(c.dh_key))
		logging.debug('Server Nonce: ' + s.dataToHexString(c.nr))
		logging.info('Server public key:\n' + s.printPublicKey(c.k2, elliptic))
# sending clie_list, pkc LIST_I
		logging.debug('Client LIST: ' + json.dumps(c.a))
		c.prepare_data3('LIST_I')
		c.send_data()
# recving serv_list LIST_R
		c.recv_data()
		c.parse_data4()
		logging.debug('Server LIST: ' + json.dumps(c.a2))
# over 
		res = c.negotiate2()
		logging.debug('Client negotiated: \n' + json.dumps(res, \
			sort_keys=True, indent=4))
		logging.info('Client negotiated: \n' + json.dumps(res, \
			sort_keys=True, indent=4))
# disconnect
		c.disconnect()
		c.calculateSymmetricKey()
		Measurement.measureTotalTime("TOTAL")

class refreshDHN(object):
	""" Refresh DH and Nonce class

	The run() method will be started and it will run in the background
	until the application exits.
	"""
	def __init__(self, server, interval=30):
		""" Constructor
		:type server: Server
		:param server: class to refresh DH and Nonce
		:type interval: int
		:param interval: Check interval, in seconds
		"""
		self.interval = interval
		self.server = server

		thread = threading.Thread(target=self.run, args=())
		thread.daemon = True                            # Daemonize thread
		thread.start()                                  # Start the execution

	def run(self):
		""" Method that runs forever """
		while True:
			self.server.refreshDH()
			self.server.refreshNonce()
			print("Refreshed DH and Nonce...")
			time.sleep(self.interval)


if __name__ == "__main__":
	main()
	exit()
#	#hmac test
#	hmac=Security.HMAC("TEST".encode(),"test".encode(),"sha256")
#	print(Security.verifyHMAC("TEST".encode(),"test".encode(),"sha256",hmac))
#	#dsig test
#	dsig=Security.DSIG("TEST".encode(),k,"sha256")
#	print(Security.verifyDSIG("TEST".encode(),k,"sha256",dsig))
#	rsn = Security.generateRSN()
#	print(s.dataToHexString(rsn))
#	print(len(rsn))
	f=open("private_key_client")
	key=f.read()
	f.close()
	logging.debug('starting main')
	c = Client('localhost','test',key,'db','tcp','4')
	exit()
	logging.debug('client sent data')
	logging.debug('client disconnect')

