from collections.abc import Iterable
from math import isinf, isnan
from typing import ClassVar

from forgeo.core import InterpolationMethod, Interpolator, Neighborhood, Variogram
from qgis.PyQt.QtCore import Qt
from qgis.PyQt.QtGui import QDoubleValidator, QIntValidator
from qgis.PyQt.QtWidgets import (
    QComboBox,
    QDialog,
    QGroupBox,
    QHBoxLayout,
    QLabel,
    QLineEdit,
    QPushButton,
    QRadioButton,
    QStackedWidget,
    QVBoxLayout,
    QWidget,
)

from forgeo.io.xml import deep_copy


class VariogramParametersGroupBox(QGroupBox):
    def __init__(
        self,
        authorized_models: Iterable[str],
        variogram: Variogram,
        *,
        show_model: bool = True,
        show_range: bool = True,
        show_sill: bool = True,
        show_nugget: bool = True,
        parent=None,
    ):
        """
        Parameters
        ----------
            authorized_models: Iterable[str]
                List of variogram models to propose in the combobox
            variogram: (optional) pile.Variogram
                Variogram object used to re-open the widget with already set values
            show_model: (optional) bool
                Whether to show the model selection combobox in the widget
            show_range: (optional) bool
                Whether to show the range line edit in the widget
            show_sill: (optional) bool
                Whether to show the sill line edit in the widget
            show_nugget: (optional) bool
                Whether to show the nugget line edit in the widget
        """
        super().__init__(parent)
        self.setTitle(self.tr("Variogram parameters"))
        main_layout = QVBoxLayout()
        self.setLayout(main_layout)
        # Variogram model
        assert authorized_models
        cbox_model = QComboBox()
        self.cbox_model = cbox_model
        cbox_model.addItems(authorized_models)
        model = variogram.model
        if model in authorized_models:
            self.cbox_model.setCurrentIndex(self.cbox_model.findText(model))
        else:
            # Handles the following case: switching form elevation kriging to
            # potential field, with a none sperical variogram model
            self.cbox_model.setCurrentIndex(0)
        if show_model:
            layout = QHBoxLayout()
            layout.addWidget(QLabel(self.tr("Variogram model")))
            layout.addWidget(cbox_model)
            main_layout.addLayout(layout)

        validator = QDoubleValidator()
        validator.setBottom(0)
        # Variogram range
        edt_range = QLineEdit()
        edt_range.setValidator(validator)
        edt_range.setText(_to_str(variogram.range))
        self.edt_range = edt_range
        if show_range:
            layout = QHBoxLayout()
            layout.addWidget(QLabel(self.tr("Range")))
            layout.addStretch(1)
            layout.addWidget(edt_range)
            main_layout.addLayout(layout)
        # Variogram sill
        edt_sill = QLineEdit()
        edt_sill.setValidator(validator)
        edt_sill.setText(_to_str(variogram.sill))
        self.edt_sill = edt_sill
        if show_sill:
            layout = QHBoxLayout()
            layout.addWidget(QLabel(self.tr("Sill")))
            layout.addStretch(1)
            layout.addWidget(edt_sill)
            main_layout.addLayout(layout)
        # Variogram nugget
        edt_nugget = QLineEdit()
        edt_nugget.setValidator(validator)
        edt_nugget.setText(_to_str(variogram.nugget))
        self.edt_nugget = edt_nugget
        if show_nugget:
            layout = QHBoxLayout()
            layout.addWidget(QLabel(self.tr("Nugget")))
            layout.addStretch(1)
            layout.addWidget(edt_nugget)
            main_layout.addLayout(layout)

    @property
    def model(self):
        return self.cbox_model.currentText()

    @property
    def range(self):
        return _get_value(self.edt_range, float)

    @property
    def sill(self):
        return _get_value(self.edt_sill, float)

    @property
    def nugget(self):
        return _get_value(self.edt_nugget, float)

    def clear(self, default: Variogram):
        self.cbox_model.setCurrentIndex(self.cbox_model.findText(default.model))
        self.edt_range.setText(_to_str(default.range))
        self.edt_sill.setText(_to_str(default.sill))
        self.edt_nugget.setText(_to_str(default.nugget))


