import os
from typing import Optional, Dict, Any

import numpy as np
from osgeo import gdal

from qgis.PyQt.QtCore import QMetaType
from qgis.core import (
    Qgis,
    QgsProcessingContext,
    QgsProcessingFeedback,
    QgsProcessingAlgorithm,
    QgsProcessingParameterRasterLayer,
    QgsProcessingParameterVectorLayer,
    QgsProcessingParameterVectorDestination,
    QgsProcessingException, QgsEditorWidgetSetup,
    QgsProcessingUtils,
    QgsProcessingParameterDefinition,
    QgsFeatureSink,
    QgsFeature,
    QgsField, QgsFields,
    QgsCoordinateTransform,
    QgsProject,
    QgsGeometry,
    QgsVectorFileWriter
)
from qgis.core import QgsProcessing, QgsVectorLayer, QgsFeatureRequest, QgsMapLayer
from ..core import create_profile_field
from ..core.spectralprofile import (
    prepareProfileValueDict,
    encodeProfileValueDict
)
from ...fieldvalueconverter import GenericFieldValueConverter
from ...qgsrasterlayerproperties import QgsRasterLayerSpectralProperties


class ExtractSpectralProfiles(QgsProcessingAlgorithm):
    """
    Extracts spectral profiles from a raster layer at locations defined by vector features.
    """

    P_INPUT_RASTER = 'INPUT_RASTER'
    P_INPUT_VECTOR = 'INPUT_VECTOR'
    P_OUTPUT = 'OUTPUT'
    # PROFILE_FIELD_NAME = 'PROFILE_FIELD_NAME'
    COPY_ATTRIBUTES = 'COPY_FIELDS'

    F_SOURCE = 'source'
    F_PROFILE = 'profile'
    F_PX_X = 'px_x'
    F_PX_Y = 'px_y'

    def __init__(self):
        super().__init__()

        self._dstFields: Optional[QgsFields] = None
        self._results: dict = {}

    def createInstance(self):
        return ExtractSpectralProfiles()

    def name(self):
        return 'extractspectralprofiles'

    def displayName(self):
        return 'Extract spectral profiles from raster layer'

    def group(self) -> str:
        return 'Spectral Library'

    def groupId(self) -> str:
        return 'spectrallibrary'

    def shortHelpString(self) -> str:

        alg_desc = ('Extracts spectral profiles from a raster layer and for each vector geometry.\n\n'
                    'For point geometries, the pixel value at that location is extracted. '
                    'For other geometries, the centroid is used.\n\n'
                    'The output is a vector layer with spectral profile data stored in a profile field.')

        D = {
            'ALG_DESC': alg_desc,
            'ALG_CREATOR': 'benjamin.jakimow@geo.hu-berlin.de',
        }
        for p in self.parameterDefinitions():
            p: QgsProcessingParameterDefinition
            infos = [f'<i>Identifier <code>{p.name()}</code></i>']
            if i := p.help():
                infos.append(i)
            infos = [i for i in infos if i != '']
            D[p.name()] = '<br>'.join(infos)

        html = QgsProcessingUtils.formatHelpMapAsHtml(D, self)
        return html

    def initAlgorithm(self, configuration=None):

        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.P_INPUT_RASTER,
                'Input raster layer (spectral data)',
                optional=False,
                defaultValue=configuration.get(self.P_INPUT_RASTER),
            )
        )

        self.addParameter(
            QgsProcessingParameterVectorLayer(
                self.P_INPUT_VECTOR,
                'Input vector layer (sample locations)',
                optional=False,
                defaultValue=configuration.get(self.P_INPUT_VECTOR),
            )
        )

        p = QgsProcessingParameterVectorDestination(
            self.P_OUTPUT,
            'Spectral Library',
            defaultValue=configuration.get(self.P_OUTPUT),
        )
        p.setHelp('Output vector layer with spectral profiles')

        self.addParameter(p)

    def processAlgorithm(self, parameters: dict, context: QgsProcessingContext, feedback: QgsProcessingFeedback):

        # Get input parameters
        raster_layer = self.parameterAsRasterLayer(parameters, self.P_INPUT_RASTER, context)
        vector_layer = self.parameterAsVectorLayer(parameters, self.P_INPUT_VECTOR, context)

        if not raster_layer or not raster_layer.isValid():
            raise QgsProcessingException('Invalid input raster layer')

        if not vector_layer or not vector_layer.isValid():
            raise QgsProcessingException('Invalid input vector layer')

        feedback.pushInfo(f'Raster layer: {raster_layer.name()}')
        feedback.pushInfo(f'Vector layer: {vector_layer.name()}')
        feedback.pushInfo(f'Number of bands: {raster_layer.bandCount()}')
        feedback.pushInfo(f'Number of features: {vector_layer.featureCount()}')

        # Get spectral properties from raster
        spectral_props = QgsRasterLayerSpectralProperties.fromRasterLayer(raster_layer)

        xValues = spectral_props.wavelengths() if spectral_props else None
        xUnit = spectral_props.wavelengthUnits()[0] if spectral_props and spectral_props.wavelengthUnits() else None
        bbl = spectral_props.badBands(default=1) if spectral_props else None

        if all([v == 1 for v in bbl]):
            bbl = None

        if xValues and any(xValues):
            feedback.pushInfo(
                f'Wavelength range: {min([w for w in xValues if w])} - {max([w for w in xValues if w])} {xUnit}')

        # Open raster with GDAL
        ds = gdal.Open(raster_layer.source())
        if not ds:
            raise QgsProcessingException('Could not open raster with GDAL')

        no_data = []
        for b in range(ds.RasterCount):
            band: gdal.Band = ds.GetRasterBand(b + 1)
            no_data.append(band.GetNoDataValue())

        geotransform = ds.GetGeoTransform()

        feedback.setProgress(10)

        # Setup coordinate transformation if needed
        raster_crs = raster_layer.crs()
        vector_crs = vector_layer.crs()

        transform = None
        if raster_crs != vector_crs:
            transform = QgsCoordinateTransform(vector_crs, raster_crs, QgsProject.instance())
            feedback.pushInfo(f'Transforming coordinates from {vector_crs.authid()} to {raster_crs.authid()}')
            request_extent = transform.transformBoundingBox(raster_layer.extent(), Qgis.TransformDirection.Reverse)
        else:
            request_extent = raster_layer.extent()

        # Create output fields
        output_fields = QgsFields()
        output_fields.append(create_profile_field(self.F_PROFILE))
        output_fields.append(QgsField(self.F_SOURCE, QMetaType.QString))
        output_fields.append(QgsField(self.F_PX_X, QMetaType.Int))
        output_fields.append(QgsField(self.F_PX_Y, QMetaType.Int))

        if True:
            for f in vector_layer.fields():
                if f.name() not in output_fields.names() and f.name().lower() not in ['fid']:
                    f2 = QgsField(f)
                    f2.setEditorWidgetSetup(f.editorWidgetSetup())
                    output_fields.append(f2)

        output_path = parameters.get(self.P_OUTPUT)

        if not isinstance(output_path, str):
            output_path = output_path.toVariant()['sink']['val']

        if output_path == QgsProcessing.TEMPORARY_OUTPUT:
            output_path = 'dummy.gpkg'
        # self.parameterAsOutputLayer(parameters, self.P_OUTPUT, context)
        driver = QgsVectorFileWriter.driverForExtension(os.path.splitext(output_path)[1])
        output_fields = GenericFieldValueConverter.compatibleTargetFields(output_fields, driver)

        # Prepare output layer
        # writer_options = QgsVectorFileWriter.SaveVectorOptions()
        # writer_options.driverName = driver
        # writer_options.fileEncoding = 'UTF-8'

        writer, destId = self.parameterAsSink(parameters,
                                              self.P_OUTPUT,
                                              context, output_fields,
                                              vector_layer.wkbType(),
                                              vector_layer.crs())

        # writer = QgsVectorFileWriter.create(
        #    output_path,
        #    output_fields,
        #    vector_layer.wkbType(),
        #    vector_layer.crs(),
        #    context.transformContext(),
        #    writer_options
        # )

        # if writer.hasError() != QgsVectorFileWriter.NoError:
        #    raise QgsProcessingException(f'Error creating output layer: {writer.errorMessage()}')

        if not isinstance(writer, QgsFeatureSink):
            raise QgsProcessingException(f'Unable to create output file: {parameters.get(self.P_OUTPUT)}')

        feedback.setProgress(20)

        # Process features
        total_features = vector_layer.featureCount()
        features_processed = 0
        features_skipped = 0

        request = QgsFeatureRequest()
        request.setFilterRect(request_extent)

        for feature in vector_layer.getFeatures(request):

            if feedback.isCanceled():
                break

            if not feature.hasGeometry():
                features_skipped += 1
                continue

            geom = feature.geometry()

            # Transform geometry to raster CRS if needed
            if transform:
                geom_transformed = QgsGeometry(geom)
                geom_transformed.transform(transform)
            else:
                geom_transformed = geom

            # Get point coordinate (use centroid for non-point geometries)
            if geom_transformed.type() == 0:  # Point
                point = geom_transformed.asPoint()
            else:
                point = geom_transformed.centroid().asPoint()

            # Convert geographic coordinates to pixel coordinates
            px = int((point.x() - geotransform[0]) / geotransform[1])
            py = int((point.y() - geotransform[3]) / geotransform[5])

            # Check if pixel is within raster bounds
            if px < 0 or py < 0 or px >= ds.RasterXSize or py >= ds.RasterYSize:
                features_skipped += 1
                feedback.pushWarning(f'Feature {feature.id()} outside raster bounds - skipped')
                continue

            # Extract pixel values from all bands
            yValues = []
            data = ds.ReadAsArray(px, py, 1, 1)
            yValues = np.mean(data, axis=(1, 2)).tolist()
            # exclude no-data values
            yValues = [v if v != nd else None for nd, v in zip(no_data, yValues)]
            # Create profile dictionary
            profile_dict = prepareProfileValueDict(
                y=yValues,
                x=xValues,
                xUnit=xUnit,
                bbl=bbl
            )

            # Create output feature
            out_feature = QgsFeature(output_fields)
            out_feature.setId(feature.id())
            out_feature.setGeometry(feature.geometry())  # Use original geometry

            # Add profile data
            pField = output_fields.field(self.F_PROFILE)
            encoded_profile = encodeProfileValueDict(profile_dict, pField)
            out_feature.setAttribute(self.F_PROFILE, encoded_profile)

            out_feature.setAttribute(self.F_PX_X, px)
            out_feature.setAttribute(self.F_PX_Y, py)
            out_feature.setAttribute(self.F_SOURCE, raster_layer.source())

            # Copy attributes from the input feature
            for field in vector_layer.fields():
                if field.name() in [self.F_PROFILE, self.F_SOURCE, self.F_PX_X, self.F_PX_Y]:
                    continue
                if field.name() in output_fields.names():
                    out_feature.setAttribute(field.name(), feature.attribute(field.name()))

            # Write feature
            if not writer.addFeature(out_feature):
                s = ""

            features_processed += 1

            # Update progress
            progress = 20 + int(70 * features_processed / total_features)
            feedback.setProgress(progress)

        # Clean up
        if hasattr(writer, 'finalize'):
            writer.finalize()
        del writer
        del ds

        results = {self.P_OUTPUT: destId}
        self._dstFields = output_fields
        self._results = results
        return results

    def postProcessAlgorithm(self, context: QgsProcessingContext, feedback: QgsProcessingFeedback) -> Dict[str, Any]:

        vl = self._results.get(self.P_OUTPUT)
        if isinstance(vl, str):
            lyr_id = vl
            vl = QgsProcessingUtils.mapLayerFromString(vl, context,
                                                       allowLoadingNewLayers=True,
                                                       typeHint=QgsProcessingUtils.LayerHint.Vector)

        if isinstance(vl, QgsVectorLayer) and vl.isValid():
            for f in self._dstFields:
                idx = vl.fields().lookupField(f.name())
                if idx > -1:
                    setup = f.editorWidgetSetup()
                    if isinstance(setup, QgsEditorWidgetSetup):
                        vl.setEditorWidgetSetup(idx, setup)
            vl.saveDefaultStyle(QgsMapLayer.StyleCategory.AllStyleCategories)
            feedback.pushInfo(f'Created {vl.publicSource(True)}\nPost-processing finished.')
        else:
            feedback.pushWarning(f'Unable to reload {vl} as vectorlayer and set profile fields')

        feedback.setProgress(100)
        return {self.P_OUTPUT: vl}
