import paho.mqtt.client as mqtt
import time
import json
from neo4j import GraphDatabase
import os

broker_hostname=str(os.getenv('mos_host',default="localhost"))
broker_port = int(os.getenv('mos_port',default=1883)) 
client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION1,"Client5")
db_uri =str(os.getenv('db_host',default="bolt://localhost:7687"))
neo4j_auth=("","")
abort_time_limit = int(os.getenv('abort_time_limit', default=99999))

def flatten_obj(key, val, target):
    complexType = type(val)
    if val is None:
        return None
    elif complexType is dict:
        for k, v in val.items():
            flatten_obj(key + "_" + k, v, target)
    elif complexType is list:
        for i in range(len(val)):
            v = val[i]
            flatten_obj(key + "_" + str(i), v, target)
    elif complexType is object:
        for k, v in val.__dict__.items():
            flatten_obj(key + "_" + k, v, target)
    else:
        target[key] = val

def parse_json_to_cypher(input) :
    jsonType = list(input['datum'].keys())[0]
    # short type string
    nodeType = jsonType.rsplit(".", 1)[1]
    # data of object
    value = input["datum"][jsonType]

    relations = dict(
     runsOn=""
    ,isGeneratedBy=""
    ,affects=""
    ,affects2=""
    ,residesOn=""
    ,isPartOf="" # not in data set
    ,hasOwningPricipal=""
    ,hasTag="" # not in data set
    ,hasParent=""
    ,hasAccountOn=""
    ,hasLocalPrincipal=""
    )


    # create relationships
    try: 
        if nodeType == 'Subject':
            if value['parentSubject'] != None:
                relations.update({'hasParent':value['parentSubject']["com.bbn.tc.schema.avro.cdm18.UUID"]})
            if value['hostId'] != None:
                relations.update({'runsOn':value['hostId']})
            if value['localPrincipal'] != None:
                relations.update({'hasLocalPrincipal':value['localPrincipal']})
            # the relationship for subject -[affects]-> event is missing... probably implicit through is generated by

        if nodeType == 'FileObject':
            if value['baseObject'] != None:
                relations.update({"residesOn":value['baseObject']['hostId']})
            # relations.update({"isPartOf":})
            if value['localPrincipal']:
                relations.update({"hasOwningPrincipal":value['localPrincipal']})
            # relations.update({"hasTag":})
        
        # create relationships for host id 
        # mapping of cdm fields to relationships of nodes
        if nodeType == 'Event':
            if value['hostId'] != None:
                relations.update({'runsOn':value['hostId']})
            if value['subject'] !=  None:
                relations.update({'isGeneratedBy':value['subject']['com.bbn.tc.schema.avro.cdm18.UUID']})
            if value['predicateObject'] !=  None:
                relations.update({'affects':value['predicateObject']['com.bbn.tc.schema.avro.cdm18.UUID']})
            if value['predicateObject2'] != None:
                relations.update({'affects2':value['predicateObject2']['com.bbn.tc.schema.avro.cdm18.UUID']})
                
        if nodeType == 'Principal':
            if value['hostId'] != None:
                relations.update({'hasAccountOn':value['hostId']})
            
        if nodeType == 'UnnamedPipeObject':
            if value['sourceUUID'] != None:
                relations.update({'affects':value['sourceUUID']['com.bbn.tc.schema.avro.cdm18.UUID']})
            if value['sinkUUID'] != None:
                relations.update({'affects2':value["sinkUUID"]["com.bbn.tc.schema.avro.cdm18.UUID"]})
            if value['baseObject'] != None:
                relations.update({'residesOn':value['baseObject']['hostId']})

        if nodeType == 'NetFlowObject':
            if value['baseObject'] != None:
                relations.update({'residesOn':value['baseObject']['hostId']})
    
        if nodeType == 'SrcSinkObject':
            if value['baseObject'] != None:
                relations.update({'residesOn':value['baseObject']['hostId']})
        
        q_rel = ''
        insert_strings = []
        # lookup existing nodes for relations  
        for rel in relations.keys(): 
            val = relations[rel]
            if val != '':
                if rel =="residesOn":
                    s = f'''
                        WITH new
                        MATCH (existing:Host) where existing._uuid = "{val}" 
                        CREATE (new) -[:{rel}]-> (existing)
                        '''
                    insert_strings.append(s)
                    
                if rel == "runsOn":
                    s = f'''
                        WITH new
                        MATCH (existing:Host) where existing._uuid = "{val}"
                        CREATE (new) -[:{rel}]-> (existing)
                        '''
                    insert_strings.append(s)
                if rel =="isGeneratedBy":
                    s = f'''
                        WITH new
                        MATCH (existing:Subject) where existing._uuid = "{val}"
                        CREATE (new) -[:{rel}]-> (existing)
                        '''
                    insert_strings.append(s)
                if rel =="hasOwningPrincipal":
                    s = f'''
                        WITH new
                        MATCH (existing:Principal) where existing._uuid = "{val}"
                        CREATE (new) -[:{rel}]-> (existing)
                        '''
                    insert_strings.append(s)
                if rel =="affects":
                    s = f'''
                        WITH new
                        MATCH (existing) where existing._uuid = "{val}"
                        CREATE (new) -[:{rel}]-> (existing)
                    '''
                    insert_strings.append(s)
                if rel == 'affects2':
                        s = f'''
                            WITH new
                            MATCH (existing) where existing._uuid = "{val}"
                            CREATE (new) -[:{rel}]-> (existing)
                        '''
                        insert_strings.append(s)            
                # ... other relations for Object not in data
                if rel =="hasParent":
                    q_rel = f'''
                    '''
        q_rel = "".join(insert_strings)
        value_flat = {}
        flatten_obj("",value,value_flat)

        query =  f"""
        CREATE (new:{nodeType} $attributes)
        {q_rel}
        RETURN new
        """
        return query,value_flat
    except:
        print('input: ', input)
        print('relations: ', relations)