class DriftSelectionGroupBox(QGroupBox):
    MAX_DRIFT_ORDER_ALLOWED = 3  # Note: limiting order to 3 is an arbitrary choice
    # TODO This widget might be updated in the future to handle external drifts too

    def __init__(self, max_order: int, parent=None):
        """
        Parameters
        ----------
            max_order: (optional) int | str
                Value to set in the drift order combobox. Typically used to re-open
                the widget. Accepted values: `[0, 1, 2, 3]`
        """
        super().__init__(parent)
        txt_lbl = self.tr("Drift order")
        self.setTitle(txt_lbl)
        assert 0 <= max_order <= self.MAX_DRIFT_ORDER_ALLOWED
        cbox = QComboBox()
        cbox.addItems([str(i) for i in range(self.MAX_DRIFT_ORDER_ALLOWED + 1)])
        value = str(max_order)
        idx = cbox.findText(value)
        cbox.setCurrentIndex(idx)
        self.cbox_order = cbox
        main_layout = QHBoxLayout()
        main_layout.addWidget(QLabel(txt_lbl))
        main_layout.addWidget(cbox)
        self.setLayout(main_layout)

    @property
    def order(self):
        return int(self.cbox_order.currentText())

    def clear(self, default):
        idx = self.cbox_order.findText(str(default))
        self.cbox_order.setCurrentIndex(idx)


class MovingNeighborhoodGroupBox(QGroupBox):
    def __init__(self, neigh: Neighborhood, parent=None):
        # TODO Take as input a pile.Neighborhood, once it will be implemented
        super().__init__(parent)
        self.setTitle(self.tr("Neighborhood"))
        # Note: for later, to allow for unique neighborhood
        # self.setCheckable(True)
        # self.setChecked(False)
        # Initialize
        validator_int = QIntValidator()
        validator_int.setBottom(1)
        validator_double = QDoubleValidator()
        validator_double.setBottom(0)
        main_layout = QVBoxLayout()
        self.setLayout(main_layout)

        # Add all widgets
        def _add_hlayout(label, default_value, validator, placeholder=None):
            edit = QLineEdit()
            edit.setValidator(validator)
            edit.setText(default_value)
            if placeholder is not None:
                edit.setPlaceholderText(self.tr(placeholder))
            layout = QHBoxLayout()
            layout.addWidget(QLabel(self.tr(label)))
            layout.addStretch(1)
            layout.addWidget(edit)
            main_layout.addLayout(layout)
            return edit

        self.edt_max_distance = _add_hlayout(
            "Max search distance",
            _to_str(neigh.max_search_distance),
            validator_double,
            "Defaults to infinite search distance",
        )
        self.edt_min_neighs = _add_hlayout(
            "Min neighbors", _to_str(neigh.nb_min_neighbors), validator_int
        )
        self.edt_max_neighs = _add_hlayout(
            "Max neighbors", _to_str(neigh.nb_max_neighbors), validator_int
        )
        self.edt_nb_sectors = _add_hlayout(
            "Number of angular sectors",
            _to_str(neigh.nb_angular_sectors),
            validator_int,
        )
        self.edt_max_neighs_per_sectors = _add_hlayout(
            "Max neighbors per sector",
            _to_str(neigh.nb_max_neighbors_per_sector),
            validator_int,
        )

    # def is_unique(self):
    #     return not self.isChecked()

    @property
    def max_search_distance(self):
        return _get_value(self.edt_max_distance, float)

    @property
    def min_nb_neighbors(self):
        return _get_value(self.edt_min_neighs, int)

    @property
    def max_nb_neighbors(self):
        return _get_value(self.edt_max_neighs, int)

    @property
    def nb_angular_sectors(self):
        return _get_value(self.edt_nb_sectors, int)

    @property
    def max_neighbors_per_sector(self):
        return _get_value(self.edt_max_neighs_per_sectors, int)

    def clear(self, default: Neighborhood):
        self.edt_max_distance.setText("")  # Default is nan...
        self.edt_min_neighs.setText(str(default.nb_min_neighbors))
        self.edt_max_neighs.setText(str(default.nb_max_neighbors))
        self.edt_nb_sectors.setText(str(default.nb_angular_sectors))
        self.edt_max_neighs_per_sectors.setText(
            str(default.nb_max_neighbors_per_sector)
        )


