from dataclasses import dataclass
from enum import StrEnum
from math import nan
from pathlib import Path

import numpy as np
import pandas as pd
from forgeo.core import (
    DipDirMeasurements,
    ItemData,
    Locations,
    Orientations,
    RasterDescription,
)
from qgis.core import (
    QgsCoordinateReferenceSystem,
    QgsCoordinateTransform,
    QgsGeometry,
    QgsProject,
    QgsRasterLayer,
)
from qgis.PyQt.QtCore import QSize, Qt, QVariant
from qgis.PyQt.QtGui import QIcon, QPainter, QPixmap
from qgis.PyQt.QtWidgets import (
    QApplication,
    QDialog,
    QDialogButtonBox,
    QFileDialog,
    QLabel,
    QStyle,
    QVBoxLayout,
    QWidget,
)
from qgis.utils import iface

import forgeo.io.xml as fxml

from .settings import PluginSettings

settings = PluginSettings()

ROOT = Path(__file__).parent / "assets"
ICON = next(ROOT.glob("**/icon.*"), "")


def qicon(name: str | None = None) -> QIcon:
    if name is None:
        return QIcon(str(ICON))
    if name.startswith("SP_"):
        if enum := getattr(QStyle, name, None):
            return QApplication.style().standardIcon(enum)
    elif file := next(ICON.parent.glob(f"**/{name}.*"), "") or next(
        ROOT.glob(f"**/{name}.*"), ""
    ):
        return QIcon(str(file))

    return QIcon()


FORGEO_DATA_DIR = "forgeo_data"
FORGEO_OUTPUT_DIR = "forgeo_results"
DEFAULT_PILE_NAME = "Pile"
DEFAULT_MODEL_NAME = "Model"
DEFAULT_FAULTNET_NAME = "Fault network"


class ReservedFieldNames(StrEnum):
    UNIT = "Unit"  # FilterOn
    DIP = "Dip"
    DIP_DIRECTION = "DipDirection"
    REVERSE_POLARITY = "ReversePolarity"
    DIP_ONLY = "DipOnly"

    @classmethod
    def qgis_type_map(cls):
        return {
            cls.UNIT: QVariant.String,
            cls.DIP: QVariant.Double,
            cls.DIP_DIRECTION: QVariant.Double,
            cls.REVERSE_POLARITY: QVariant.Bool,
            cls.DIP_ONLY: QVariant.Bool,
        }


def qpixmap(name: str | None = None) -> QPixmap:
    if name is None:
        return QPixmap(str(ICON))
    file = next(ICON.parent.glob(f"**/{name}.*"), "") or next(
        ROOT.parent.glob(f"**/{name}.*"), ""
    )
    return QPixmap(str(file))


def get_forgeo_data_dir():
    path = Path(QgsProject.instance().absolutePath())
    if not path.exists():
        return None
    path = path / FORGEO_DATA_DIR
    path.mkdir(exist_ok=True)  # FIXME May raise a "Permision denied" exception on C:\ ?
    return path


def get_forgeo_output_dir():
    path = Path(QgsProject.instance().absolutePath())
    if not path.exists():
        return None
    path = path / FORGEO_OUTPUT_DIR
    path.mkdir(exist_ok=True)  # FIXME May raise a "Permision denied" exception on C:\ ?
    return path


def input_data_type(item):
    # FIXME Use flags : 0001 ! 0100 | ...
    item_data = item.item_data
    if item_data is None:
        return None
    if item_data.has_observation_data() or item_data.has_orientation_data():
        return "vectorData"
    return "rasterData"


def data_icon_widget(data_type, nb_observations, nb_orientations):
    if nb_observations + nb_orientations == 0:
        return None
    label_txt = QLabel(
        f"{nb_observations}    {nb_orientations}",
        alignment=Qt.AlignmentFlag.AlignCenter,
    )
    label_img = QLabel()
    label_img.setPixmap(qpixmap(data_type).scaled(50, 25))
    layout = QVBoxLayout()
    layout.addWidget(label_img)
    layout.addWidget(label_txt)
    widget = QWidget()
    widget.setLayout(layout)
    return widget


def clearlayout(layout):
    # FIXME Why not simply calling deleteLater() and replacing by a new layout?
    if layout is not None:
        while layout.count():
            item = layout.takeAt(0)
            widget = item.widget()
            if widget:
                widget.deleteLater()
            else:
                sub_layout = item.layout()
                if sub_layout:
                    clearlayout(sub_layout)


