import numpy as np
from scipy.io import loadmat

from qgis.core import (
    Qgis,
    QgsFeature,
    QgsFeatureSink,
    QgsField,
    QgsFields,
    QgsGeometry,
    QgsPointXY,
    QgsProcessingAlgorithm,
    QgsProcessingException,
    QgsProcessingParameterDefinition,
    QgsProcessingParameterFeatureSink,
    QgsProcessingParameterFeatureSource,
    QgsWkbTypes,
)
from qgis.PyQt.QtCore import QCoreApplication, QMetaType

from processing_swan_provider.core.processing_parameters import (
    QgsProcessingParameterMatField,
    QgsProcessingParameterMultipleMatFile,
)


class ExtractValuesFromMatAlgorithm(QgsProcessingAlgorithm):
    """
    This algorithm extracts values from SWAN .mat files
    for use in QGIS.
    """

    # Constants used to refer to parameters and outputs
    INPUT = "INPUT"
    INPUT_POINT_LAYER = "INPUT_POINT_LAYER"
    OUTPUT = "OUTPUT"
    FIELDS = "FIELDS"
    X_FIELD = "X_FIELD"
    Y_FIELD = "Y_FIELD"
    CRS = "CRS"

    def createInstance(self):
        return ExtractValuesFromMatAlgorithm()

    def tr(self, message: str) -> str:
        """
        Get the translation for a string using Qt translation API.
        """
        return QCoreApplication.translate(self.__class__.__name__, message)

    def name(self):
        """
        Returns the algorithm name.
        """
        return "extract_values_from_mat"

    def displayName(self):
        """
        Returns the translated algorithm name.
        """
        return self.tr("Extract Values from SWAN .mat")

    def group(self):
        """
        Returns the name of the group this algorithm belongs to.
        """
        return self.tr("SWAN Postprocessing")

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to.
        """
        return "swan_postprocessing"

    def shortHelpString(self):
        """
        Returns a localised short helper string for the algorithm.
        """
        return self.tr(
            "Extracts values from SWAN .mat files for use in QGIS.\n"
            "Point layer is supposed to be in the same CRS as the .mat files."
        )

    def initAlgorithm(self, config=None):
        """
        Define the inputs and outputs of the algorithm.
        """
        # Add the input parameters
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_POINT_LAYER,
                self.tr("Input point layer"),
            )
        )

        self.addParameter(
            QgsProcessingParameterMultipleMatFile(
                self.INPUT,
                self.tr("Input SWAN .mat files"),
            )
        )

        self.addParameter(
            QgsProcessingParameterMatField(
                self.FIELDS,
                self.tr("SWAN.mat Fields to extract"),
                parentMatFileParameterName=self.INPUT,
                allowMultipleValues=True,
                selectAllValuesByDefault=True,
            )
        )

        # Add parameters for X and Y fields
        param = QgsProcessingParameterMatField(
            self.X_FIELD,
            self.tr("SWAN.mat X field"),
            defaultValue="Xp",
            parentMatFileParameterName=self.INPUT,
        )
        param.setFlags(param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
        self.addParameter(param)

        param = QgsProcessingParameterMatField(
            self.Y_FIELD,
            self.tr("SWAN.mat Y field"),
            defaultValue="Yp",
            parentMatFileParameterName=self.INPUT,
        )
        param.setFlags(param.flags() | QgsProcessingParameterDefinition.Flag.FlagAdvanced)
        self.addParameter(param)

        # Add the output parameter
        self.addParameter(
            QgsProcessingParameterFeatureSink(
                self.OUTPUT,
                self.tr("Points with extracted values"),
                type=Qgis.ProcessingSourceType.TypeVectorPoint,
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        """
        Run the algorithm.
        """
        point_layer = self.parameterAsSource(parameters, self.INPUT_POINT_LAYER, context)
        if point_layer is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.INPUT_POINT_LAYER)
            )

        input_files = self.parameterAsStrings(parameters, self.INPUT, context)
        fields = self.parameterAsStrings(parameters, self.FIELDS, context)
        x_field = self.parameterAsString(parameters, self.X_FIELD, context)
        y_field = self.parameterAsString(parameters, self.Y_FIELD, context)

        # Validate inputs
        if not input_files:
            raise QgsProcessingException(self.tr("No input .mat files specified"))
        if not fields:
            raise QgsProcessingException(self.tr("No fields selected for extraction"))

        # Load and validate .mat files
        mat_data_list, x_coords, y_coords = self._load_mat_files(
            input_files, fields, x_field, y_field, feedback
        )

        # Store input files for reference in processing
        self._input_files = input_files

        # Create output sink
        sink, dest_id = self._create_output_sink(
            parameters, context, point_layer, input_files, fields
        )

        # Process points and extract values
        self._process_points(point_layer, mat_data_list, fields, x_coords, y_coords, sink, feedback)

        feedback.pushInfo(self.tr("Extraction completed successfully"))
        return {self.OUTPUT: dest_id}

    def _load_mat_files(self, input_files, fields, x_field, y_field, feedback):
        """Load and validate .mat files."""
        feedback.pushInfo(self.tr(f"Processing {len(input_files)} .mat file(s)"))
        feedback.pushInfo(self.tr(f"Extracting fields: {', '.join(fields)}"))

        mat_data_list = []
        x_coords = None
        y_coords = None

        for i, mat_file in enumerate(input_files):
            feedback.pushInfo(self.tr(f"Loading .mat file {i + 1}/{len(input_files)}: {mat_file}"))

            if feedback.isCanceled():
                return [], None, None

            try:
                mat_data = loadmat(mat_file)
                mat_data_list.append(mat_data)

                # Get coordinates from first file
                if i == 0:
                    self._validate_coordinate_fields(mat_data, x_field, y_field, mat_file)
                    x_coords = mat_data[x_field].flatten()
                    y_coords = mat_data[y_field].flatten()
                    feedback.pushInfo(self.tr(f"Grid dimensions: {len(x_coords)} points"))

                # Validate that all required fields exist
                self._validate_fields(mat_data, fields, mat_file)

            except Exception as e:
                raise QgsProcessingException(self.tr(f"Error loading {mat_file}: {str(e)}")) from e

        return mat_data_list, x_coords, y_coords

    def _validate_coordinate_fields(self, mat_data, x_field, y_field, mat_file):
        """Validate that coordinate fields exist in mat data."""
        if x_field not in mat_data:
            raise QgsProcessingException(self.tr(f"X field '{x_field}' not found in {mat_file}"))
        if y_field not in mat_data:
            raise QgsProcessingException(self.tr(f"Y field '{y_field}' not found in {mat_file}"))

    def _validate_fields(self, mat_data, fields, mat_file):
        """Validate that all required fields exist in mat data."""
        for field in fields:
            if field not in mat_data:
                raise QgsProcessingException(self.tr(f"Field '{field}' not found in {mat_file}"))

    def _create_output_sink(self, parameters, context, point_layer, input_files, fields):
        """Create the output sink with appropriate fields."""
        output_fields = QgsFields()

        # Add original point layer fields
        for field in point_layer.fields():
            output_fields.append(field)

        # Add mat file name field
        output_fields.append(QgsField("mat_file", QMetaType.Type.QString))

        # Add fields for extracted values
        for field in fields:
            output_fields.append(QgsField(field, QMetaType.Type.Double))

        # Create output sink
        (sink, dest_id) = self.parameterAsSink(
            parameters,
            self.OUTPUT,
            context,
            output_fields,
            QgsWkbTypes.Type.Point,
            point_layer.sourceCrs(),
        )
        if sink is None:
            raise QgsProcessingException(self.invalidSinkError(parameters, self.OUTPUT))

        return sink, dest_id

    def _process_points(
        self, point_layer, mat_data_list, fields, x_coords, y_coords, sink, feedback
    ):
        """Process each point and extract values from .mat files."""
        total_points = point_layer.featureCount()
        total_operations = total_points * len(mat_data_list)
        feedback.pushInfo(
            self.tr(f"Processing {total_points} points for {len(mat_data_list)} .mat files")
        )

        for current, point_feature in enumerate(point_layer.getFeatures()):
            if feedback.isCanceled():
                break

            # Get point coordinates
            point_geom = point_feature.geometry()
            if point_geom.isEmpty():
                continue

            point = point_geom.asPoint()

            # Create one feature per mat file
            for mat_idx, mat_data in enumerate(mat_data_list):
                if feedback.isCanceled():
                    break

                feedback.setProgress(int((current + 1) * 100 / total_operations))

                output_feature = QgsFeature()
                point_xy = QgsPointXY(point.x(), point.y())
                output_feature.setGeometry(QgsGeometry.fromPointXY(point_xy))

                # Copy original attributes
                attributes = []
                for field in point_layer.fields():
                    attributes.append(point_feature.attribute(field.name()))

                # Add mat file name (extract filename without path)
                mat_file_path = self._input_files[mat_idx]
                mat_filename = mat_file_path.split("/")[-1].split("\\")[-1]
                mat_filename = mat_filename.replace(".mat", "")
                attributes.append(mat_filename)

                # Add extracted values for all fields
                for field in fields:
                    value = self._extract_nearest_value(
                        mat_data,
                        field,
                        point,
                        x_coords,
                        y_coords,
                    )
                    attributes.append(value)

                # Set attributes for the output feature
                output_feature.setAttributes(attributes)
                # Add the feature to the sink
                sink.addFeature(output_feature, QgsFeatureSink.Flag.FastInsert)

    def _extract_nearest_value(self, mat_data, field, point, x_coords, y_coords) -> float | None:
        """Extract the nearest value for a given field from mat data."""
        distances = ((x_coords - point.x()) ** 2 + (y_coords - point.y()) ** 2) ** 0.5
        nearest_idx = np.argmin(distances)

        # Extract the value at the nearest index
        try:
            field_data = mat_data[field].flatten()
            if nearest_idx < len(field_data):
                return float(field_data[nearest_idx])
            else:
                return None
        except (IndexError, ValueError, TypeError):
            return None
