from scipy.io import loadmat

from qgis.core import (
    Qgis,
    QgsFeature,
    QgsFeatureSink,
    QgsField,
    QgsFields,
    QgsGeometry,
    QgsPointXY,
    QgsProcessingAlgorithm,
    QgsProcessingException,
    QgsProcessingParameterCrs,
    QgsProcessingParameterDefinition,
    QgsProcessingParameterFeatureSink,
    QgsProcessingParameterFile,
    QgsWkbTypes,
)
from qgis.PyQt.QtCore import QCoreApplication, QMetaType

from processing_swan_provider.core.processing_parameters import (
    QgsProcessingParameterMatField,
)


class MatToVectorAlgorithm(QgsProcessingAlgorithm):
    """
    This algorithm converts SWAN .mat files to vector format
    for use in QGIS.
    """

    # Constants used to refer to parameters and outputs
    INPUT = "INPUT"
    OUTPUT = "OUTPUT"
    X_FIELD = "X_FIELD"
    Y_FIELD = "Y_FIELD"
    CRS = "CRS"

    def createInstance(self):
        return MatToVectorAlgorithm()

    def tr(self, message: str) -> str:
        """
        Get the translation for a string using Qt translation API.
        """
        return QCoreApplication.translate(self.__class__.__name__, message)

    def name(self):
        """
        Returns the algorithm name.
        """
        return "mat_to_vector"

    def displayName(self):
        """
        Returns the translated algorithm name.
        """
        return self.tr("SWAN .mat to Vector")

    def group(self):
        """
        Returns the name of the group this algorithm belongs to.
        """
        return self.tr("SWAN Postprocessing")

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to.
        """
        return "swan_postprocessing"

    def shortHelpString(self):
        """
        Returns a localised short helper string for the algorithm.
        """
        return self.tr("Converts SWAN .mat files to vector format for use in QGIS.")

    def initAlgorithm(self, config=None):
        """
        Define the inputs and outputs of the algorithm.
        """
        # Add the input parameters
        self.addParameter(
            QgsProcessingParameterFile(self.INPUT, self.tr("Input SWAN .mat file"), extension="mat")
        )

        # Add parameter for CRS
        self.addParameter(
            QgsProcessingParameterCrs(
                self.CRS,
                self.tr("Coordinate Reference System"),
                defaultValue="EPSG:2154",  # RGF93 / Lambert-93
            )
        )

        # Add parameters for X and Y fields
        param = QgsProcessingParameterMatField(
            self.X_FIELD,
            self.tr("SWAN .mat X field"),
            defaultValue="Xp",
            parentMatFileParameterName=self.INPUT,
        )
        param.setFlags(param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
        self.addParameter(param)

        param = QgsProcessingParameterMatField(
            self.Y_FIELD,
            self.tr("SWAN .mat Y field"),
            defaultValue="Yp",
            parentMatFileParameterName=self.INPUT,
        )
        param.setFlags(param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
        self.addParameter(param)

        # Add the output parameter
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.OUTPUT,
                self.tr("Output vector layer"),
                type=Qgis.ProcessingSourceType.TypeVectorPoint,
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        """
        Run the algorithm.
        """

        # Get parameters
        input_file = self.parameterAsFile(parameters, self.INPUT, context)
        x_field = self.parameterAsString(parameters, self.X_FIELD, context)
        y_field = self.parameterAsString(parameters, self.Y_FIELD, context)
        crs = self.parameterAsCrs(parameters, self.CRS, context)

        def flatten_array(arr):
            """Flatten the array if it has more than one dimension."""
            return arr.flatten() if arr.ndim > 1 else arr

        try:
            # Load the .mat file
            inData = loadmat(input_file)

            # Pull list of keys from .mat file
            keys = list(inData.keys())

            # Filter out magic keys
            valid_keys = [key for key in keys if not key.startswith("__")]

            # Create a dictionary to hold our data
            data_dict = {}

            if not valid_keys:
                raise QgsProcessingException(self.tr("No valid data found in the .mat file."))

            # Check if X and Y fields are in the valid keys
            if x_field not in valid_keys or y_field not in valid_keys:
                raise QgsProcessingException(
                    self.tr(
                        f"X field '{x_field}' or Y field '{y_field}' not found in the .mat file."
                    )
                )

            # Add X and Y coordinates
            data_dict["X"] = flatten_array(inData[x_field])
            data_dict["Y"] = flatten_array(inData[y_field])

            # Add all other fields to the data dictionary
            for key in valid_keys:
                if key not in [x_field, y_field]:  # Skip X and Y fields to avoid duplication
                    data_dict[key] = flatten_array(inData[key])

            wkbType = QgsWkbTypes.Type.Point

            fields = QgsFields()
            fields.append(QgsField("X", QMetaType.Type.Double))
            fields.append(QgsField("Y", QMetaType.Type.Double))
            for key in valid_keys:
                if key not in [x_field, y_field]:  # Skip X and Y fields to avoid duplication
                    fields.append(QgsField(key, QMetaType.Type.Double))

            (sink, dest_id) = self.parameterAsSink(
                parameters, self.OUTPUT, context, fields, wkbType, crs
            )

            if sink is None:
                raise QgsProcessingException(self.invalidSinkError(parameters, self.OUTPUT))

            # Get the number of points
            num_points = len(data_dict["X"])

            # Create features - one per point with all attributes
            total_features = num_points
            for i in range(num_points):
                if feedback.isCanceled():
                    break

                # Update progress
                feedback.setProgress(int(100 * i / total_features))

                # Create feature
                feature = QgsFeature(fields)

                # Set geometry
                point_xy = QgsPointXY(data_dict["X"][i], data_dict["Y"][i])
                point_geom = QgsGeometry.fromPointXY(point_xy)
                feature.setGeometry(point_geom)

                # Set X and Y attributes
                feature.setAttribute("X", float(data_dict["X"][i]))
                feature.setAttribute("Y", float(data_dict["Y"][i]))

                # Set all other attributes
                for key in valid_keys:
                    if key not in [x_field, y_field] and key in data_dict:
                        values = data_dict[key]
                        if i < len(values):
                            feature.setAttribute(key, float(values[i]))
                        else:
                            feature.setAttribute(key, None)

                # Add feature to sink
                sink.addFeature(feature, QgsFeatureSink.Flag.FastInsert)

            # Return the output file path
            return {self.OUTPUT: dest_id}

        except Exception as e:
            raise QgsProcessingException(str(e)) from e
