import numpy as np
from forgeo.core import InterpolationMethod, Interpolator, Variogram
from qgis.core import QgsProject, QgsVectorLayer, QgsWkbTypes
from qgis.gui import QgsMapToolEmitPoint
from qgis.PyQt.QtCore import Qt
from qgis.PyQt.QtGui import QColor, QDoubleValidator
from qgis.PyQt.QtWidgets import (
    QCheckBox,
    QComboBox,
    QDialog,
    QDialogButtonBox,
    QGroupBox,
    QHBoxLayout,
    QLabel,
    QLineEdit,
    QMessageBox,
    QPushButton,
    QRadioButton,
    QSizePolicy,
    QStackedWidget,
    QVBoxLayout,
    QWidget,
)
from qgis.utils import iface

from .adddata_widget import AddItemDataDialog


def _text_to_float(edit_line):
    return float(v) if (v := edit_line.text()) else None


class CenterGroupBox(QGroupBox):
    def __init__(self, fault, fault_prop_dialog, network_dialog, parent=None):
        super().__init__(parent)
        self.fault = fault
        self.faults_prop_dialog = fault_prop_dialog
        self.faultnetwork_dialog = network_dialog
        self.setTitle(self.tr("Center"))
        main_layout = QVBoxLayout()
        self.setLayout(main_layout)

        # Buttons
        self.btn_meancenter = QPushButton(self.tr("Set to mean center (default)"))
        self.btn_meancenter.clicked.connect(self.set_mean_center)
        self.btn_datacenter = QPushButton(self.tr("Set to data extent center"))
        self.btn_datacenter.clicked.connect(self.set_data_center)
        btn_selectcenter = QPushButton(self.tr("Point a center"))
        btn_selectcenter.clicked.connect(self.digitize_on_map)

        # Default value : mean center
        ellipsoid = fault.extension
        center = ellipsoid.center if ellipsoid is not None else self.get_mean_center()

        # Add widgets
        self.edt_x_center = _add_hlayout("X", center[0], main_layout)
        self.edt_y_center = _add_hlayout("Y", center[1], main_layout)
        self.edt_z_center = _add_hlayout("Z", center[2], main_layout)
        main_layout.addWidget(self.btn_meancenter)
        main_layout.addWidget(self.btn_datacenter)
        main_layout.addWidget(btn_selectcenter)

    def get_mean_center(self):
        if self.fault.item_data is None:
            if self.sender() == self.btn_meancenter:
                QMessageBox.warning(
                    self,
                    "Missing data",
                    "No data found in this fault to compute a center",
                )
            return ("", "", "")
        return np.mean(self.fault.item_data.all_observations.values, axis=0)

    def set_mean_center(self):
        x, y, z = self.get_mean_center()
        self.edt_x_center.setText(str(x))
        self.edt_y_center.setText(str(y))
        self.edt_z_center.setText(str(z))

    def set_data_center(self):
        if self.fault.item_data is None:
            if self.sender() == self.btn_datacenter:
                QMessageBox.warning(
                    self,
                    "Missing data",
                    "No data found in this fault to compute a center",
                )
            center = ("", "", "")
        else:
            max = np.max(self.fault.item_data.observations.values, 0)
            min = np.min(self.fault.item_data.observations.values, 0)
            center = (max + min) / 2
        self.edt_x_center.setText(str(center[0]))
        self.edt_y_center.setText(str(center[1]))
        self.edt_z_center.setText(str(center[2]))

    def digitize_on_map(self):
        # Initialize the tool to select the point on the map canvas
        tool = QgsMapToolEmitPoint(iface.mapCanvas())

        def get_xyz_coordinates(event):
            # Get the (x,y) coordinates picked by the user
            center = tool.toMapCoordinates(event.pos())
            x = center.x()
            y = center.y()
            # Retrieve the associated z coordinate
            z = 0
            project_mnt = QgsProject.instance().elevationProperties().terrainProvider()
            if project_mnt is not None:
                z_mnt = project_mnt.heightAt(x, y)
                if z_mnt is not None:
                    z = z_mnt
            # Set coordinates
            self.edt_x_center.setText(str(x))
            self.edt_y_center.setText(str(y))
            self.edt_z_center.setText(str(z))
            # Finalize
            iface.mapCanvas().unsetMapTool(tool)
            # Reopen fault properties and fault network edition widgets after the
            # user picked the fault center on the map canvas
            self.faults_prop_dialog.show()
            self.faultnetwork_dialog.show()

        tool.canvasReleaseEvent = get_xyz_coordinates
        # Hide fault properties and fault network edition widgets, so the user
        # can pick the fault center on the map canvas
        self.faults_prop_dialog.hide()
        self.faultnetwork_dialog.hide()
        # Set the oom so the user can pick on the canvas
        iface.mapCanvas().setMapTool(tool)

    @property
    def center(self):
        x = _text_to_float(self.edt_x_center)
        y = _text_to_float(self.edt_y_center)
        z = _text_to_float(self.edt_z_center)
        return (x, y, z)