def create_cypher_query_from_cdm(json):
    '''
    Create Cypher Queries from publisher message
    ''' 
    query,value = parse_json_to_cypher(json)
    return query, value

def on_message(client, userdata, message):
    '''
    The callback function for message listener 
    '''
    data = json.loads(message.payload.decode("utf-8"))
    print("Received message from: ",message.topic)
    q,attr = create_cypher_query_from_cdm(data)
    # print(q)
    execute_query(q,attr)

def on_connect(client, userdata, flags, return_code):
    '''
    Connecting and subscribing to the Mosquitto topic
    '''
    if return_code == 0:
        print("connected")
        client.subscribe("neo4j")
    else:
        print("could not connect, return code:", return_code)
        client.failed_connect = True

def connect_to_db(uri,auth):
    '''
    Establish db connection to neo4j
    '''
    driver = GraphDatabase.driver(uri, auth=auth)
    with driver.session() as session:    
        print("Cleanup existing data...")
        session.run("MATCH (n) detach delete n")
        session.run("RETURN 1 as result")
        print("Successfully connected to DB...")
    return driver
    
   
def execute_query(query:str, attributes):
    '''
    Execute any Neo4j Query. 
    
    Expected Query Parameter: 
    query = Query String, 
    attributes = atttributes to be inserted
    ''' 
    with driver.session() as session:
        result = session.run(query,attributes=attributes)
        return result.data()

driver = connect_to_db(db_uri,neo4j_auth)

# client.username_pw_set(username="user_name", password="password") # uncomment if you use password auth
client.on_connect = on_connect
client.on_message = on_message
client.failed_connect = False

client.connect(broker_hostname, broker_port) 
client.loop_start()

# this try-finally block ensures that whenever we terminate the program earlier by hitting ctrl+c, it still gracefully exits   
try:
    i = 0
    while i < abort_time_limit and client.failed_connect == False:
            time.sleep(1)
            i += 1
            if client.failed_connect == True:
                print('Connection failed, exiting...')

finally:
    client.disconnect()
    client.loop_stop()   
    driver.close()