def cleargridlayoutcolumn(layout, column_index):
    # FIXME Why not simply calling deleteLater() and replacing by a new layout?
    for row in range(layout.rowCount()):
        # Collect all items in the cell (row, column_index)
        items_to_remove = []
        for i in range(layout.count()):
            item = layout.itemAt(i)
            pos = layout.getItemPosition(i)
            # Also select items spanning multiple columns
            if pos[1] <= column_index < (pos[1] + pos[3]) and pos[0] == row:
                items_to_remove.append(item)
        # Remove all collected items
        delete_items(layout, items_to_remove)


def cleargridlayoutcell(layout, row_index, column_index):
    # FIXME Why not simply calling deleteLater() and replacing by a new layout?
    items_to_remove = []
    for i in range(layout.count()):
        item = layout.itemAt(i)
        pos = layout.getItemPosition(i)
        if pos[0] == row_index and pos[1] == column_index:
            items_to_remove.append(item)
    # Remove all collected items
    delete_items(layout, items_to_remove)


def delete_items(layout, items_to_remove):
    for item in items_to_remove:
        widget = item.widget()
        if widget:
            layout.removeWidget(widget)
            widget.deleteLater()
        else:
            sub_layout = item.layout()
            if sub_layout:
                clearlayout(sub_layout)
                layout.removeItem(item)


def popup_save_changes(updated_element_name: str):
    """Pop-up a Qdialog proposing to save the current changes, with two buttons
    "Yes" and "No".
    Returns True or False depending on the user's choice
    """
    dlg = QDialog(parent=iface.mainWindow())
    dlg.setWindowTitle(dlg.tr("Save changes"))
    layout = QVBoxLayout()
    dlg.setLayout(layout)
    lbl_description = QLabel(
        dlg.tr(f"Do you want to save changes in {updated_element_name}")
    )
    buttons = QDialogButtonBox(
        QDialogButtonBox.StandardButton.No | QDialogButtonBox.StandardButton.Yes
    )
    layout.addWidget(lbl_description)
    layout.addWidget(buttons)
    buttons.accepted.connect(dlg.accept)  # Yes
    buttons.rejected.connect(dlg.reject)  # No
    # Note: Call to exec() is deliberate, to block the interface while the user
    # has not made a choice
    return dlg.exec() == QDialog.DialogCode.Accepted


def save_as_png(widget, parent, name, width_reduc=0, width_trans=0):
    filepath = get_forgeo_data_dir() / (name + ".png")
    img_size = QSize(widget.size().width() - width_reduc, widget.size().height())
    pixmap = QPixmap(img_size)
    painter = QPainter(pixmap)
    painter.translate(width_trans, 0)
    widget.render(painter)
    painter.end()
    filename = QFileDialog.getSaveFileName(
        parent=parent,
        caption=parent.tr("Save as image"),
        dir=str(filepath),
        filter=parent.tr("Images (*.png)"),
    )[0]
    if filename:
        pixmap.save(filename, "png")


def save_as_xml(parent, elem):
    filepath = get_forgeo_data_dir() / (elem.name + ".xml")
    filename = QFileDialog.getSaveFileName(
        parent=parent,
        caption=parent.tr("Save as XML"),
        dir=str(filepath),
        filter=parent.tr("XML (*.xml)"),
    )[0]
    if filename:
        fxml.dump(elem, filename)


def raster_layer_to_description(raster):
    if raster is None:
        return None
    assert isinstance(raster, QgsRasterLayer)
    # Get only the band of interest, returned array shape (len(bands), nx, ny)
    band = 0  # TODO Someday, handle multiband raster (do not impose band = 0)
    data = raster.as_numpy(use_masking=True, bands=[band])[0]  # Masked array
    # FIXME For now, it does not seem to return a np.ma... Or it only does if there
    # are ndv in the raster and otherwise returns a "normal" ndarray?
    # So for knowthe following does not work... And we ignore ndv...
    # # Remove no-data values (masked elements in data)
    # data.data[data.mask] = np.nan

    assert data.shape == (raster.height(), raster.width())  # ny, nx
    # data = np.transpose(data)[::-1]
    # shape = (raster.width(), raster.height())  # nx, ny
    # data = data[::-1, :]  # Reverse lines, so data[0, 0] is z(xmin)
    # data = data.transpose()  # Transpose to match gmlib axis ordering
    extent = raster.extent()
    xmin = (extent.xMinimum(), extent.yMinimum())
    xmax = (extent.xMaximum(), extent.yMaximum())
    return RasterDescription.from_bbox(data.shape, xmin, xmax, data)