class RadiusGroupBox(QGroupBox):
    def __init__(self, ellipsoid, parent=None):
        super().__init__(parent)
        self.setTitle(self.tr("Radius"))
        layout = QVBoxLayout()
        self.setLayout(layout)

        # Default values
        radius = ellipsoid.radius if ellipsoid is not None else ("", "", "")

        # Add widgets
        self.edt_strike = _add_hlayout("Along strike", radius[0], layout)
        self.edt_dip_dir = _add_hlayout("Along dip direction", radius[1], layout)
        self.edt_vertical = _add_hlayout("Vertical", radius[2], layout)

    @property
    def radius(self):
        strike = _text_to_float(self.edt_strike)
        dip_dir = _text_to_float(self.edt_dip_dir)
        vertical = _text_to_float(self.edt_vertical)
        return (strike, dip_dir, vertical)


class FiniteFaultParamsWidget(QWidget):
    def __init__(self, fault, fault_prop_dialog, network_dialog, parent=None):
        super().__init__(parent)
        self.fault = fault

        self.center_groupbox = CenterGroupBox(fault, fault_prop_dialog, network_dialog)
        self.radius_groupbox = RadiusGroupBox(fault.extension)

        self.layout = QVBoxLayout()
        self.layout.addWidget(self.center_groupbox)
        self.layout.addWidget(self.radius_groupbox)
        self.setLayout(self.layout)


class FromTraceWidget(QWidget):
    def __init__(self, parent=None):
        super().__init__(parent)

        # Layer selection
        label_layer = QLabel(
            "Select layer :", alignment=Qt.AlignmentFlag.AlignLeft, height=15
        )
        self.layer_cbox = QComboBox()
        self.layer_cbox.currentTextChanged.connect(self.fill_filter_cbox)

        # Filter on
        label_filter = QLabel(
            "Filter on :", alignment=Qt.AlignmentFlag.AlignLeft, height=15
        )
        self.filter_cbox = QComboBox()
        self.filter_cbox.currentTextChanged.connect(self.fill_value_cbox)

        # Value
        label_value = QLabel("Value :", alignment=Qt.AlignmentFlag.AlignLeft, height=15)
        self.value_cbox = QComboBox()

        # Reload
        # TODO
        self.fill_layer_cbox()

        # Layouts
        layout_layer = QHBoxLayout()
        layout_layer.addWidget(label_layer)
        layout_layer.addWidget(self.layer_cbox)
        layout_filter = QHBoxLayout()
        layout_filter.addWidget(label_filter)
        layout_filter.addWidget(self.filter_cbox)
        layout_value = QHBoxLayout()
        layout_value.addWidget(label_value)
        layout_value.addWidget(self.value_cbox)
        layout = QVBoxLayout()
        layout.addWidget(QLabel("Not implemented yet"))
        layout.addLayout(layout_layer)
        layout.addLayout(layout_filter)
        layout.addLayout(layout_value)
        self.setLayout(layout)

    def fill_layer_cbox(self):
        for layer in QgsProject.instance().mapLayers().values():
            if (
                isinstance(layer, QgsVectorLayer)
                and layer.geometryType() == QgsWkbTypes.LineGeometry
            ):
                self.layer_cbox.addItem(layer.name(), userData=layer)
        self.fill_filter_cbox()

    def fill_filter_cbox(self):
        if self.layer_cbox.currentData() is None:
            return
        self.filter_cbox.clear()
        for f in self.layer_cbox.currentData().fields():
            self.filter_cbox.addItem(f.name())
        self.fill_value_cbox()

    def fill_value_cbox(self):
        if self.layer_cbox.currentData() is None:
            return
        self.value_cbox.clear()
        values = set()
        fieldname = self.filter_cbox.currentText()
        if fieldname != "":
            for feature in self.layer_cbox.currentData().getFeatures():
                values.add(str(feature[fieldname]))
        self.value_cbox.addItems(list(values))


class FromInterpolationWidget(QWidget):
    def __init__(self, fault, filter, network_dialog, parent=None):
        super().__init__(parent)
        self.fault = fault
        self.filter = filter
        # FIXME Terrible hack, remove as soon as possible...
        self.network_edition_widget = network_dialog
        # Is it a finite fault ?
        self.finite_fault_params_widget = FiniteFaultParamsWidget(
            self.fault, self.parent(), network_dialog
        )
        self.finite_fault_params_widget.setVisible(False)
        self.is_finite_checkbox = QCheckBox("is finite")
        self.is_finite_checkbox.stateChanged.connect(
            lambda checkstate: self.finite_fault_params_widget.setVisible(
                checkstate == Qt.CheckState.Checked
            )
        )
        if self.fault.extension is not None:
            self.is_finite_checkbox.setCheckState(Qt.CheckState.Checked)

        # Data
        add_data_button = QPushButton("Data")
        add_data_button.clicked.connect(self.add_data)

        # Layouts
        self.layout = QVBoxLayout()
        self.layout.addWidget(add_data_button)
        self.layout.addWidget(self.is_finite_checkbox)
        self.layout.addWidget(self.finite_fault_params_widget)
        self.setLayout(self.layout)

    def add_data(self):
        color = self.fault.info["color"] if self.fault.info is not None else None
        dlg = AddItemDataDialog(self.fault, self.filter, color)

        def _update_item_data(result):
            if result == QDialog.DialogCode.Rejected:
                return
            layer = self.network_edition_widget.layer
            updated_filter = dlg.filter
            self.fault.item_data = updated_filter.extract_selection(layer.crs())
            self.filter = updated_filter

        dlg.finished.connect(_update_item_data)
        dlg.open()