class PotentialMethodParametersWidget(QWidget):
    METHOD = InterpolationMethod.POTENTIAL
    FORCED_MODEL = "spherical"
    FORCED_SILL = 1.0
    DEFAULT_DRIFT_ORDER = 1

    def __init__(self, interpolator: Interpolator, parent=None):
        super().__init__(parent)
        # Get interpolator initial values, if any
        if (variograms := interpolator.variograms) is not None:
            variogram = variograms[0]
        else:
            variogram = self._default_variogram()
        drift_order = interpolator.drift_order
        if drift_order is None:
            drift_order = self.DEFAULT_DRIFT_ORDER
        self.variogram = VariogramParametersGroupBox(
            [self.FORCED_MODEL],
            variogram,
            show_model=False,
            show_sill=False,
        )
        self.drift = DriftSelectionGroupBox(max_order=drift_order)
        main_layout = QVBoxLayout()
        main_layout.addWidget(self.variogram)
        main_layout.addWidget(self.drift)
        main_layout.setContentsMargins(0, 0, 0, 0)
        self.setLayout(main_layout)

    def clear(self):
        self.variogram.clear(self._default_variogram())
        self.drift.clear(self.DEFAULT_DRIFT_ORDER)

    def update_interpolator(self, interpolator):
        # FIXME Use pile.Interpolator enum once available
        interpolator.method = self.METHOD
        variogram = self.variogram
        interpolator.variograms = [
            Variogram(
                self.FORCED_MODEL,
                range=variogram.range,
                sill=self.FORCED_SILL,
                nugget=variogram.nugget,
            )
        ]
        interpolator.drift_order = self.drift.order

    @classmethod
    def _default_variogram(cls):
        # Use class method rather than attribute to ensure this cannot be modified
        return Variogram(cls.FORCED_MODEL, None, cls.FORCED_SILL, 0.0)


class ElevationKrigingParametersWidget(QWidget):
    METHOD = InterpolationMethod.ELEVATION_KRIGING
    AUTHORIZED_MODELS: ClassVar[list[str]] = [
        "spherical",
        "exponential",
        "cubic",
        "linear",
    ]
    DEFAULT_MODEL = "spherical"
    DEFAULT_DRIFT_ORDER = 1

    def __init__(self, interpolator: Interpolator, parent=None):
        super().__init__(parent)
        main_layout = QVBoxLayout()
        if (variograms := interpolator.variograms) is not None:
            variogram = variograms[0]
        else:
            variogram = self._default_variogram()
        drift_order = interpolator.drift_order
        if drift_order is None:
            drift_order = self.DEFAULT_DRIFT_ORDER
        neighborhood = interpolator.neighborhood
        if neighborhood is None:
            neighborhood = self._default_neighborhood()
        # Sub-widgets
        self.variogram = VariogramParametersGroupBox(self.AUTHORIZED_MODELS, variogram)
        self.drift = DriftSelectionGroupBox(max_order=drift_order)
        self.neighborhood = MovingNeighborhoodGroupBox(neighborhood)
        # Layout
        main_layout.addWidget(self.variogram)
        main_layout.addWidget(self.drift)
        main_layout.addWidget(self.neighborhood)
        main_layout.setContentsMargins(0, 0, 0, 0)
        self.setLayout(main_layout)

    def clear(self):
        self.variogram.clear(self._default_variogram())
        self.drift.clear(self.DEFAULT_DRIFT_ORDER)
        self.neighborhood.clear(self._default_neighborhood())

    def update_interpolator(self, interpolator):
        # FIXME Use pile.Interpolator enum once available
        interpolator.method = self.METHOD
        variogram = self.variogram
        interpolator.variograms = [
            Variogram(
                variogram.model,
                range=variogram.range,
                sill=variogram.sill,
                nugget=variogram.nugget,
            )
        ]
        interpolator.drift_order = self.drift.order
        neigh = self.neighborhood
        interpolator.neighborhood = Neighborhood.create_moving(
            max_search_distance=neigh.max_search_distance,
            nb_max_neighbors=neigh.max_nb_neighbors,
            nb_min_neighbors=neigh.min_nb_neighbors,
            nb_angular_sectors=neigh.nb_angular_sectors,
            nb_max_neighbors_per_sector=neigh.max_neighbors_per_sector,
        )

    @classmethod
    def _default_variogram(cls):
        # Use class method rather than attribute to ensure this cannot be modified
        return Variogram(cls.DEFAULT_MODEL, None, None, 0.0)

    @classmethod
    def _default_neighborhood(cls):
        # Use class method rather than attribute to ensure this cannot be modified
        return Neighborhood.create_moving(
            nb_max_neighbors=48,
            nb_min_neighbors=3,
            nb_angular_sectors=8,
            nb_max_neighbors_per_sector=3,
        )


