import json

from sqlalchemy import text

from openlog.core.geo_extractor import SpatialiteGeoExtractor
from openlog.core.trace_splitter import SpatialiteSplitTracesQueries
from openlog.datamodel.assay.categories import CategoriesTableDefinition
from openlog.datamodel.assay.generic_assay import (
    AssayColumn,
    AssayDatabaseDefinition,
    AssayDataExtent,
    AssayDefinition,
    AssayDomainType,
    AssaySeriesType,
    GenericAssay,
)
from openlog.datamodel.connection.interfaces.categories_interface import (
    CategoriesInterface,
)
from openlog.datamodel.connection.sqlalchemy.sqlalchemy_assay_interface import (
    SqlAlchemyAssayInterface,
)


class SpatialiteAssayInterface(SqlAlchemyAssayInterface):
    def __init__(self, engine, session, categories_iface: CategoriesInterface):
        """
        Implement AssayInterface for a Spatialite db

        Args:
            engine: sqlalchemy engine
            session: sqlalchemy session created from engine
            categories_iface : CategoriesInterface to get default lith category name
        """

        self._categories_iface = categories_iface
        super().__init__(engine, session)
        self.trace_splitter = SpatialiteSplitTracesQueries
        self.geo_extractor = SpatialiteGeoExtractor

    def alter_assay(self, assay_definition):

        old_def = [
            assay_def
            for assay_def in self.get_all_available_assay_definitions()
            if assay_def.variable == assay_definition.variable
        ][0]
        to_delete = [
            assay_column
            for name, assay_column in old_def.columns.items()
            if name not in assay_definition.columns.keys()
        ]
        to_add = [
            assay_column
            for name, assay_column in assay_definition.columns.items()
            if name not in old_def.columns.keys()
        ]
        if len(to_delete) > 0:
            self.delete_assay_from_database(
                variable=assay_definition.variable, only_splitted=True
            )
            for assay_column in to_delete:
                self._remove_assay_column(assay_definition.variable, assay_column)

        if len(to_add) > 0:
            self.delete_assay_from_database(
                variable=assay_definition.variable, only_splitted=True
            )
            for assay_column in to_add:
                self._add_assay_column(assay_definition.variable, assay_column)

    def _add_assay_column(self, variable, assay_column):

        # add column to assay table
        if assay_column.series_type in (
            AssaySeriesType.CATEGORICAL,
            AssaySeriesType.NOMINAL,
            AssaySeriesType.SPHERICAL,
        ):
            type_ = "TEXT"
        elif assay_column.series_type in (
            AssaySeriesType.NUMERICAL,
            AssaySeriesType.POLAR,
        ):
            type_ = "REAL"
        elif assay_column.series_type in (AssaySeriesType.DATETIME):
            type_ = "DATE"
        elif assay_column.series_type in (AssaySeriesType.IMAGERY):
            type_ = "BLOB"

        q = f"ALTER TABLE {variable} ADD COLUMN {assay_column.name} {type_};"
        self.session.execute(q)
        self.session.commit()

        if assay_column.series_type == AssaySeriesType.NUMERICAL:
            for other_col in assay_column.uncertainty.get_uncertainty_columns():
                q = f"ALTER TABLE {variable} ADD COLUMN {other_col} REAL;"
                self.session.execute(q)
                self.session.commit()
            for other_col in assay_column.detection_limit.get_detection_columns():
                q = f"ALTER TABLE {variable} ADD COLUMN {other_col} REAL;"
                self.session.execute(q)
                self.session.commit()

        # add column to assay definitions
        q = f"""INSERT INTO assay_column VALUES('{assay_column.name}',
        '{variable}',
        '{assay_column.series_type.value}',
        '{assay_column.unit}',
        '{assay_column.category_name}',
        '{assay_column.uncertainty.upper_box_column}',
        '{assay_column.uncertainty.lower_box_column}',
        '{assay_column.uncertainty.upper_whisker_column}',
        '{assay_column.uncertainty.lower_whisker_column}',
        '{assay_column.image_format_col}',
        '{assay_column.detection_limit.detection_min_col}',
        '{assay_column.detection_limit.detection_max_col}',
        '{assay_column.display_name}',
        NULL);"""

        self.session.execute(q)
        self.session.commit()

        q = f"INSERT INTO assay_column_definition VALUES('{assay_column.name}', '{variable}', '{assay_column.name}');"
        self.session.execute(q)
        self.session.commit()

        # add new category table if needed
        if assay_column.series_type == AssaySeriesType.CATEGORICAL:
            categorie_table_names = [
                table.name
                for table in self._categories_iface.get_available_categories_table()
            ]
            if assay_column.category_name not in categorie_table_names:
                category_def = CategoriesTableDefinition(
                    name=assay_column.category_name,
                    table_name=assay_column.category_name,
                    schema=self.schema,
                )
                self._categories_iface.import_categories_table([category_def])

    def _remove_assay_column(self, variable: str, assay_column: AssayColumn):

        # add uncertainties and detection limits
        columns_to_drop = [assay_column.name]
        for uncertainty_col in assay_column.uncertainty.get_uncertainty_columns():
            columns_to_drop.append(uncertainty_col)
        for detection_col in assay_column.detection_limit.get_detection_columns():
            columns_to_drop.append(detection_col)

        # delete column
        for column in columns_to_drop:
            q = f"ALTER TABLE {variable} DROP COLUMN {column};"
            self.session.execute(q)
            self.session.commit()

        # delete reference in assay definitions
        q = f'DELETE FROM assay_column WHERE assay = "{variable}" and name = "{assay_column.name}";'
        self.session.execute(q)
        self.session.commit()

        q = f'DELETE FROM assay_column_definition WHERE variable = "{variable}" and name = "{assay_column.name}";'
        self.session.execute(q)
        self.session.commit()

    def delete_assay_from_database(self, variable: str, only_splitted: bool = False):

        # new format : one layer by table
        splitted_layers = []
        trigger_names = []
        for geom in ["trace", "planned_trace"]:
            l = "_".join([variable, geom])
            splitted_layers.append(l)
            for op in ["update", "insert", "delete"]:
                for table in ["collar", "assay"]:
                    s = "_".join([op, table, variable, geom])
                    trigger_names.append(s)

        # get splitted layers (old fashion, one layer by column)
        q = f'SELECT name FROM assay_column WHERE assay = "{variable}" AND series_type IN ("numerical", "categorical");'
        splittable_columns = self.session.execute(q).fetchall()
        splittable_columns = [elt[0] for elt in splittable_columns]

        for column in splittable_columns:
            for geom in ["trace", "planned_trace"]:
                l = "_".join([variable, column, geom])
                splitted_layers.append(l)
                for op in ["update", "insert", "delete"]:
                    for table in ["collar", "assay"]:
                        s = "_".join([op, table, variable, column, geom])
                        trigger_names.append(s)

        # delete triggers
        for trigger in trigger_names:
            q = f'DROP TRIGGER IF EXISTS "{trigger}";'
            self.session.execute(q)
            self.session.commit()

        # delete trace layers
        for layer in splitted_layers:
            q = f"SELECT DropTable(NULL, '{layer}', 1) ;"
            self.session.execute(q)
            self.session.commit()

        if only_splitted:
            return

        # delete assay table
        q = f"DROP TABLE IF EXISTS {variable};"
        self.session.execute(q)
        self.session.commit()
        # delete assay definitions
        q = f'DELETE FROM assay_column WHERE assay = "{variable}";'
        self.session.execute(q)
        self.session.commit()

        q = f'DELETE FROM assay_column_definition WHERE variable = "{variable}";'
        self.session.execute(q)
        self.session.commit()

        q = f'DELETE FROM assay_definition WHERE table_name = "{variable}";'
        self.session.execute(q)
        self.session.commit()

        q = f'DELETE FROM assays WHERE variable = "{variable}";'
        self.session.execute(q)
        self.session.commit()

    def can_administrate_assays(self) -> dict:

        return {"creation": True, "deletion": True}

    def can_save_symbology_in_db(self) -> bool:
        """
        Check is database can store assay symbology.
        """

        return True

    def _convert_spherical_to_dict(self, param: str):

        return json.loads(param)

    def _update_assay(self, assay_definition: AssayDefinition, params: list) -> None:
        """
        Update an assay using SQL.
        """
        # delete splitted layers because triggers will slow massively update (spatialite caveat)
        self.delete_assay_from_database(
            variable=assay_definition.variable, only_splitted=True
        )
        for param in params:
            set_clause = []
            values = {}
            # main columns
            for col in assay_definition.columns.keys():
                clause = f"{col}=:{col}"

                if assay_definition.columns[col].series_type in (
                    AssaySeriesType.NOMINAL,
                    AssaySeriesType.CATEGORICAL,
                    AssaySeriesType.SPHERICAL,
                ):

                    if (
                        assay_definition.columns[col].series_type
                        == AssaySeriesType.SPHERICAL
                    ):
                        if param.get(col) is not None:
                            values[col] = json.dumps(param.get(col))
                        else:
                            values[col] = param.get(col)
                    else:
                        values[col] = param.get(col)

                else:
                    values[col] = param.get(col)

                if col in param.keys():
                    set_clause.append(clause)

                # uncertainties
                for uncert in assay_definition.columns[
                    col
                ].uncertainty.get_uncertainty_columns():
                    clause = f"{uncert}=:{uncert}"
                    values[uncert] = param.get(uncert)
                    if uncert in param.keys():
                        set_clause.append(clause)

                # detection
                for detection in assay_definition.columns[
                    col
                ].detection_limit.get_detection_columns():
                    clause = f"{detection}=:{detection}"
                    values[detection] = param.get(detection)
                    if detection in param.keys():
                        set_clause.append(clause)

            if len(set_clause) == 0:
                continue
            set_clause = ",".join(set_clause)
            where_clause = f"""hole='{param.get("hole")}'"""
            if assay_definition.domain == AssayDomainType.TIME:
                where_clause += f""" AND x='{param.get("x").isoformat(sep=' ', timespec='microseconds')}'"""
            else:
                where_clause += f""" AND x={param.get("x")}"""
            if assay_definition.data_extent == AssayDataExtent.EXTENDED:
                if assay_definition.domain == AssayDomainType.TIME:
                    where_clause += f""" AND x_end='{param.get("x_end").isoformat(sep=' ', timespec='microseconds')}'"""
                else:
                    where_clause += f""" AND x_end={param.get("x_end")}"""

            q = text(
                f"UPDATE {assay_definition.variable} SET {set_clause} WHERE {where_clause}"
            )
            self.session.execute(q, values)

        self.commit()
