from qgis.core import (
    Qgis,
    QgsFeature,
    QgsField,
    QgsGeometry,
    QgsMarkerSymbol,
    QgsPoint,
    QgsProject,
    QgsSingleSymbolRenderer,
    QgsVectorLayer,
    QgsWkbTypes,
)
from qgis.gui import QgsCheckableComboBox
from qgis.PyQt.QtCore import Qt, QVariant
from qgis.PyQt.QtWidgets import (
    QComboBox,
    QGridLayout,
    QHBoxLayout,
    QLabel,
    QMessageBox,
    QPushButton,
    QSizePolicy,
    QVBoxLayout,
    QWidget,
)
from qgis.utils import iface

from ..utils import (
    ItemDataSelection,
    QgisVectorDataFilter,
    apply_item_filters,
    clearlayout,
    input_data_type,
)
from ..utils import ReservedFieldNames as RFN


class PointsFilterWidget(QWidget):
    def __init__(self, item, item_filter, color, parent=None):
        """
        Parameters
        ----------
        item: forgeo.Item
            Read-only: only used for accessing "contextual" information
        item_filter: ItemDataSelection | None
        """
        super().__init__(parent)
        self.color = color

        # Layers selection
        lbl_layers = QLabel(
            "Select layers :", alignment=Qt.AlignmentFlag.AlignLeft, height=15
        )
        self.layers_cbox = QgsCheckableComboBox()
        self.layers_cbox.checkedItemsChanged.connect(self.update_layers_selection)
        self.fill_layers_cbox()
        self.checked_layernames = []

        # Fieldsnames
        self.fields_layout = QGridLayout()
        label = "Contact" if item.is_surface else "Unit"
        self._filter_on_label = label  # Need to keep track for clear()
        self.fields_layout.addWidget(
            QLabel(label, alignment=Qt.AlignmentFlag.AlignCenter), 0, 1
        )
        self.fields_layout.addWidget(
            QLabel("Dip", alignment=Qt.AlignmentFlag.AlignCenter), 0, 2
        )
        self.fields_layout.addWidget(
            QLabel("Dip direction", alignment=Qt.AlignmentFlag.AlignCenter), 0, 3
        )
        self.fields_layout.addWidget(
            QLabel("Reverse polarity", alignment=Qt.AlignmentFlag.AlignCenter), 0, 4
        )
        self.fields_layout.addWidget(
            QLabel("Orientation only", alignment=Qt.AlignmentFlag.AlignCenter), 0, 5
        )
        field_text = QLabel("Fields names of each selected layer :", height=15)
        self.fields_of_layer = {}

        # Values to use
        values_text = QLabel("Values to use :", height=15)
        self.values_cbox = QgsCheckableComboBox()
        self.values_cbox.checkedItemsChanged.connect(self.update_values_selection)
        self.fill_values_cbox()
        self.checked_values = []

        # Reload data
        if input_data_type(item) not in ["noData", "rasterData"]:
            self.load_data(item_filter)

        # Export selected data as individual layer
        btn_export_selection = QPushButton(self.tr("Export selection"))
        btn_export_selection.setSizePolicy(
            QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed
        )

        def extract_selected_data_as_a_new_layer():
            name = f"{item.name} - Selected data"
            layer = self._extract_selection_as_a_new_layer(name)
            if layer is not None:
                QgsProject.instance().addMapLayer(layer)

        btn_export_selection.clicked.connect(extract_selected_data_as_a_new_layer)
        btn_export_selection.setToolTip(
            self.tr("Export the currently selected data as an individual vector layer")
        )
        # Layouts
        layers_selection_layout = QHBoxLayout()
        layers_selection_layout.addWidget(lbl_layers, stretch=1)
        layers_selection_layout.addWidget(self.layers_cbox, stretch=4)
        layout = QVBoxLayout()
        layout.addLayout(layers_selection_layout)
        layout.addWidget(field_text)
        layout.addLayout(self.fields_layout)
        layout.addWidget(values_text)
        layout.addWidget(self.values_cbox)
        layout.addWidget(
            btn_export_selection, Qt.AlignmentFlag.AlignRight
        )  # FIXME On the left...
        self.setLayout(layout)

    def load_data(self, item_filter):
        layernames = []
        self.checked_values = []
        if not item_filter:
            return
        for datafilter in item_filter:
            layer = QgsProject.instance().mapLayer(datafilter.layer_id)
            if layer is None:
                iface.messageBar().pushMessage(
                    "Warning",
                    f"Layer {datafilter.layer_id} not found",
                    level=Qgis.Warning,
                )
                continue
            layernames.append(layer.name())
            # Create widgets
            (
                unit_cbox,
                dip_cbox,
                dipdir_cbox,
                revpolarity_cbox,
                diponly_cbox,
            ) = self.fields_cboxes(layer)
            label = QLabel(
                layer.name(), alignment=Qt.AlignmentFlag.AlignLeft, height=15
            )
            nrow = self.fields_layout.rowCount()
            self.fields_layout.addWidget(label, nrow, 0)
            self.fields_layout.addWidget(unit_cbox, nrow, 1)
            self.fields_layout.addWidget(dip_cbox, nrow, 2)
            self.fields_layout.addWidget(dipdir_cbox, nrow, 3)
            self.fields_layout.addWidget(revpolarity_cbox, nrow, 4)
            self.fields_layout.addWidget(diponly_cbox, nrow, 5)
            # Get info and fill comboboxes
            fields = {}
            fields[RFN.UNIT] = datafilter.value
            unit_cbox.setCurrentText(fields[RFN.UNIT])
            if datafilter.dip_fields is not None:
                fields[RFN.DIP] = datafilter.dip_fields[RFN.DIP]
                fields[RFN.DIP_DIRECTION] = datafilter.dip_fields[RFN.DIP_DIRECTION]
                fields[RFN.REVERSE_POLARITY] = datafilter.dip_fields.get(
                    RFN.REVERSE_POLARITY, ""
                )  # TODO Field is optional as compared to dip / dip_direction
                fields[RFN.DIP_ONLY] = datafilter.dip_fields.get(RFN.DIP_ONLY, "")
                dip_cbox.setCurrentText(fields[RFN.DIP])
                dipdir_cbox.setCurrentText(fields[RFN.DIP_DIRECTION])
                revpolarity_cbox.setCurrentText(fields[RFN.REVERSE_POLARITY])
                diponly_cbox.setCurrentText(fields[RFN.DIP_ONLY])
            self.fields_of_layer[layer] = fields
        self.checked_layernames = layernames
        for layername in self.checked_layernames:
            item = self.layers_cbox.model().findItems(layername)[0]
            item.setCheckState(Qt.CheckState.Checked)
        self.checked_values = get_values_from_expression(datafilter.expression)
        self.fill_values_cbox()
        for value in self.checked_values:
            idx = self.values_cbox.findText(value)
            if idx != -1:
                self.values_cbox.setItemCheckState(idx, Qt.CheckState.Checked)

    def fill_layers_cbox(self):
        self.row_of_layer = {}
        i = 0
        for layer in QgsProject.instance().mapLayers().values():
            if (
                isinstance(layer, QgsVectorLayer)
                and layer.geometryType() == QgsWkbTypes.PointGeometry
            ):
                self.layers_cbox.addItemWithCheckState(
                    layer.name(), Qt.CheckState.Unchecked, userData=layer
                )
                self.row_of_layer[layer.name()] = i
                i += 1

    def update_layers_selection(self):
        more_checked_layers = len(self.checked_layernames) < len(
            self.layers_cbox.checkedItems()
        )
        if more_checked_layers:
            self.add_new_layer()
        else:
            self.remove_layer()
        self.fill_values_cbox()

    def add_new_layer(self):
        for layername in self.layers_cbox.checkedItems():
            if layername in self.checked_layernames:
                continue  # Process only the newly checked layer(s)
            self.checked_layernames.append(layername)
            layer = self.layers_cbox.itemData(self.row_of_layer[layername])
            (
                unit_cbox,
                dip_cbox,
                dipdir_cbox,
                revpolarity_cbox,
                diponly_cbox,
            ) = self.fields_cboxes(layer)
            # Add to layout
            label = QLabel(layername, alignment=Qt.AlignmentFlag.AlignLeft, height=15)
            nrow = self.fields_layout.rowCount()
            self.fields_layout.addWidget(label, nrow, 0)
            self.fields_layout.addWidget(unit_cbox, nrow, 1)
            self.fields_layout.addWidget(dip_cbox, nrow, 2)
            self.fields_layout.addWidget(dipdir_cbox, nrow, 3)
            self.fields_layout.addWidget(revpolarity_cbox, nrow, 4)
            self.fields_layout.addWidget(diponly_cbox, nrow, 5)
            # Update
            self.fields_of_layer[layer] = {
                RFN.UNIT: unit_cbox.currentText(),
                RFN.DIP: dip_cbox.currentText(),
                RFN.DIP_DIRECTION: dipdir_cbox.currentText(),
                RFN.REVERSE_POLARITY: revpolarity_cbox.currentText(),
                RFN.DIP_ONLY: diponly_cbox.currentText(),
            }

    def remove_layer(self):
        for layername in self.checked_layernames:
            if (
                self.layers_cbox.itemCheckState(self.row_of_layer[layername])
                == Qt.CheckState.Unchecked
            ):
                self.checked_layernames.remove(layername)
                layer = self.layers_cbox.itemData(self.row_of_layer[layername])
                self.fields_of_layer.pop(layer)
                nrow = self.fields_layout.rowCount()
                for row in range(nrow):
                    item = self.fields_layout.itemAtPosition(row, 0)
                    if item:
                        label = item.widget()
                        if label.text() == layername:
                            self.fields_layout.removeWidget(label)
                            label.deleteLater()
                            for i in range(1, self.fields_layout.columnCount()):
                                cbox = self.fields_layout.itemAtPosition(
                                    row, i
                                ).widget()
                                self.fields_layout.removeWidget(cbox)
                                cbox.deleteLater()

    def fields_cboxes(self, layer):
        def create_combobox(fn, name, fields, add_empty_item=True):
            cbox = QComboBox()
            slot = lambda: fn(cbox, name)  # noqa: E731
            cbox.activated.connect(slot)
            if add_empty_item:
                cbox.addItem("")
            for f in fields:
                cbox.addItem(f.name())
            return cbox

        def update_other_cbox(cbox, name):
            self.fields_of_layer[layer][name] = cbox.currentText()

        def update_unit_cbox(cbox, name):
            update_other_cbox(cbox, name)
            self.fill_values_cbox()

        # Compute once lists of text / numeric / boolean fields
        str_fields = [f for f in layer.fields() if f.type() == QVariant.String]
        # FIXME What about long, float, and potentially others?
        num_types = (QVariant.Int, QVariant.Double, QVariant.LongLong)
        num_fields = [f for f in layer.fields() if f.type() in num_types]
        bool_fields = [f for f in layer.fields() if f.type() == QVariant.Bool]

        # Unit fieldname combobox
        unit_cbox = create_combobox(update_unit_cbox, RFN.UNIT, str_fields, False)
        # Dip fieldname combobox
        dip_cbox = create_combobox(update_other_cbox, RFN.DIP, num_fields)
        # Dip direction fieldname combobox
        dipdir_cbox = create_combobox(update_other_cbox, RFN.DIP_DIRECTION, num_fields)
        # Reverse polarity fieldname combobox
        revpolarity_cbox = create_combobox(
            update_other_cbox, RFN.REVERSE_POLARITY, bool_fields
        )
        # Orientation only fieldname combobox (optionnal)
        diponly_cbox = create_combobox(update_other_cbox, RFN.DIP_ONLY, bool_fields)
        return unit_cbox, dip_cbox, dipdir_cbox, revpolarity_cbox, diponly_cbox

    def fill_values_cbox(self):
        self.values_cbox.clear()
        values = set()
        for layer, fieldnames in self.fields_of_layer.items():
            if fieldnames[RFN.UNIT] != "":
                for feature in layer.getFeatures():
                    values.add(str(feature[fieldnames[RFN.UNIT]]))
        values = sorted(values)
        self.values_cbox.addItems(values)

    def update_values_selection(self):
        self.checked_values = self.values_cbox.checkedItems()

    def clear(self):
        # Clear layers
        self.checked_layernames = []
        self.layers_cbox.deselectAllOptions()
        # Clear fields names
        self.fields_of_layer = {}
        clearlayout(self.fields_layout)
        self.fields_layout.addWidget(
            QLabel(self._filter_on_label, alignment=Qt.AlignmentFlag.AlignCenter), 0, 1
        )
        self.fields_layout.addWidget(
            QLabel("Dip", alignment=Qt.AlignmentFlag.AlignCenter), 0, 2
        )
        self.fields_layout.addWidget(
            QLabel("Dip direction", alignment=Qt.AlignmentFlag.AlignCenter), 0, 3
        )
        self.fields_layout.addWidget(
            QLabel("Reverse polarity", alignment=Qt.AlignmentFlag.AlignCenter), 0, 4
        )
        self.fields_layout.addWidget(
            QLabel("Orientation only", alignment=Qt.AlignmentFlag.AlignCenter), 0, 5
        )
        # Clear values
        self.checked_values = []
        self.values_cbox.deselectAllOptions()

    def process_selection(self):
        values = [f"'{v}'" for v in self.values_cbox.checkedItems()]
        values = ", ".join(values)
        return self._get_data_selection(values)

    def _get_data_selection(self, values):
        def _make_filter(layer, fields):
            layer_id = layer.id()
            value = fields[RFN.UNIT]
            expression = get_expression_from_values(value, values)
            dip = None  # Map of dip-related fields (dip, dip-dir, polarity, ...)
            if fields.get(RFN.DIP):
                dip = {key: name for key, name in fields.items() if key != RFN.UNIT}
            return QgisVectorDataFilter(layer_id, value, expression, dip)

        return ItemDataSelection(
            [_make_filter(k, v) for k, v in self.fields_of_layer.items()]
        )

    def _extract_selection_as_a_new_layer(self, new_layer_name):
        values = [f"'{v}'" for v in self.values_cbox.checkedItems()]
        values = ", ".join(values)
        item_filters = self._get_data_selection(values)
        if not item_filters:
            QMessageBox.warning(self, "Invalid selection", "No layer selected")
            return
        if not values:
            QMessageBox.warning(self, "Invalid selection", "No filter values selected")
            return
        crs = QgsProject.instance().crs()
        df, geometry = apply_item_filters(item_filters, crs)
        if df is None:
            QMessageBox.warning(
                self,
                "No data extracted",
                "No data with valid geometry matched the input selection",
            )
            return
        layer = QgsVectorLayer("Point", new_layer_name, "memory")
        layer.setCrs(crs)
        layer.startEditing()
        provider = layer.dataProvider()
        # Fields creation
        provider.addAttributes(
            [QgsField(field, type) for field, type in RFN.qgis_type_map().items()]
        )
        layer.updateFields()
        # Add features
        features = []
        for idx, row in df.iterrows():
            feature = QgsFeature()
            feature.setGeometry(QgsGeometry.fromPoint(QgsPoint(*geometry[idx])))
            feature.setAttributes(list(row.to_numpy(na_value=None)))
            features.append(feature)
        provider.addFeatures(features)
        layer.commitChanges()
        layer.updateExtents()
        # Set color
        symbol = QgsMarkerSymbol()
        symbol.setColor(self.color)
        renderer = QgsSingleSymbolRenderer(symbol)
        layer.setRenderer(renderer)
        layer.triggerRepaint()
        # Add layer
        QgsProject.instance().addMapLayer(layer)


def get_values_from_expression(expression):
    marker = "in ("
    return [
        value.strip().strip("'")
        for value in expression[(expression.find(marker) + len(marker)) : -1].split(",")
    ]


def get_expression_from_values(unit, values):
    return f'"{unit}" in ({values})'