class InterpolationParametersDialog(QDialog):
    MAP_METHOD_TO_STACKED_WIDGET: ClassVar[dict[InterpolationMethod, int]] = {
        InterpolationMethod.POTENTIAL: 0,
        InterpolationMethod.ELEVATION_KRIGING: 1,
    }

    def __init__(self, interpolator, parent=None):
        super().__init__(parent)
        self.interpolator = interpolator
        self.set_interpolator_working_copy()
        # Check interpolator consistency
        method = interpolator.method
        assert method in self.MAP_METHOD_TO_STACKED_WIDGET

        self.setWindowTitle(self.tr("Interpolation parameterization"))
        lbl_items = "Interpolated item(s): " + ", ".join(
            item.name for item in interpolator.dataset
        )
        lbl_items = QLabel(lbl_items)
        lbl_items.setWordWrap(True)
        # Groupbox "Method selection"
        gbox_method_selection = QGroupBox()
        gbox_method_selection.setTitle(self.tr("Interpolation method"))
        gbox_layout = QHBoxLayout()
        btn_potential = QRadioButton(self.tr("Potential method"))
        btn_potential.clicked.connect(lambda: self.stack_parameters.setCurrentIndex(0))
        btn_kriging = QRadioButton(self.tr("Elevation map (kriging)"))
        btn_kriging.clicked.connect(self.set_elevation_kriging)
        gbox_layout.addWidget(btn_potential)
        gbox_layout.addWidget(btn_kriging)
        gbox_method_selection.setLayout(gbox_layout)
        # Interpolation parameters
        stack_parameters = QStackedWidget()
        self.wgt_potential = PotentialMethodParametersWidget(interpolator)
        stack_parameters.addWidget(self.wgt_potential)
        self.wgt_kriging = ElevationKrigingParametersWidget(interpolator)
        stack_parameters.addWidget(self.wgt_kriging)
        # Select the correct stacked widget
        stack_parameters.setCurrentIndex(self.MAP_METHOD_TO_STACKED_WIDGET[method])
        if method == PotentialMethodParametersWidget.METHOD:
            btn_potential.setChecked(True)
        else:
            assert method == ElevationKrigingParametersWidget.METHOD
            btn_kriging.setChecked(True)
        self.stack_parameters = stack_parameters
        # Load
        if interpolator.method == InterpolationMethod.POTENTIAL:
            btn_potential.setChecked(True)
            self.set_potential_method()
        elif interpolator.method == InterpolationMethod.ELEVATION_KRIGING:
            btn_kriging.setChecked(True)
            self.set_elevation_kriging()
        # Dialog buttons
        btn_clear = QPushButton("Clear")
        btn_clear.clicked.connect(self.clear)
        btn_cancel = QPushButton("Cancel")
        btn_cancel.clicked.connect(self.reject)
        btn_ok = QPushButton("OK")
        btn_ok.clicked.connect(self.accept)
        layout_buttons = QHBoxLayout()
        layout_buttons.addWidget(btn_clear)
        layout_buttons.addStretch(1)
        layout_buttons.addWidget(btn_cancel)
        layout_buttons.addWidget(btn_ok)
        # Signals
        self.accepted.connect(lambda: self.on_accept())
        self.finished.connect(self.deleteLater)
        # Layout
        main_layout = QVBoxLayout()
        main_layout.addWidget(lbl_items, Qt.AlignmentFlag.AlignLeft)
        main_layout.addWidget(gbox_method_selection)
        main_layout.addWidget(stack_parameters)
        main_layout.addLayout(layout_buttons)
        self.setLayout(main_layout)
        self.setMinimumWidth(300)
        self.resize(300, 0)

    def set_interpolator_working_copy(self):
        assert self.interpolator is not None
        self.interpolator = deep_copy(self.interpolator)

    def set_potential_method(self):
        self.stack_parameters.setCurrentIndex(0)
        # FIXME Really useful
        self.interpolator.method = InterpolationMethod.POTENTIAL

    def set_elevation_kriging(self):
        self.stack_parameters.setCurrentIndex(1)
        # FIXME Really useful
        self.interpolator.method = InterpolationMethod.ELEVATION_KRIGING

    def clear(self):
        # Clear UI
        for i in range(self.stack_parameters.count()):
            w = self.stack_parameters.widget(i)
            if w is not None:
                w.clear()

    def on_accept(self):
        self.stack_parameters.currentWidget().update_interpolator(self.interpolator)


def _get_value(line_edit, type_):
    return type_(v) if (v := line_edit.text()) else None


def _to_str(v):
    if v is None or isnan(v) or isinf(v):
        return ""
    return str(v)