def extract_item_data(item_filters, target_crs: QgsCoordinateReferenceSystem):
    """
    Parameters
    ----------
    item_filters: ItemDataSelection
    target_crs: QgsCoordinateReferenceSystem
        CRS in which to return the extracted geometries. Note: each filter in
        `item_filters` has its own input CRS
    Returns
    -------
    data: ItemData
        The extracted observation and orientation data
    """
    RFN = ReservedFieldNames
    df, geometry = apply_item_filters(item_filters, target_crs)
    if df is None:
        return None
    fields = list(df.columns)
    nb_fields = len(fields)
    assert nb_fields > 0
    if nb_fields == 1:  # Only observations data
        return ItemData.from_observations(geometry)
    df = df.drop(RFN.UNIT, axis=1)
    if nb_fields < 2:
        msg = "At least two fields required for dip data (dip and dir)"
        raise AssertionError(msg)
    orientations = df[[RFN.DIP, RFN.DIP_DIRECTION]].to_numpy(dtype=np.float64)
    is_orientations = ~np.any(np.isnan(orientations), axis=1)
    item_data = {  # Initialize with observation data, then update with orientations
        "observations": (
            None if np.all(is_orientations) else Locations(geometry[~is_orientations])
        )
    }

    # Process dip vs dip_only data
    locations = geometry[is_orientations]
    orientations = orientations[is_orientations]
    polarities = (
        None
        if RFN.REVERSE_POLARITY not in fields
        else df[RFN.REVERSE_POLARITY].to_numpy(dtype=bool)[is_orientations]
    )
    if RFN.DIP_ONLY not in fields:
        item_data["orientations"] = _to_orientations(
            locations, orientations, polarities
        )
    else:
        flag_only = df[RFN.DIP_ONLY].to_numpy(dtype=bool)[is_orientations]
        if np.all(flag_only):
            item_data["orientations_only"] = _to_orientations(
                locations, orientations, polarities
            )
        else:  # Both "orientations" and "orientations_only"
            item_data["orientations"] = _to_orientations(
                locations[~flag_only],
                orientations[~flag_only],
                None if polarities is None else polarities[~flag_only],
            )
            item_data["orientations_only"] = _to_orientations(
                locations[flag_only],
                orientations[flag_only],
                None if polarities is None else polarities[flag_only],
            )
    return ItemData(**item_data)


def apply_item_filters(item_filters, target_crs: QgsCoordinateReferenceSystem):
    """
    Parameters
    ----------
    item_filters: ItemDataSelection
    target_crs: QgsCoordinateReferenceSystem
        CRS in which to return the extracted geometries. Note: each filter in
        `item_filters` has its own input CRS
    Returns
    -------
    df: pandas.DataFrame
        All the features and attributes of the input sources (that have a vaild
        geometry)
    geometry: numpy.ndarray
        Shape (len(df), 3). The (x,y,z) coordinates of each feature in df
    """
    df, geometry = [], []
    for data_filter in item_filters:
        d, g = apply_one_item_filter(data_filter, target_crs)
        if d is not None:
            assert g is not None
            assert len(d) == len(g)
            df.append(d)
            geometry.append(g)
        else:  # Not useful, simply a safeguard
            assert g is None
    nb_selections = len(df)
    if nb_selections == 0:
        return None, None
    if nb_selections == 1:
        df = df[0]
        geometry = geometry[0]
    else:
        df = pd.concat(df)
        geometry = np.concatenate(geometry, axis=0)
    assert len(df) == len(geometry)
    return df, geometry


