from forgeo.rigs import all_intersections
from qgis.core import (
    Qgis,
    QgsFeature,
    QgsFeatureSink,
    QgsField,
    QgsFields,
    QgsLineString,
    QgsMultiLineString,
    QgsProcessing,
    QgsProcessingException,
    QgsProcessingParameterFeatureSink,
    QgsWkbTypes,
)
from qgis.PyQt.QtCore import QVariant

from ._utils import colors_as_strings, custom_symbol_renderer
from .gmprocessing import GmProcessingAlgorithm


class CustomSink:
    def __init__(self, processing, name, nature, parameters, context, fields, crs):
        (sink, sink_id) = processing.parameterAsSink(
            parameters,
            name,
            context=context,
            fields=fields,
            geometryType=QgsWkbTypes.MultiLineStringZ,
            crs=crs,
        )
        # Check feature sink
        if sink is None:
            raise QgsProcessingException(self.invalidSinkError(parameters, name))
        self.sink = sink
        self.sink_id = sink_id
        self.nature = nature
        self.symbology = {}

    def add(self, feature, key, name, color):
        self.symbology.setdefault(key, (color, name))
        self.sink.addFeature(feature, QgsFeatureSink.FastInsert)


class FeatureExtracter(GmProcessingAlgorithm, name="", display_name=""):
    OUTPUT_FAULTS = "OUTPUT_FAULTS"
    OUTPUT_FORMATIONS = "OUTPUT_FORMATIONS"

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)

    def __init__(self, info="geology"):
        super().__init__()
        self.info = info
        self.fields = QgsFields([QgsField(info, QVariant.Int)])
        self.faults_sink = None
        self.formations_sink = None

    # We add a feature sink in which to store our processed features (this
    # usually takes the form of a newly created vector layer when the
    # algorithm is run in QGIS).
    def add_parameter_sinks(self):
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                name=self.OUTPUT_FAULTS,
                description=self.tr("Faults layer"),
                type=QgsProcessing.TypeVectorLine,
            )
        )
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                name=self.OUTPUT_FORMATIONS,
                description=self.tr("Formation layer"),
                type=QgsProcessing.TypeVectorLine,
            )
        )

    def init_sinks(self, parameters, context, crs):
        def make_sink(parameter, nature):
            return CustomSink(
                self, parameter, nature, parameters, context, self.fields, crs
            )

        assert self.faults_sink is None
        self.faults_sink = make_sink(self.OUTPUT_FAULTS, "faults")
        assert self.formations_sink is None
        self.formations_sink = make_sink(self.OUTPUT_FORMATIONS, "formations")

    @property
    def sinks(self):
        return self.faults_sink, self.formations_sink

    def extract_features(self, model, mesh):
        names = model["names"]
        is_fault = model["is_fault"]
        contacts = all_intersections(
            *mesh, **model, return_contact_polylines=True
        ).contacts
        vertices = contacts.vertices
        lines = contacts.lines

        faults_sink, formations_sink = self.sinks
        colormap = colors_as_strings(model["colors"])

        for contact, contact_lines in lines.items():
            if contact_lines:
                # print(f"adding {[len(l) for l in contact_lines]} for contact {contact}")
                contact_feature = QgsFeature(self.fields)
                contact_feature.setGeometry(
                    QgsMultiLineString(
                        [QgsLineString(vertices[line]) for line in contact_lines]
                    )
                )
                contact_feature.setAttributes([str(contact)])
                if is_fault(contact):
                    faults_sink.add(
                        contact_feature, contact, names[contact], colormap[contact]
                    )
                else:
                    formations_sink.add(
                        contact_feature, contact, names[contact], colormap[contact]
                    )

    def finalize(self, layer_name, context, support_name=None):
        for sink in self.sinks:
            details = context.layerToLoadOnCompletionDetails(sink.sink_id)
            details.name = f"{layer_name} {sink.nature}"
            if support_name:
                details.name += f"on {support_name}"
            details.forceName = True
            section = context.getMapLayer(sink.sink_id)
            eprop = section.elevationProperties()
            eprop.setBinding(Qgis.AltitudeBinding.Vertex)
            eprop.setClamping(Qgis.AltitudeClamping.Absolute)
            # postprocess the ouput layer
            linewidth = {"faults": 4, "formations": 2}[sink.nature]
            section.setRenderer(
                custom_symbol_renderer(
                    self.info, sink.symbology, width=linewidth, width_unit="Pixel"
                )
            )
        return {
            self.OUTPUT_FAULTS: self.faults_sink.sink,
            self.OUTPUT_FORMATIONS: self.formations_sink.sink,
        }

    def extract_features_once(
        self, mesh, parameters, context, feedback, support_name=None
    ):
        input_layer = self.model_layer(parameters)
        crs = input_layer.crs()
        self.init_sinks(parameters, context, crs)
        model = self.extraction_parameters(
            parameters, with_topography=False, faults_only=False
        )
        self.extract_features(model, mesh)
        result = self.finalize(input_layer.name(), context, support_name)
        feedback.setProgress(100)
        return result
