# -*- coding: utf-8 -*-

"""
/***************************************************************************
 Rockyfor3DInputRasters
                                 A QGIS plugin
 This plugin prepares the input rasters for Rockyfor3D.
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2025-04-24
        copyright            : (C) 2025 by ecorisQ
        email                : alexandra.erbach@bfh.ch
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Alexandra Erbach'
__date__ = '2025-04-24'
__copyright__ = '(C) 2025 by ecorisQ'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import os
import inspect
from qgis.PyQt.QtGui import QIcon
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (
    QgsProcessing,
    QgsProcessingAlgorithm,
    QgsProcessingParameterRasterLayer,
    QgsProcessingParameterVectorLayer,
    QgsProcessingParameterField,
    QgsProcessingParameterFile,
    QgsProcessingParameterBoolean,
    QgsProcessingMultiStepFeedback,
    QgsRasterLayer,
    QgsVectorLayer,
    QgsProject,
    QgsProcessingException,
    QgsFeatureRequest,
    QgsExpression
)
import processing
import os
import shutil
import numpy as np


class Rockyfor3DInputRastersAlgorithm(QgsProcessingAlgorithm):


    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterRasterLayer('dtm', 'Digital Elevation Model DEM (will be copied to destination folder and converted to .asc if necessary)'))
        self.addParameter(QgsProcessingParameterVectorLayer('poly', 'Input Vector Data', types=[QgsProcessing.TypeVectorPolygon, QgsProcessing.TypeVectorLine]))
        self.addParameter(QgsProcessingParameterField('fields', 'Attributes for rasterization (default: all numeric fields except the ID field)', parentLayerParameterName='poly', allowMultiple=True, type=QgsProcessingParameterField.Numeric, defaultToAllFields=True))
        self.addParameter(QgsProcessingParameterFile('output_path', 'Destination folder (will be created if not existent)', behavior=QgsProcessingParameterFile.Folder, defaultValue=os.path.dirname(QgsProject.instance().fileName())))
        self.addParameter(QgsProcessingParameterBoolean('load_layers', 'Open results in QGIS', defaultValue=True))

    def processAlgorithm(self, parameters, context, model_feedback):
        feedback = QgsProcessingMultiStepFeedback((len(parameters['fields'])+1), model_feedback)
        results = {}
        field_warnings = []
        nodata_value = -9999

        field_constraints = {
            "rockdensity": {"min": 2000, "max": 3400, "type": "Integer"},
            "rockdensit": {"min": 2000, "max": 3400, "type": "Integer"},
            "blshape": {"min": 0, "max": 4, "type": "Integer"},
            "soiltype": {"min": 0, "max": 7, "type": "Integer"},
            "rg10": {"min": 0, "max": 100, "type": "Real"},
            "rg20": {"min": 0, "max": 100, "type": "Real"},
            "rg70": {"min": 0, "max": 100, "type": "Real"},
            "net_number": {"min": 0, "max": 999, "type": "Integer"},
            "net_energy": {"min": 0, "max": 20000, "type": "Integer"},
            "net_height": {"min": 0, "max": 15, "type": "Real"},
            "nrtrees": {"min": 0, "max": 10000, "type": "Integer"},
            "dbhmean": {"min": 0, "max": 250, "type": "Integer"},
            "dbhstd": {"min": 0, "max": 250, "type": "Integer"},
            "conif_perc": {"min": 0, "max": 100, "type": "Integer"}
        }

        # input layers
        dtm = self.parameterAsRasterLayer(parameters, 'dtm', context)
        if not dtm or not dtm.isValid():
            raise QgsProcessingException(f"❌ ERROR: DEM layer is not valid.")
        if not dtm.crs().isValid():
            raise QgsProcessingException(f"❌ ERROR: The DEM has no valid coordinate reference system assigned - please fix that.")
        dtm_path = dtm.source()
        dtm_extent = dtm.extent()
        
        layer = self.parameterAsVectorLayer(parameters, 'poly', context)
        if not layer or not layer.isValid():
            raise QgsProcessingException(f"❌ ERROR: Input vector layer is not valid.")

        if layer.featureCount() < 1 or layer.featureCount() is None:
            raise QgsProcessingException(f"❌ ERROR: The vector layer is empty and contains no features. Please check your input data.")
     
        ids_invalid_geom = []
        for feature in layer.getFeatures():
            geom = feature.geometry()
            if not geom.isGeosValid():
                ids_invalid_geom.append(str(feature.id()))
        if ids_invalid_geom:
            feedback.pushWarning(f"❌ WARNING: Feature(s) ({', '.join(ids_invalid_geom)}) of vector layer with invalid geometry. Please verify if the output rasters are correct and otherwise fix your input data.")

        # fields
        fields = parameters['fields']
        fields = [field for field in fields if field not in ['fid', 'id']]

        output_folder = self.parameterAsFile(parameters, 'output_path', context)
        load_layers = self.parameterAsBool(parameters, 'load_layers', context)
        os.makedirs(output_folder, exist_ok = True)
             
        # check if vector CRS matches DTM CRS and reproject if necessary
        if layer.crs() != dtm.crs():
            feedback.pushInfo(f"Vector layer CRS does not match the DEM CRS – the vector layer is reprojected.")
            reprojected = processing.run("native:reprojectlayer", {
                'INPUT': layer,
                'TARGET_CRS': dtm.crs().authid(),
                'OUTPUT': QgsProcessing.TEMPORARY_OUTPUT
            }, context=context, feedback=feedback, is_child_algorithm=True)
            
        # check if extents of vector layer and DTM overlap
        vector_for_check = context.getMapLayer(reprojected['OUTPUT']) if layer.crs() != dtm.crs() else layer
        vector_extent = vector_for_check.extent()

        if not vector_extent.intersects(dtm_extent):
            raise QgsProcessingException(f"❌ ERROR: The vector layer does not intersect with the extent of the DEM. Please check your input data.")
                
        # convert .tif DTM to .asc if necessary
        feedback.setCurrentStep(0)
        dtm_asc_path = os.path.join(output_folder, 'dem.asc')
        if dtm_path.lower().endswith('.tif'):
            processing.run('gdal:translate', {
            'INPUT': dtm_path,
            'NODATA': nodata_value,
            'OUTPUT': dtm_asc_path
            }, context=context, feedback=feedback)
            feedback.pushInfo(f"DEM converted to ASCII format: {dtm_asc_path}")
        
        elif dtm_path.lower().endswith('.asc') and os.path.dirname(dtm_path) != output_folder:
            feedback.pushInfo(f"DEM is already in ASCII format; copied from {dtm_path} to {dtm_asc_path}")
            shutil.copy2(dtm_path, dtm_asc_path)
            
        elif os.path.basename(dtm_path) != 'dem.asc':
            os.rename(dtm_path, dtm_asc_path)
            feedback.pushInfo(f"{dtm_path} renamed to {dtm_asc_path}")
                
        # iterate over fields
        for i, field in enumerate(fields):
            feedback.setCurrentStep(i+1)
            if feedback.isCanceled():
                break

            warnings = []

            # determine data type and choose GDAL code accordingly
            qgs_field = layer.fields().field(field)
            if qgs_field.typeName().lower().startswith('real'):
                gdal_type = 5  # Float32
            elif qgs_field.typeName().lower().startswith('integer'):
                gdal_type = 1 # Int16
            else:
                feedback.pushInfo(f"Data type of field '{field}' could not be determined - please verify.")
                gdal_type = 5
               
             # check on NULL values and replace them with notdata_value
            expression = QgsExpression(f'"{field}" IS NULL')
            request = QgsFeatureRequest(expression)
            total_count = layer.featureCount()
            null_count = sum(1 for f in layer.getFeatures(request))
            if null_count > 0:
                calc_layer = processing.run('native:fieldcalculator', {
                    'INPUT': reprojected['OUTPUT'] if layer.crs() != dtm.crs() else layer,
                    'FIELD_NAME': field,
                    'NEW_FIELD': False,
                    'FORMULA': f"coalesce(\"{field}\", {nodata_value})",
                    'OUTPUT': QgsProcessing.TEMPORARY_OUTPUT
                }, context=context, feedback=feedback, is_child_algorithm=True)
                
                if null_count == total_count:
                    field_warnings.append(f"⚠️ WARNING: Field '{field}' contains only NULL values and cannot be rasterized - please verify.")
                    continue
                else:
                    warnings.append(f"{null_count} NULL values were found")
                                        
            # check on problems with field type and value range, continue with warnings at the end
            out_of_range = 0

            if field.lower() in field_constraints:
                min_val = field_constraints[field.lower()]["min"]
                max_val = field_constraints[field.lower()]["max"]
                expected_type = field_constraints[field.lower()]["type"]
                
                if not qgs_field.typeName().startswith(expected_type):
                    warnings.append(f"field type should be {expected_type} instead of {qgs_field.typeName()}")
                        
                for feature in layer.getFeatures():
                    value = feature[field]

                    if value is not None and field.lower().startswith("rockdensit"):
                        if (min_val is not None and value < min_val) or (max_val is not None and value > max_val):
                            out_of_range +=1
                    
                    if value is not None and not field.lower().startswith("rockdensit"):
                        if (min_val is not None and value < min_val and value!=0) or (max_val is not None and value > max_val):
                            out_of_range +=1                                                  
                        
                if out_of_range > 0:
                    warnings.append(f"{out_of_range} value(s) outside the expected range ({min_val}–{max_val})")
                
            if warnings:
                field_warnings.append(f"⚠️ WARNING: Field '{field}' – {', '.join(warnings)}. Please verify.")
                    
             
            # rasterize field
            rasterized = processing.run('gdal:rasterize', {
                'INPUT': calc_layer['OUTPUT'] if null_count > 0 else (reprojected['OUTPUT'] if layer.crs() != dtm.crs() else layer),
                'FIELD': field,
                'WIDTH': dtm.rasterUnitsPerPixelX(),
                'HEIGHT': dtm.rasterUnitsPerPixelY(),
                'EXTENT': f"{dtm_extent.xMinimum()},{dtm_extent.xMaximum()},{dtm_extent.yMinimum()},{dtm_extent.yMaximum()}",
                'DATA_TYPE': gdal_type,
                'NODATA': nodata_value,
                'INIT': nodata_value,
                'UNITS': 1,
                'EXTRA': '-at',
                'OUTPUT': QgsProcessing.TEMPORARY_OUTPUT
            }, context=context, feedback=feedback, is_child_algorithm=True)

            # warp/align to DTM grid
            warped = processing.run('gdal:warpreproject', {
                'DATA_TYPE': gdal_type+1,
                'INPUT': rasterized['OUTPUT'],
                'SOURCE_CRS': dtm.crs().authid(),
                'TARGET_CRS': dtm.crs().authid(),
                'TARGET_RESOLUTION': [dtm.rasterUnitsPerPixelX(), dtm.rasterUnitsPerPixelY()],
                'TARGET_EXTENT': f"{dtm_extent.xMinimum()},{dtm_extent.xMaximum()},{dtm_extent.yMinimum()},{dtm_extent.yMaximum()}",
                'TARGET_EXTENT_CRS': dtm.crs().authid(),
                'RESAMPLING': 0,
                'REFERENCE_LAYER': dtm,
                'OUTPUT': QgsProcessing.TEMPORARY_OUTPUT
            }, context=context, feedback=feedback, is_child_algorithm=True)

            # translate to .asc
            out_path = os.path.join(output_folder, f"{field}.asc")
            if field.lower() == "rockdensit":
                out_path = os.path.join(output_folder, "rockdensity.asc")
            processing.run('gdal:translate', {
                'DATA_TYPE': gdal_type+1,
                'INPUT': warped['OUTPUT'],
                'NODATA': nodata_value,
                'EXTRA':'-co DECIMAL_PRECISION=2',
                'OUTPUT': out_path
            }, context=context, feedback=feedback, is_child_algorithm=True)
            
            # check if rockdensity cells in the outer 2 rows
            if field.lower().startswith("rockdensit"):
                rock_raster = QgsRasterLayer(out_path, "rockdensity")
                if not rock_raster.isValid():
                    feedback.pushWarning(f"⚠️ WARNING: ROCKDENSITY raster is not valid and could not be loaded for edge check.")
                else:
                    provider = rock_raster.dataProvider()
                    extent = rock_raster.extent()
                    cols = rock_raster.width()
                    rows = rock_raster.height()
                    block = provider.block(1, extent, cols, rows)

                    # load full raster into numpy
                    values = np.array([block.value(x, y) for x in range(rows) for y in range(cols)], dtype=int)
                    values = values.reshape((rows, cols))

                    # extraction of outer 2 rows + columns
                    top = values[:2, :]  # top 2 rows
                    bottom = values[-2:, :]  # bottom 2 rows
                    left = values[:, :2]  # left 2 columns
                    right = values[:, -2:]  # right 2 columns

                    # flatten all edges into a 1D array
                    edges = np.concatenate((top.ravel(), bottom.ravel(), left.ravel(), right.ravel()))

                    if np.any(edges > 0):
                        field_warnings.append(f"⚠️ WARNING: ROCKDENSITY raster contains values in the two outer rows or columns of the raster. Those will not be taken into account in the simulation!")

            if load_layers:
                rl = QgsRasterLayer(out_path, field)
                if rl.isValid():
                    QgsProject.instance().addMapLayer(rl)

            results[field] = out_path
            
        if field_warnings:
            for info in field_warnings:
                feedback.pushWarning(info)
                
        return results

    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return 'rockyfor3Drasters'

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return 'Create Rockyfor3D Input Rasters'

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        return self.tr(self.groupId())

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return ''

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)
        
    def icon(self):
        cmd_folder = os.path.split(inspect.getfile(inspect.currentframe()))[0]
        icon = QIcon(os.path.join(os.path.join(cmd_folder, 'Logo_IEA_Q.png')))
        return icon

    def createInstance(self):
        return Rockyfor3DInputRastersAlgorithm()