def apply_one_item_filter(data_filter, target_crs: QgsCoordinateReferenceSystem):
    """
    Parameters
    ----------
    data_filter: QgisVectorDataFilter
    target_crs: QgsCoordinateReferenceSystem
        CRS in which to return the extracted geometries. Note: the "input" CRS is
        given by `data_filter`
    Returns
    -------
    df: pandas.DataFrame
        All the features and attributes of the input sources (that have a vaild
        geometry)
    geometry: numpy.ndarray
        Shape (len(df), 3). The (x,y,z) coordinates of each feature in df
    """
    # Columns of interest (unit, dip, dipdir, polarity, only)
    # Note: we keep unit even though we do not use it, to ensure we do
    # not produce an empty DataFrame for observations-only layers
    if data_filter.dip_fields is not None:
        columns = [data_filter.value, *list(data_filter.dip_fields.values())]
    else:
        columns = [data_filter.value] + [""] * 4
    no_data = (None, None)
    if not columns:
        return no_data  # Unsure: may be useful for "empty" widget
    layer = QgsProject.instance().mapLayer(data_filter.layer_id)

    # Define the CRS transform
    transform = None
    if (source_crs := layer.crs()).authid() != target_crs.authid():
        transform = QgsCoordinateTransform(
            source_crs, target_crs, QgsProject.instance()
        )

    # Extract geometries

    xyz = []
    for feature in layer.getFeatures(data_filter.expression):
        vertex = 3 * (np.nan,)
        if feature.hasGeometry():
            geometry = feature.geometry()
            if not geometry.isEmpty():
                if transform is not None:  # Apply CRS transform
                    geometry = QgsGeometry(geometry)  # Do not modify the original
                    geometry.transform(transform, transformZ=False)  # TODO Unsure for Z
                v = next(geometry.vertices())
                vertex = (v.x(), v.y(), v.z())  # FIXME Forces to have a Z coord
        xyz.append(vertex)
    if len(xyz) == 0:
        return no_data  # No data matching filter in the input layer
    xyz = np.array(xyz, np.float64)
    valid_geometry = ~np.any(np.isnan(xyz), axis=1)
    xyz = xyz[valid_geometry]  # Filter invalid geometries
    if len(xyz) == 0:
        return no_data  # No valid geometry, cannot exploit the layer

    # Extract selected features

    RFN = ReservedFieldNames
    fields = {data_filter.value: RFN.UNIT}
    if data_filter.dip_fields is not None:
        for k, v in data_filter.dip_fields.items():
            if v:
                fields[v] = k
    # Keep only features with matching name
    df = pd.DataFrame(
        layer.getFeatures(data_filter.expression), columns=layer.fields().names()
    )
    if df.empty:
        return no_data  # No features matching name
    # Keep only columns of interest with valid geometries
    df = df[fields.keys()]
    df = df[valid_geometry]
    assert len(df) == len(xyz)
    # Rename columns (to have the same names across all individual df)
    df = df.rename(columns=fields)
    # Check "observations only" case: no need to go further
    if len(columns) == 1:
        # assert columns[0] == fields[RFN.UNIT]  # Safeguard
        return df, xyz
    assert len(columns) >= 3  # Unit + dip + dip-dir

    # Process the different "unusable" values
    float_fields = [RFN.DIP, RFN.DIP_DIRECTION]
    boolean_fields = [
        name for name in (RFN.DIP_ONLY, RFN.REVERSE_POLARITY) if name in df.columns
    ]
    if len(fields) > 1:
        df[float_fields] = df[float_fields].map(
            _convert_null_qvariant_to_nan, na_action="ignore"
        )
        df[boolean_fields] = df[boolean_fields].map(
            _convert_null_qvariant_to_false, na_action="ignore"
        )
    return df, xyz


def get_data_layers(model_filters):
    layers = []
    if model_filters is not None:
        for item_filters in model_filters:
            for data_filter in item_filters:
                layers.append(QgsProject().instance().mapLayer(data_filter.layer_id))
    return layers


def connect_layer_update(data_layer, model_layer, connect=True):
    def _layer_changed():
        return model_layer.layer_changed(data_layer)

    if connect:
        data_layer.editingStopped.connect(_layer_changed)
        data_layer.willBeDeleted.connect(_layer_changed)
        data_layer.crsChanged.connect(_layer_changed)
        data_layer.crs3DChanged.connect(_layer_changed)
    else:
        try:
            data_layer.editingStopped.disconnect(_layer_changed)
            data_layer.willBeDeleted.disconnect(_layer_changed)
            data_layer.crsChanged.disconnect(_layer_changed)
            data_layer.crs3DChanged.disconnect(_layer_changed)
        except Exception:
            pass


@dataclass
class QgisVectorDataFilter:
    """Stores parameters to run the `native:extractbyexpression` processing,
    that is, the result of filtering one vector layer using a QgsExpression
    """

    layer_id: str = None  # TODO store QgsVectorLayer (but serialize only its id)
    value: str = None  # RFN.UNIT
    expression: str = None  # The complete QgsExpression
    dip_fields: dict[ReservedFieldNames:str] = (
        None  # Map between field uses and field names
    )


class ItemDataSelection:
    """A mere container that stores all the QgisVectorDataFilter associated to
    a given item
    """

    def __init__(self, filters):
        """filters: Iterable[QgisVectorDataFilter] | ItemDataSelection"""
        if isinstance(filters, type(self)):
            filters = filters.filters
        filters = [] if not filters else list(filters)
        assert all(isinstance(f, QgisVectorDataFilter) for f in filters)
        self.filters = filters

    # Note: unsure which dunder methods should be implemented...
    # For now, __getitem__ and __setitem__ are not useful
    def __bool__(self):
        return bool(self.filters)

    def __iter__(self):
        return iter(self.filters)

    def __len__(self):
        return len(self.filters)

    def extract_selection(self, target_crs):
        return extract_item_data(self.filters, target_crs)


def _to_orientations(locs, ori, pol):
    return Orientations(Locations(locs), DipDirMeasurements(ori, pol))


def _convert_null_qvariant_to_nan(v):
    return nan if isinstance(v, QVariant) else v


def _convert_null_qvariant_to_false(v):
    return False if isinstance(v, QVariant) else v
