import sqlalchemy
from qgis.core import QgsCsException
from qgis.PyQt.QtWidgets import QLabel, QMessageBox, QVBoxLayout, QWidget, QWizardPage
from xplordb.datamodel.collar import Collar
from xplordb.datamodel.metadata import RawCollarMetadata

from openlog.datamodel.assay.generic_assay import AssaySeriesType
from openlog.datamodel.connection.openlog_connection import OpenLogConnection
from openlog.gui.import_assay.utils import sanitize_sql_name
from openlog.gui.utils.column_definition import ColumnDefinition
from openlog.gui.utils.delimited_text_import_widget import DelimitedTextImportWidget

BASE_SETTINGS_KEY = "/OpenLog/gui/import/collar"


class CollarsImportPageWizard(QWizardPage):
    def __init__(self, parent: QWidget, openlog_connection: OpenLogConnection) -> None:
        """
        QWizard to import collars into xplordb from csv file

        Args:
            openlog_connection: OpenLogConnection used to import collar
            parent : QWidget parent
        """
        super().__init__(parent)

        self._openlog_connection = openlog_connection
        self.setTitle("Collar import")

        self.HOLE_ID_COL = "HoleID*"
        self.X_COL = "Easting*"
        self.Y_COL = "Northing*"
        self.Z_COL = "Elevation*"
        self.EOH_COL = "EOH"
        self.PLANNED_X_COL = "Pld. East."
        self.PLANNED_Y_COL = "Pld. North."
        self.PLANNED_Z_COL = "Pld. Eleva."
        self.PLANNED_EOH_COL = "Pld. EOH"
        self.DIP = "Dip"
        self.AZIMUTH = "Azimuth"

        label = QLabel("Select a .csv file or a QGIS layer in Source table section.")
        label.setWordWrap(True)

        layout = QVBoxLayout()
        layout.addWidget(label)
        self.dataset_edit = DelimitedTextImportWidget(
            self, self._openlog_connection, "collar"
        )
        self.dataset_edit.column_mapping_model.set_nodata_readonly(
            [i for i in range(11)]
        )
        self.dataset_edit.enable_crs_selection(True)
        self.dataset_edit.enable_elevation_from_dtm(True)
        self.dataset_edit.set_coordinates_columns(
            self.X_COL,
            self.Y_COL,
            self.Z_COL,
            self.PLANNED_X_COL,
            self.PLANNED_Y_COL,
            self.PLANNED_Z_COL,
            self.DIP,
            self.AZIMUTH,
        )
        # connect to red flag signal to enable/disable next button
        self.dataset_edit.red_flag_signal.connect(self.completeChanged.emit)

        self.dataset_edit.set_column_definition(
            [
                ColumnDefinition(
                    column=self.HOLE_ID_COL,
                    fixed=True,
                    series_type=AssaySeriesType.NOMINAL,
                ),
                ColumnDefinition(
                    column=self.X_COL,
                    unit="m",
                    fixed=True,
                    series_type=AssaySeriesType.NUMERICAL,
                ),
                ColumnDefinition(
                    column=self.Y_COL,
                    unit="m",
                    fixed=True,
                    series_type=AssaySeriesType.NUMERICAL,
                ),
                ColumnDefinition(
                    column=self.Z_COL,
                    unit="m",
                    fixed=True,
                    series_type=AssaySeriesType.NUMERICAL,
                ),
                ColumnDefinition(
                    column=self.EOH_COL,
                    unit="m",
                    fixed=True,
                    optional=True,
                    series_type=AssaySeriesType.NUMERICAL,
                ),
                ColumnDefinition(
                    column=self.PLANNED_X_COL,
                    unit="m",
                    fixed=True,
                    optional=True,
                    series_type=AssaySeriesType.NUMERICAL,
                ),
                ColumnDefinition(
                    column=self.PLANNED_Y_COL,
                    unit="m",
                    fixed=True,
                    optional=True,
                    series_type=AssaySeriesType.NUMERICAL,
                ),
                ColumnDefinition(
                    column=self.PLANNED_Z_COL,
                    unit="m",
                    fixed=True,
                    optional=True,
                    series_type=AssaySeriesType.NUMERICAL,
                ),
                ColumnDefinition(
                    column=self.PLANNED_EOH_COL,
                    unit="m",
                    fixed=True,
                    optional=True,
                    series_type=AssaySeriesType.NUMERICAL,
                ),
                ColumnDefinition(
                    column=self.DIP,
                    unit="°",
                    fixed=True,
                    optional=True,
                    series_type=AssaySeriesType.NUMERICAL,
                ),
                ColumnDefinition(
                    column=self.AZIMUTH,
                    unit="°",
                    fixed=True,
                    optional=True,
                    series_type=AssaySeriesType.NUMERICAL,
                ),
            ]
        )
        layout.addWidget(self.dataset_edit)
        self.setLayout(layout)
        self.dataset_edit.button_frame.show()
        self.dataset_edit.restore_settings(BASE_SETTINGS_KEY)

    def isComplete(self) -> bool:
        """
        Override of QWizardPage.isComplete method to enable next button only if there is no red flag.
        """
        return not self.dataset_edit.red_flag

    def data_label(self) -> str:
        """
        Returns label to be used in confirmation dialog

        Returns: imported data label

        """
        return "Collars"

    def data_count(self) -> int:
        """
        Returns expected imported data count to be displayed in confirmation dialog

        Returns: expected imported data count

        """
        df = self.dataset_edit.get_dataframe()
        return df.shape[0] if df is not None else 0

    def import_data(self):
        """
        Import data into openlog database.

        OpenLogConnection.ImportData exception can be raised.

        """
        df = self.dataset_edit.get_dataframe()
        if df is not None:
            collars = [
                Collar(
                    hole_id=r[self.HOLE_ID_COL],
                    data_set=self.field("dataset"),
                    loaded_by=self.field("person"),
                    x=r[self.X_COL],
                    y=r[self.Y_COL],
                    z=r[self.Z_COL] if r[self.Z_COL] is not None else 0.0,
                    srid=self.dataset_edit.crs().postgisSrid(),
                    project_srid=self._openlog_connection.default_srid,
                    eoh=r[self.EOH_COL],
                    planned_x=r[self.PLANNED_X_COL]
                    if r[self.PLANNED_X_COL]
                    else r[self.X_COL],
                    planned_y=r[self.PLANNED_Y_COL]
                    if r[self.PLANNED_Y_COL]
                    else r[self.Y_COL],
                    planned_z=r[self.PLANNED_Z_COL]
                    if r[self.PLANNED_Z_COL]
                    else r[self.Z_COL],
                    planned_eoh=r[self.PLANNED_EOH_COL],
                    dip=r[self.DIP] if r[self.DIP] else -90.0,
                    azimuth=r[self.AZIMUTH] if r[self.AZIMUTH] else 0.0,
                )
                for index, r in df.iterrows()
            ]

            # find extra columns
            default_cols = [
                self.HOLE_ID_COL,
                self.X_COL,
                self.Y_COL,
                self.Z_COL,
                self.EOH_COL,
                self.PLANNED_X_COL,
                self.PLANNED_Y_COL,
                self.PLANNED_Z_COL,
                self.PLANNED_EOH_COL,
                self.DIP,
                self.AZIMUTH,
            ]
            extra_cols = [col for col in df.columns if col not in default_cols]

            # set values into RawCollarMetadata instances
            metadatas = []
            for collar, (_, values_serie) in zip(collars, df[extra_cols].iterrows()):
                d = {}
                for col in extra_cols:
                    d[sanitize_sql_name(col)] = values_serie[col]
                # lowercase for column name
                d = {key.lower(): value for key, value in d.items()}
                metadatas.append(
                    RawCollarMetadata(hole_id=collar.hole_id, extra_cols=d)
                )
            try:
                self._openlog_connection.get_write_iface().import_collar(collars)
            except QgsCsException as e:
                self._openlog_connection.rollback()
                raise OpenLogConnection.ImportException(
                    "Coordinate are invalid regarding selected CRS"
                )
            except Exception as e:
                self._openlog_connection.rollback()
                raise OpenLogConnection.ImportException(e)

            try:
                self._openlog_connection.get_write_iface().import_collar_metadata(
                    metadatas
                )
            except sqlalchemy.exc.IntegrityError as e:
                self._openlog_connection.rollback()
                raise OpenLogConnection.ImportException(
                    "Some collars are already in database"
                )
            except Exception as e:
                self._openlog_connection.rollback()
                raise OpenLogConnection.ImportException(e)

    def validatePage(self) -> bool:
        """
        Validate current page content (return always True since data is optional)

        Returns: True

        """
        valid = self.dataset_edit.data_is_valid
        df = self.dataset_edit.get_dataframe()
        if df is not None and not self.dataset_edit.crs().isValid():
            valid = False
            QMessageBox.warning(self, "No CRS defined", "Define imported data CRS.")

        self.dataset_edit.save_setting(BASE_SETTINGS_KEY)

        return valid
