import sqlalchemy
from sqlalchemy import DDL, event
from xplordb.sqlalchemy.assay.structural_types import pgAzimuthType, pgSphericalType

from openlog.core.geo_extractor import XplordbGeoExtractor
from openlog.core.trace_splitter import XplordbSplitTracesQueries
from openlog.datamodel.assay.categories import CategoriesTableDefinition
from openlog.datamodel.assay.generic_assay import (
    AssayColumn,
    AssayDatabaseDefinition,
    AssayDataExtent,
    AssayDefinition,
    AssayDomainType,
    AssaySeriesType,
)
from openlog.datamodel.connection.interfaces.categories_interface import (
    CategoriesInterface,
)
from openlog.datamodel.connection.sqlalchemy.sqlalchemy_assay_interface import (
    SqlAlchemyAssayInterface,
)


class XplordbAssayInterface(SqlAlchemyAssayInterface):
    def __init__(self, engine, session, categories_iface: CategoriesInterface):
        """
        Implement AssayInterface for a Xplordb db

        Args:
            engine: sqlalchemy engine
            session: sqlalchemy session created from engine
            categories_iface : CategoriesInterface to get default lith category name
        """
        self._categories_iface = categories_iface
        super().__init__(engine, session, "assay")
        self.trace_splitter = XplordbSplitTracesQueries
        self.geo_extractor = XplordbGeoExtractor

    def default_assay_schema(self) -> str:
        """
        Return default schema for assay table creation

        """
        return "assay"

    def _cast_column_type(
        self, assay_table: sqlalchemy.Table, assay_definition: AssayDefinition
    ) -> sqlalchemy.Table:
        """
        Cast sqlalchemy.Table's columns depending backend engine.
        """
        # define generic function
        def _cast_one_column_type_(
            assay_table: sqlalchemy.Table,
            assay_definition: AssayDefinition,
            source_type: AssaySeriesType,
            target_type: sqlalchemy.types,
        ):
            # some columns have a dedicated postgresql type
            to_remove = [
                n
                for n, v in assay_definition.columns.items()
                if v.series_type == source_type
            ]
            col_instances = [
                col for col in assay_table._columns if col.name in to_remove
            ]
            for col in col_instances:
                assay_table._columns.remove(col)

            # create new columns with specific type
            for col_name in to_remove:
                c = sqlalchemy.Column(col_name, target_type(), primary_key=False)
                assay_table.append_column(c)

            return assay_table

        # spherical : composite type assay.spherical_data
        assay_table = _cast_one_column_type_(
            assay_table, assay_definition, AssaySeriesType.SPHERICAL, pgSphericalType
        )
        # polar : assay.azimuth type
        assay_table = _cast_one_column_type_(
            assay_table, assay_definition, AssaySeriesType.POLAR, pgAzimuthType
        )

        return assay_table

    def _convert_composite_string(
        self, params: list, assay_definition: AssayDefinition
    ) -> list:
        """
        Convert string values to composite type depending backend engine.
        """

        # spherical type has to be converted
        col_names = [
            n
            for n, v in assay_definition.columns.items()
            if v.series_type == AssaySeriesType.SPHERICAL
        ]

        for p in params:
            for name in col_names:
                if p[name] is None:
                    continue
                p[name] = pgSphericalType.insert_from_json(p[name])

        return params

    def _convert_spherical_to_dict(self, param: str):
        azimuth, dip, polarity, type_ = (
            param.replace("(", "").replace(")", "").split(",")
        )
        return {
            "azimuth": azimuth,
            "dip": dip,
            "type": type_.upper(),
            "polarity": polarity,
        }

    def _before_sqlalchemy_table_creation(
        self, table: sqlalchemy.Table, schema: str
    ) -> None:
        """
        Define sqlalchemy events before table creation. Used to define specific permission on created table.

        Args:
            base: sqlalchemy base to be created
            schema: (str) database schema used
        """
        schema_str = ""
        if schema:
            schema_str = schema + "."

        event.listen(
            table,
            "after_create",
            DDL(
                f"ALTER TABLE {schema_str}{table.name} OWNER TO xdb_admin;"
                f"GRANT SELECT, TRIGGER ON TABLE {schema_str}{table.name} TO xdb_viewer;"
                f"GRANT SELECT, UPDATE, INSERT, TRIGGER ON TABLE {schema_str}{table.name} TO xdb_logger;"
                f"GRANT ALL ON TABLE {schema_str}{table.name} TO xdb_admin;"
            ),
        )

    def alter_assay(self, assay_definition):

        old_def = [
            assay_def
            for assay_def in self.get_all_available_assay_definitions()
            if assay_def.variable == assay_definition.variable
        ][0]
        to_delete = [
            assay_column
            for name, assay_column in old_def.columns.items()
            if name not in assay_definition.columns.keys()
        ]
        to_add = [
            assay_column
            for name, assay_column in assay_definition.columns.items()
            if name not in old_def.columns.keys()
        ]
        if len(to_delete) > 0:
            for assay_column in to_delete:
                self._remove_assay_column(assay_definition.variable, assay_column)
        if len(to_add) > 0:
            for assay_column in to_add:
                self._add_assay_column(assay_definition.variable, assay_column)

    def _add_assay_column(self, variable, assay_column):

        # add column to assay table
        if assay_column.series_type in (
            AssaySeriesType.CATEGORICAL,
            AssaySeriesType.NOMINAL,
        ):
            type_ = "TEXT"
        elif assay_column.series_type in (AssaySeriesType.NUMERICAL):
            type_ = "REAL"
        elif assay_column.series_type in (AssaySeriesType.DATETIME):
            type_ = "DATE"
        elif assay_column.series_type in (AssaySeriesType.IMAGERY):
            type_ = "BLOB"
        elif assay_column.series_type == AssaySeriesType.SPHERICAL:
            type_ = "assay.spherical_data"
        elif assay_column.series_type == AssaySeriesType.POLAR:
            type_ = "assay.azimuth"

        q = f"ALTER TABLE assay.{variable} ADD COLUMN {assay_column.name} {type_};"
        self.session.execute(q)
        self.session.commit()

        if assay_column.series_type == AssaySeriesType.NUMERICAL:
            for other_col in assay_column.uncertainty.get_uncertainty_columns():
                q = f"ALTER TABLE assay.{variable} ADD COLUMN {other_col} REAL;"
                self.session.execute(q)
                self.session.commit()
            for other_col in assay_column.detection_limit.get_detection_columns():
                q = f"ALTER TABLE assay.{variable} ADD COLUMN {other_col} REAL;"
                self.session.execute(q)
                self.session.commit()

        # add column to assay definitions
        q = f"""INSERT INTO assay.assay_column VALUES('{assay_column.name}',
        '{variable}',
        '{assay_column.series_type.value}',
        '{assay_column.unit}',
        '{assay_column.category_name}',
        '{assay_column.uncertainty.upper_box_column}',
        '{assay_column.uncertainty.lower_box_column}',
        '{assay_column.uncertainty.upper_whisker_column}',
        '{assay_column.uncertainty.lower_whisker_column}',
        '{assay_column.image_format_col}',
        '{assay_column.detection_limit.detection_min_col}',
        '{assay_column.detection_limit.detection_max_col}',
        '{assay_column.display_name}',
        NULL);"""

        self.session.execute(q)
        self.session.commit()

        q = f"INSERT INTO assay.assay_column_definition VALUES('{assay_column.name}', '{variable}', '{assay_column.name}');"
        self.session.execute(q)
        self.session.commit()

        # add new category table if needed
        if assay_column.series_type == AssaySeriesType.CATEGORICAL:
            categorie_table_names = [
                table.name
                for table in self._categories_iface.get_available_categories_table()
            ]
            if assay_column.category_name not in categorie_table_names:
                category_def = CategoriesTableDefinition(
                    name=assay_column.category_name,
                    table_name=assay_column.category_name,
                )
                self._categories_iface.import_categories_table([category_def])

    def _remove_assay_column(self, variable: str, assay_column: AssayColumn):

        # delete splitted layer
        if assay_column.series_type in (
            AssaySeriesType.NUMERICAL,
            AssaySeriesType.CATEGORICAL,
        ):
            splitted_layers = []
            trigger_names = []
            table_names = []
            column = assay_column.name
            for geom in ["trace", "planned_trace"]:
                l = "_".join([variable, column, geom])
                splitted_layers.append(l)
                for op in ["update", "insert", "delete"]:
                    for table in ["collar", "assay"]:
                        s = "_".join([op, table, variable, column, geom])
                        trigger_names.append(s)
                        # table name
                        table_name = (
                            "display.display_collar"
                            if table == "collar"
                            else f"assay.{variable}"
                        )
                        table_names.append(table_name)

            # delete triggers
            for trigger, table_name in zip(trigger_names, table_names):
                q = f'DROP TRIGGER IF EXISTS "{trigger}" ON {table_name};'
                self.session.execute(q)
                # in lowercase
                q = f'DROP TRIGGER IF EXISTS "{trigger.lower()}" ON {table_name};'
                self.session.execute(q)
                self.session.commit()

            # delete trace layers
            for layer in splitted_layers:
                q = f'DROP TABLE IF EXISTS display."{layer}";'
                self.session.execute(q)
                self.session.commit()

        # add uncertainties and detection limits
        columns_to_drop = [assay_column.name]
        for uncertainty_col in assay_column.uncertainty.get_uncertainty_columns():
            columns_to_drop.append(uncertainty_col)
        for detection_col in assay_column.detection_limit.get_detection_columns():
            columns_to_drop.append(detection_col)

        # delete column
        for column in columns_to_drop:
            q = f"ALTER TABLE assay.{variable} DROP COLUMN {column};"
            self.session.execute(q)
            self.session.commit()

        # delete reference in assay definitions
        q = f"DELETE FROM assay.assay_column WHERE assay = '{variable}' and name = '{assay_column.name}';"
        self.session.execute(q)
        self.session.commit()

        q = f"DELETE FROM assay.assay_column_definition WHERE variable = '{variable}' and name = '{assay_column.name}';"
        self.session.execute(q)
        self.session.commit()

    def delete_assay_from_database(self, variable: str, only_splitted: bool = False):

        # get splitted layers
        q = f"SELECT name FROM assay.assay_column WHERE assay = '{variable}' AND series_type IN ('numerical', 'categorical');"
        splittable_columns = self.session.execute(q).fetchall()
        splittable_columns = [elt[0] for elt in splittable_columns]

        splitted_layers = []
        trigger_names = []
        table_names = []
        for column in splittable_columns:
            for geom in ["trace", "planned_trace"]:
                l = "_".join([variable, column, geom])
                splitted_layers.append(l)
                for op in ["update", "insert", "delete"]:
                    for table in ["collar", "assay"]:
                        s = "_".join([op, table, variable, column, geom])
                        trigger_names.append(s)
                        # table name
                        table_name = (
                            "display.display_collar"
                            if table == "collar"
                            else f"assay.{variable}"
                        )
                        table_names.append(table_name)

        # delete triggers
        for trigger, table_name in zip(trigger_names, table_names):
            q = f'DROP TRIGGER IF EXISTS "{trigger}" ON {table_name};'
            self.session.execute(q)
            # in lowercase
            q = f'DROP TRIGGER IF EXISTS "{trigger.lower()}" ON {table_name};'
            self.session.execute(q)
            self.session.commit()

        # delete trace layers
        for layer in splitted_layers:
            q = f'DROP TABLE IF EXISTS display."{layer}";'
            self.session.execute(q)
            self.session.commit()

        if only_splitted:
            return

        # delete assay table
        q = f"DROP TABLE IF EXISTS assay.{variable};"
        self.session.execute(q)
        self.session.commit()
        # delete assay definitions
        q = f"DELETE FROM assay.assay_column WHERE assay = '{variable}';"
        self.session.execute(q)
        self.session.commit()

        q = f"DELETE FROM assay.assay_column_definition WHERE variable = '{variable}';"
        self.session.execute(q)
        self.session.commit()

        q = f"DELETE FROM assay.assay_definition WHERE table_name = '{variable}';"
        self.session.execute(q)
        self.session.commit()

        q = f"DELETE FROM assay.assays WHERE variable = '{variable}';"
        self.session.execute(q)
        self.session.commit()

    def can_administrate_assays(self) -> dict:

        result = {"creation": False, "deletion": False}

        try:

            # get user name
            q = "SELECT session_user;"
            user = self.session.execute(q).fetchall()[0][0]

            # retrieve parent role
            q = f"""
                SELECT
                    r_parent.rolname AS parent_role,
                    r_child.rolsuper
                FROM
                    pg_roles r_child
                JOIN
                    pg_auth_members m ON r_child.oid = m.member
                JOIN
                    pg_roles r_parent ON r_parent.oid = m.roleid
                WHERE
                    r_child.rolname = '{user}';

            """

            role_parent, is_superuser = self.session.execute(q).fetchall()[0]

            if is_superuser:
                result["creation"] = True
                result["deletion"] = True
            elif role_parent == "xdb_admin":
                result["creation"] = True
                result["deletion"] = True

            return result
        except:
            return result

    def can_save_symbology_in_db(self) -> bool:
        """
        Check is database can store assay symbology.
        """

        return True

    def _update_assay(self, assay_definition: AssayDefinition, params: list) -> None:
        """
        Update an assay using SQL.
        """
        for param in params:
            set_clause = []
            values = {}
            for col in assay_definition.columns.keys():
                clause = f"{col}=:{col}"

                if assay_definition.columns[col].series_type in (
                    AssaySeriesType.NOMINAL,
                    AssaySeriesType.CATEGORICAL,
                    AssaySeriesType.SPHERICAL,
                ):
                    if (
                        assay_definition.columns[col].series_type
                        == AssaySeriesType.SPHERICAL
                    ):
                        # for spherical data, set each attribute as parameter
                        if param.get(col) is not None:
                            data = param.get(col)
                            clause = f"{col}=ROW(:azimuth, :dip, :polarity, :type)"
                            values["azimuth"] = data["azimuth"]
                            values["dip"] = data["dip"]
                            values["polarity"] = data["polarity"]
                            values["type"] = data["type"].lower()
                        else:
                            values[col] = param.get(col)
                    else:
                        values[col] = param.get(col)
                else:
                    values[col] = param.get(col)

                if col in param.keys():
                    set_clause.append(clause)

                # uncertainties
                for uncert in assay_definition.columns[
                    col
                ].uncertainty.get_uncertainty_columns():
                    clause = f"{uncert}=:{uncert}"
                    values[uncert] = param.get(uncert)
                    if uncert in param.keys():
                        set_clause.append(clause)

                # detection
                for detection in assay_definition.columns[
                    col
                ].detection_limit.get_detection_columns():
                    clause = f"{detection}=:{detection}"
                    values[detection] = param.get(detection)
                    if detection in param.keys():
                        set_clause.append(clause)

            if len(set_clause) == 0:
                continue
            set_clause = ",".join(set_clause)
            where_clause = f"""hole='{param.get("hole")}'"""
            if assay_definition.domain == AssayDomainType.TIME:
                where_clause += f""" AND x='{param.get("x").to_pydatetime()}'"""
            else:
                where_clause += f""" AND x={param.get("x")}"""

            if assay_definition.data_extent == AssayDataExtent.EXTENDED:
                if assay_definition.domain == AssayDomainType.TIME:
                    where_clause += (
                        f""" AND x_end='{param.get("x_end").to_pydatetime()}'"""
                    )
                else:
                    where_clause += f""" AND x_end={param.get("x_end")}"""

            q = f"""UPDATE {self.schema}.{assay_definition.variable} SET {set_clause} WHERE {where_clause};"""
            self.session.execute(q, values)

        self.commit()