class FaultPropertiesDialog(QDialog):
    """Dialog to define parameters of a fault"""

    def __init__(self, fault, filter, parent=None):
        super().__init__(parent)
        self.fault = fault
        self.filter = filter
        self.interpolator = None

        # Title
        title = QLabel(fault.name, alignment=Qt.AlignmentFlag.AlignCenter)
        title.setFixedHeight(50)
        if (color := fault.info["color"]) is not None:
            color = QColor(color)
            title.setStyleSheet(f"background-color: {color.name()};")

        # Fault creation method
        from_trace_rbutton = QRadioButton("Create from trace")
        from_trace_rbutton.clicked.connect(lambda: self.props_widget.setCurrentIndex(0))
        interpolate_rbutton = QRadioButton("Interpolate")
        interpolate_rbutton.clicked.connect(
            lambda: self.props_widget.setCurrentIndex(1)
        )
        self.props_widget = QStackedWidget()
        self.props_widget.setSizePolicy(
            QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding
        )
        self.props_widget.addWidget(FromTraceWidget())
        # Retrieve parent fault network edition dialog
        # Note: We need it to access the FaultNetworkLayer to get its CRS
        # This hack is dirty, but at least functional...
        network_dialog = self.parent()
        parent = network_dialog.parent()
        while parent is not iface.mainWindow():
            network_dialog = parent
            parent = network_dialog.parent()
        # Note: 'parent=self' is mandatory, as FromInterpolationWidget.__init__
        # explicitly uses self.parent()... Yet another temporary dirty hack
        self.props_widget.addWidget(
            FromInterpolationWidget(fault, filter, network_dialog, parent=self)
        )

        # Reload
        # TODO
        interpolate_rbutton.setChecked(True)
        self.props_widget.setCurrentIndex(1)  # interpolation

        # Buttons
        buttons = QDialogButtonBox(self)
        buttons.setStandardButtons(
            QDialogButtonBox.StandardButton.Cancel | QDialogButtonBox.StandardButton.Ok
        )
        buttons.accepted.connect(self.accept)
        buttons.rejected.connect(self.reject)

        # Signals
        self.accepted.connect(lambda: self.save())
        self.finished.connect(self.deleteLater)

        # Layout
        layout = QVBoxLayout()
        layout.addWidget(title)
        layout_rbutton = QHBoxLayout()
        layout_rbutton.addWidget(from_trace_rbutton)
        layout_rbutton.addWidget(interpolate_rbutton)
        layout.addLayout(layout_rbutton)
        layout.addWidget(self.props_widget)
        layout.addWidget(buttons)
        self.setLayout(layout)

    def save(self):
        # Fault creation method
        fromXwidget = self.props_widget.currentWidget()
        self.fault = fromXwidget.fault
        if isinstance(fromXwidget, FromTraceWidget):
            return
            # TODO
        if isinstance(fromXwidget, FromInterpolationWidget):
            # Finite faults parameters
            if fromXwidget.is_finite_checkbox.checkState() == Qt.CheckState.Checked:
                finite_faults_params_widget = fromXwidget.finite_fault_params_widget
                center = finite_faults_params_widget.center_groupbox.center
                radius = finite_faults_params_widget.radius_groupbox.radius
                if any(v is None for v in [*center, *radius]):
                    QMessageBox.warning(
                        self, "Missing data", "Please enter a center and a range"
                    )
                    return
                self.fault.set_extension(center, radius)
            else:
                self.fault.extension = None
        # Update filter
        self.filter = fromXwidget.filter
        # Interpolation parameters
        def_method = InterpolationMethod.POTENTIAL
        def_model = "spherical"
        def_range = 1000
        if (item_data := self.fault.item_data) is not None:
            obs = item_data.all_observations
            if obs is not None:
                values = obs.values
                if len(values) > 1:
                    def_range = np.linalg.norm(np.max(values, 0) - np.min(values, 0))
        def_sill = 1
        vario = Variogram(def_model, def_range, def_sill)
        self.interpolator = Interpolator(
            def_method, [self.fault], variograms=[vario], drift_order=1
        )


def _add_hlayout(label, value, main_layout):
    validator = QDoubleValidator()
    lbl = QLabel(label)
    edit = QLineEdit()
    edit.setValidator(validator)
    edit.setValidator(validator)
    edit.setText(str(value))
    layout = QHBoxLayout()
    layout.addWidget(lbl)
    layout.addStretch(1)
    layout.addWidget(edit)
    main_layout.addLayout(layout)
    return edit
