# -*- coding: utf-8 -*-
# author:chao.yy
# email:yuyc@ishangqi.com
# date:2021/11/26 3:56 下午
# Copyright (C) 2021 The lesscode Team
import json

from tornado.options import options

import logging
import traceback

from neo4j import WorkspaceConfig

from lesscode.db.page import Page
from lesscode.utils.aes import AES


class Neo4jHelper:
    """
    Neo4j   数据库操作实现
    """

    def __init__(self, pool):
        """
        初始化sql工具
        :param pool: 连接池名称
        """
        if isinstance(pool, str):
            self.pool, self.dialect = options.database[pool]
        else:
            self.pool = pool

    def __repr__(self):
        printer = 'o(>﹏<)o ......Neo4j old driver "{0}" carry me fly...... o(^o^)o'.format(self.pool)
        return printer

    async def listreader(self, cypher, keys):
        """
        Read data from Neo4j in specified cypher.
        Read and parse data straightly from cypher field result.
        :param cypher: string
            Valid query cypher statement.
        :param keys: list
            Cypher query columns to return.
        :return: list
            Each returned record constructs a list and stored in a big list, [[...], [...], ...].
        """
        tx = await self.pool.begin_transaction()
        data = []
        result = await tx.run(cypher)
        async for record in result:
            rows = []
            for key in keys:
                rows.append(record[key])
            data.append(rows)
        return data

    async def dictreader(self, cypher):
        """
        Read data from Neo4j in specified cypher.
        The function depends on constructing dict method of dict(key = value) and any error may occur if the "key" is invalid to Python.
        you can choose function dictreaderopted() below to read data by hand(via the args "keys").
        :param cypher: string
            Valid query cypher statement.
        :return: list
            Each returned record constructs a dict in "key : value" pairs and stored in a big list, [{...}, {...}, ...].
        """
        tx = await self.pool.begin_transaction()
        data = []
        result = await tx.run(cypher)
        async for record in result.records():
            item = {}
            for args in str(record).split('>')[0].split()[1:]:
                exec
                "item.update(dict({0}))".format(args)
            data.append(item)
        return data

    async def dictreaderopted(self, cypher, keys=None):
        """
        Optimized function of dictreader().
        Read and parse data straightly from cypher field result.
        :param cypher: string
            Valid query cypher statement.
        :param keys: list, default : none(call dictreader())
            Cypher query columns to return.
        :return: list.
            Each returned record constructs an dict in "key : value" pairs and stored in a list, [{...}, {...}, ...].
        """
        if not keys:
            return await self.dictreader(cypher)
        else:
            tx = await self.pool.begin_transaction()
            data = []
            result = await tx.run(cypher)
            async for record in result:
                item = {}
                for key in keys:
                    item.update({key: record[key]})
                data.append(item)
            return data

    async def cypherexecuter(self, cypher):
        """
        Execute manipulation into Neo4j in specified cypher.
        :param cypher: string
            Valid handle cypher statement.
        :return: none.
        """
        async with self.pool.begin_transaction() as tx:
            return await tx.run(cypher)

    async def parse_relation_data(self, cql, result_list, id_key="id", name_key="name"):
        data = []
        num = 0
        stock_strike = True
        while stock_strike and num < 3:
            try:
                res = await self.listreader(cql, result_list)
                for item in res:
                    level = 1
                    if len(item) > 1:
                        if isinstance(item[1], list):
                            for item_item in item[1]:
                                relation_item = parse_relation_item(item_item, id_key, name_key, level)
                                level = level + 1
                                data.append(relation_item)
                        else:
                            relation_item = parse_relation_item(item[1], id_key, name_key, level)
                            data.append(relation_item)
                    else:
                        if isinstance(item[0], list):
                            for item_item in item[0]:
                                relation_item_list = parse_relation_item_more_relation(item_item, id_key, name_key)
                                data = data + relation_item_list
                        else:
                            relation_item_list = parse_relation_item_more_relation(item[0], id_key, name_key)
                            data = data + relation_item_list
                stock_strike = False
            except Exception:
                logging.error(traceback.format_exc())
            num = num + 1
        return data

    async def create_node(self, node_name: str, table_name: str, node_property: dict = {}):
        sql = f"CREATE ({node_name}:{table_name} {json.dumps(node_property)})"
        result = await self.execute_fetchall(sql)
        return result

    async def create_node_relationship(self, relationship_list: list):
        sql = f"CREATE "
        sql += ",".join(relationship_list)
        result = await self.execute_fetchall(sql)
        return result

    async def delete_node(self, node_list: list):
        sql = f"DELETE "
        sql += ",".join(node_list)
        result = await self.execute_fetchall(sql)
        return result

    async def delete_property(self, node: str, property_node: dict, node_property: str):
        sql = f"MATCH ({node}:{json.dumps(property_node)}) REMOVE {node}.{node_property} RETURN {node}"
        result = await self.execute_fetchall(sql)
        return result

    async def update_node(self, node_name: str, table_name: str, node_property_dict: dict):
        node_property_str = ",".join(
            [f"{node_name}.{property_name}={property_value}" for property_name, property_value in
             node_property_dict.items()])
        sql = f"MATCH ({node_name}:{table_name}) SET {node_property_str} RETURN {node_name}"
        result = await self.execute_fetchall(sql)
        return result

    async def query_all(self, node_name: str, table_name: str, node_property_dict: dict = {}):
        node_property_str = ",".join(
            [f"{node_name}.{property_name}={property_value}" for property_name, property_value in
             node_property_dict.items()])
        sql = f"MATCH ({node_name}:{table_name})"
        if node_property_str:
            sql += f" WHERE {node_property_str}"
        sql += f" RETURN {node_name}"
        result = await self.execute_fetchall(sql)
        return result

    async def query_count(self, sql):
        result = await self.execute_fetchall(sql)
        return result

    async def query_page(self, node_name: str, table_name: str, node_property_dict: dict = {},
                         page_num: int = 1, page_size: int = 10):
        node_property_str = ",".join(
            [f"{node_name}.{property_name}={property_value}" for property_name, property_value in
             node_property_dict.items()])
        sql = f"MATCH ({node_name}:{table_name})"
        if node_property_str:
            sql += f" WHERE {node_property_str}"
        page_num = page_num if page_num > 1 else 1
        num = (page_num - 1) * page_size
        count_sql = sql + " RETURN COUNT(*)"
        total = await self.query_count(count_sql)
        sql += f" RETURN {node_name} SKIP {num} LIMIT {page_size}"
        result = await self.execute_fetchall(sql)
        return Page(records=result, current=page_num, page_size=page_size, total=total).__dict__

    async def execute_fetchall(self, sql):
        async with self.pool.session() as session:
            tx = await session.begin_transaction()
            result = await tx.run(sql)
            return result

    def sync_listreader(self, cypher, keys):
        """
        Read data from Neo4j in specified cypher.
        Read and parse data straightly from cypher field result.
        :param cypher: string
            Valid query cypher statement.
        :param keys: list
            Cypher query columns to return.
        :return: list
            Each returned record constructs a list and stored in a big list, [[...], [...], ...].
        """
        with self.pool.begin_transaction() as tx:
            data = []
            result = tx.run(cypher)
            for record in result:
                rows = []
                for key in keys:
                    rows.append(record[key])
                data.append(rows)
            return data

    def sync_dictreader(self, cypher):
        """
        Read data from Neo4j in specified cypher.
        The function depends on constructing dict method of dict(key = value) and any error may occur if the "key" is invalid to Python.
        you can choose function dictreaderopted() below to read data by hand(via the args "keys").
        :param cypher: string
            Valid query cypher statement.
        :return: list
            Each returned record constructs a dict in "key : value" pairs and stored in a big list, [{...}, {...}, ...].
        """
        with self.pool.begin_transaction() as tx:
            data = []
            for record in tx.run(cypher).records():
                item = {}
                for args in str(record).split('>')[0].split()[1:]:
                    exec
                    "item.update(dict({0}))".format(args)
                data.append(item)
            return data

    def sync_dictreaderopted(self, cypher, keys=None):
        """
        Optimized function of dictreader().
        Read and parse data straightly from cypher field result.
        :param cypher: string
            Valid query cypher statement.
        :param keys: list, default : none(call dictreader())
            Cypher query columns to return.
        :return: list.
            Each returned record constructs an dict in "key : value" pairs and stored in a list, [{...}, {...}, ...].
        """
        if not keys:
            return self.sync_dictreader(cypher)
        else:
            with self.pool.begin_transaction() as tx:
                data = []
                result = tx.run(cypher)
                for record in result:
                    item = {}
                    for key in keys:
                        item.update({key: record[key]})
                    data.append(item)
                return data

    def sync_cypherexecuter(self, cypher):
        """
        Execute manipulation into Neo4j in specified cypher.
        :param cypher: string
            Valid handle cypher statement.
        :return: none.
        """
        with self.pool.begin_transaction() as tx:
            return tx.run(cypher)

    def sync_parse_relation_data(self, cql, result_list, id_key="id", name_key="name", database="neo4j"):
        data = []
        num = 0
        stock_strike = True
        while stock_strike and num < 3:
            try:
                res = self.sync_listreader(cql, result_list)
                for item in res:
                    level = 1
                    if len(item) > 1:
                        if isinstance(item[1], list):
                            for item_item in item[1]:
                                relation_item = parse_relation_item(item_item, id_key, name_key, level)
                                level = level + 1
                                data.append(relation_item)
                        else:
                            relation_item = parse_relation_item(item[1], id_key, name_key, level)
                            data.append(relation_item)
                    else:
                        if isinstance(item[0], list):
                            for item_item in item[0]:
                                relation_item_list = parse_relation_item_more_relation(item_item, id_key, name_key)
                                data = data + relation_item_list
                        else:
                            relation_item_list = parse_relation_item_more_relation(item[0], id_key, name_key)
                            data = data + relation_item_list
                stock_strike = False
            except Exception:
                logging.error(traceback.format_exc())
            num = num + 1
        return data

    def sync_create_node(self, node_name: str, table_name: str, node_property: dict = {}, database=""):
        sql = f"CREATE ({node_name}:{table_name} {json.dumps(node_property)})"
        result = self.sync_execute_fetchall(sql, database=database)
        return result

    def sync_create_node_relationship(self, relationship_list: list, database=""):
        sql = f"CREATE "
        sql += ",".join(relationship_list)
        result = self.sync_execute_fetchall(sql, database=database)
        return result

    def sync_delete_node(self, node_list: list, database=""):
        sql = f"DELETE "
        sql += ",".join(node_list)
        result = self.sync_execute_fetchall(sql, database=database)
        return result

    def sync_delete_property(self, node: str, property_node: dict, node_property: str, database=""):
        sql = f"MATCH ({node}:{json.dumps(property_node)}) REMOVE {node}.{node_property} RETURN {node}"
        result = self.sync_execute_fetchall(sql, database=database)
        return result

    def sync_update_node(self, node_name: str, table_name: str, node_property_dict: dict, database=""):
        node_property_str = ",".join(
            [f"{node_name}.{property_name}={property_value}" for property_name, property_value in
             node_property_dict.items()])
        sql = f"MATCH ({node_name}:{table_name}) SET {node_property_str} RETURN {node_name}"
        result = self.sync_execute_fetchall(sql, database=database)
        return result

    def sync_query_all(self, node_name: str, table_name: str, node_property_dict: dict = {}, database=""):
        node_property_str = ",".join(
            [f"{node_name}.{property_name}={property_value}" for property_name, property_value in
             node_property_dict.items()])
        sql = f"MATCH ({node_name}:{table_name}) "
        if node_property_str:
            sql += f" WHERE {node_property_str}"
        sql += f" RETURN {node_name}"
        result = self.sync_execute_fetchall(sql, database=database)
        return result

    def sync_query_count(self, sql, database="neo4j"):
        result = self.sync_execute_fetchall(sql, database=database)
        return result

    def sync_query_page(self, node_name: str, table_name: str = None, match_property_dict: dict = {},
                        node_property_dict: dict = {}, result_list: list = None, database="neo4j",
                        page_num: int = 1, page_size: int = 10):
        if table_name and not match_property_dict:
            sql = f"MATCH ({node_name}:{table_name})"
        elif match_property_dict and not table_name:
            sql = f"MATCH ({node_name} {json.dumps(match_property_dict)})"
        elif match_property_dict and table_name:
            sql = f"MATCH ({node_name}:{table_name} {json.dumps(match_property_dict)})"
        else:
            raise Exception("At least one of table_name and match_property_dict")
        node_property_str = ",".join(
            [f"{node_name}.{property_name}={property_value}" for property_name, property_value in
             node_property_dict.items()])
        if node_property_str:
            sql += f" WHERE {node_property_str}"
        page_num = page_num if page_num > 1 else 1
        num = (page_num - 1) * page_size
        count_sql = sql + " RETURN COUNT(*)"
        total = self.sync_query_count(count_sql)
        if result_list:
            sql += f" RETURN {','.join(result_list)} SKIP {num} LIMIT {page_size}"
        else:
            sql += f" RETURN {node_name} SKIP {num} LIMIT {page_size}"
        result = self.sync_execute_fetchall(sql, database=database)
        records = []
        for record in result:
            records.append(dict(record))
        return Page(records=records, current=page_num, page_size=page_size, total=total).__dict__

    def sync_execute_fetchall(self, sql):
        with self.pool.begin_transaction() as tx:
            result = tx.run(sql)
            return result


def parse_relation_item(item, id_key, name_key, level=0):
    relation_item = {"start_node": parse_node_dict(item.start_node, id_key, name_key),
                     "end_node": parse_node_dict(item.end_node, id_key, name_key),
                     "type": item.type, "level": level, "properties": item._properties}
    return relation_item


def parse_relation_item_more_relation(item, id_key, name_key):
    relation_item_list = []
    for data in item._relationships:
        relation_item_list.append(parse_relation_item(data, id_key, name_key))
    return relation_item_list


def parse_node_dict(node, id_key, name_key):
    return {
        "id": AES.encrypt(options.aes_key, node._properties[id_key]),
        "name": node._properties[name_key],
        "label": list(node._labels)[0]
    }
