import xml.etree.ElementTree as ET
from dataclasses import dataclass

from forgeo.io.xml import Serializer

from ....utils import ReservedFieldNames as RFN


@dataclass
class QgisLayerPointer:
    source: str
    id: str


@dataclass
class QgisVectorDataFilter:
    """Stores parameters to run the `native:extractbyexpression` processing,
    that is, the result of filtering one vector layer using a QgsExpression
    """

    layer: QgisLayerPointer
    filter: str = None  # The QgsExpression, if any
    dip_fields: dict[RFN:str] = None  # Map between field uses and field names


@dataclass
class QgisVectorDataSelection:
    """Stores the parameters of the PointsFilterWidget:
    - a list of QGISVectorDataExpression
    - and the resulting layer, if any
    """

    inputs: list[QgisVectorDataFilter]
    result: QgisLayerPointer = None


class UnknownXmlTagException(Exception):
    pass


class QgisLayerPointerSerializer(
    Serializer, target=QgisLayerPointer, tag="QgisLayerPointer"
):
    @classmethod
    def dump_element(cls, layer_pointer, e):
        e.attrib["source"] = layer_pointer.source
        e.attrib["id"] = layer_pointer.id

    @classmethod
    def load_element(cls, e):
        source = e.attrib["source"]
        id_ = e.attrib["id"]
        return QgisLayerPointer(source, id_)


class QgisVectorDataFilterSerializer(
    Serializer, target=QgisVectorDataFilter, tag="QgisVectorDataFilter"
):
    @classmethod
    def dump_element(cls, data_filter, e):
        e.append(QgisLayerPointerSerializer.dump(data_filter.layer))
        if data_filter.filter:
            node = ET.Element("Filter")
            node.text = data_filter.filter
            e.append(node)
        if data_filter.dip_fields:
            dip_node = ET.Element("DipMeasurement")
            for key, value in data_filter.dip_fields.items():
                if value:  # Avoids dumping "ReversePolarity" = ""
                    node = ET.Element(key)
                    node.text = value
                    dip_node.append(node)
            e.append(dip_node)

    @classmethod
    def load_element(cls, e):
        layer = filter = dip_fields = None
        for elem in e:
            tag = elem.tag
            if tag == QgisLayerPointerSerializer.tag:
                layer = QgisLayerPointerSerializer.load(elem)
            elif tag == "Filter":
                filter = elem.text
            elif tag == "DipMeasurement":
                dip_fields = {child.tag: child.text for child in elem}
            else:
                msg = f"In {cls.__name__}: {tag = }"
                raise UnknownXmlTagException(msg)
        return QgisVectorDataFilter(layer, filter, dip_fields)


class QGISVectorDataSelectionSerializer(
    Serializer, target=QgisVectorDataSelection, tag="QgisVectorDataSelection"
):
    @classmethod
    def dump_element(cls, data_selection, e):
        inputs = data_selection.inputs
        if inputs is not None:
            node = ET.Element("Inputs")
            for data_filter in inputs:
                node.append(QgisVectorDataFilterSerializer.dump(data_filter))
            e.append(node)
        result = data_selection.result
        if result is not None:
            node = ET.Element("Result")
            node.append(QgisLayerPointerSerializer.dump(result))
            e.append(node)

    @classmethod
    def load_element(cls, e):
        inputs = result = None
        for child in e:
            tag = child.tag
            if tag == "Inputs":
                inputs = [QgisVectorDataFilterSerializer.load(c) for c in child]
            elif tag == "Result":
                children = list(child)
                assert len(children) == 1
                result = QgisLayerPointerSerializer.load(children[0])
            else:
                msg = f"In {cls.__name__}: {tag = }"
                raise UnknownXmlTagException(msg)
        return QgisVectorDataSelection(inputs, result)


class ModelFilters(list):
    pass


class FiltersSerializer(Serializer, target=ModelFilters, tag="ModelFilters"):
    @classmethod
    def dump_element(cls, modelfilters, e):
        for filter in modelfilters:
            e.append(QGISVectorDataSelectionSerializer.dump(filter))

    @classmethod
    def load_element(cls, e):
        return ModelFilters(QGISVectorDataSelectionSerializer.load(elem) for elem in e)
