"""
#-----------------------------------------------------------
# Copyright (C) 2025 Tanja Kempen, Mathias Gröbe
#-----------------------------------------------------------
# Licensed under the terms of GNU GPL 2
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.

#---------------------------------------------------------------------
"""

from typing import Any, Optional

from qgis.core import (
    QgsProcessingAlgorithm,
    QgsProcessingContext,
    QgsProcessingException,
    QgsProcessingParameterPointCloudLayer,
    QgsProcessingParameterRasterDestination,
    QgsRasterLayer,  
    QgsProcessingMultiStepFeedback,  
    QgsProcessingUtils,
)
from qgis.PyQt.QtGui import QIcon
from scipy.ndimage import median_filter
from rasterio.transform import from_origin
from rasterio.crs import CRS
import numpy as np
import laspy
import rasterio
import itertools
import os
import subprocess

PIXEL_SIZE = 0.25  # Example pixel size, adjust as needed
DTM_PIPELINE = "dtm_pipeline.json"
CHM_PIPELINE = "chm_pipeline.json"
LOW_VEGETATION_PIPELINE = "low_vegetation_pipeline.json"
HIGH_VEGETATION_PIPELINE = "high_vegetation_pipeline.json"


class TrailscanPreProcessingAlgorithm(QgsProcessingAlgorithm):
    """
    Preparation of point cloud data for TrailScan analysis.
    This algorithm processes point cloud data to create various raster outputs
    such as DTM, CHM, MRM and VDI.
    """


    POINTCLOUD = "POINTCLOUD"
    OUTPUT_NORMALIZED = "OUTPUT_NORMALIZED"

    def name(self) -> str:
        """
        Returns the algorithm name
        """
        return "preprocessing"

    def displayName(self) -> str:
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return "01 Preprocessing Point Cloud"

    def group(self) -> str:
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        return ""

    def groupId(self) -> str:
        """
        Returns the unique ID of the group this algorithm belongs to.
        """
        return ""

    def shortHelpString(self) -> str:
        """
        Returns a localised short helper string for the algorithm. 
        """

        help_string = (
            "The ALS point cloud (.laz or .las format) is processed with the TrailScan preprocessing tool.\n\n"
            "The point cloud is converted into a 4-band georeferenced raster image:\n"
            "- Band 1: Digital Terrain Model (DTM)\n"
            "- Band 2: Canopy Height Model (CHM)\n"
            "- Band 3: Micro Relief Model (MRM)\n"
            "- Band 4: Vegetation Density Index (VDI)\n\n"
            "Each band's values are normalized to the range 0–1.\n"
            "The output file is therefore named \"Normalized\"."
        )

        return help_string

    def icon(self):
        return QIcon(os.path.join(os.path.dirname(__file__), 'TrailScan_Logo.svg'))

    def initAlgorithm(self, config: Optional[dict[str, Any]] = None):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        self.addParameter(
            QgsProcessingParameterPointCloudLayer(
                name=self.POINTCLOUD,
                description="Input point cloud"
            )
        )

        self.addParameter(
            QgsProcessingParameterRasterDestination(
                name=self.OUTPUT_NORMALIZED, 
                description="Normalized"
                )
        )    


    def calculate_extent_and_transform(self, input_laz, resolution):

        
        las = laspy.read(input_laz)
        x_min, x_max = np.min(las.x), np.max(las.x)
        y_min, y_max = np.min(las.y), np.max(las.y)

        # Calculate number of pixels
        width = int(np.ceil((x_max - x_min) / resolution))
        height = int(np.ceil((y_max - y_min) / resolution))

        # Transform object
        transform = from_origin(x_min, y_max, resolution, resolution)

        return transform, width, height

    def create_single_raster(self, data_array, transform, output_path, crs, nodata_value=0):
        """Create a single-band raster from a data array.
        
        Args:
            data_array: Input data array
            output_path: Output file path
            crs: Coordinate reference system
            nodata_value: NoData value to use
        """
        height, width = data_array.shape
        crs_def = CRS.from_wkt(crs)

        with rasterio.open(
            output_path, "w",
            driver="GTiff", height=height, width=width,
            count=1, dtype=data_array.dtype, crs=crs_def, transform=transform,
            nodata=nodata_value
        ) as dst:
            dst.write(data_array, 1)    

    def normalize_percentile(self, data, low=1, high=99, nodata_value=0):
        """Normalizes data by applying a percentile cut stretch bandwise.

        Args:
            data (np.ndarray): The data to be normalized
            low: The low percentile cut value (default: 1)
            high: The high percentile cut value (default: 99)
            nodata_value: The no data value (default: 0)
        """
        # Create mask for valid data
        datamask = data != nodata_value
        
        # Calculate percentiles only for valid data
        pmin = np.array([np.percentile(data[:, :, i][datamask[:, :, i]], q=low) for i in range(data.shape[-1])])
        pmax = np.array([np.percentile(data[:, :, i][datamask[:, :, i]], q=high) for i in range(data.shape[-1])])
        
        # Normalize and clip
        normalized_data = np.clip((data - pmin) / (pmax - pmin + 1E-10), 0, 1)
        
        # Set NoData values back to 0
        normalized_data[~datamask] = nodata_value
        
        return normalized_data  

    def create_multiband_raster(self, data_arrays, transform, output_path, crs, nodata_value=0):
        """Create a multi-band raster from multiple data arrays.
        
        Args:
            data_arrays: List of input data arrays
            output_path: Output file path
            crs: Coordinate reference system
            nodata_value: NoData value to use
        """
        num_bands = len(data_arrays)
        height, width = data_arrays[0].shape
        crs_def = CRS.from_wkt(crs)

        with rasterio.open(
            output_path, "w",
            driver="GTiff", height=height, width=width,
            count=num_bands, dtype=data_arrays[0].dtype, crs=crs_def, transform=transform,
            nodata=nodata_value
        ) as dst:
            for i, data_array in enumerate(data_arrays, start=1):
                dst.write(data_array, i)                       

    def processAlgorithm(
        self,
        parameters: dict[str, Any],
        context: QgsProcessingContext,
        feedback: QgsProcessingMultiStepFeedback,
    ) -> dict[str, Any]:
        """
        Here is where the processing itself takes place.
        
        """

        counter = itertools.count(1)
        count_max = 7
        feedback = QgsProcessingMultiStepFeedback(count_max, feedback)

        sourceCloud = self.parameterAsPointCloudLayer(parameters, self.POINTCLOUD, context)
        input_laz = sourceCloud.dataProvider().dataSourceUri()
        # Normalize provider URI for PDAL CLI and local reads
        if isinstance(input_laz, str) and input_laz.lower().startswith("pdal://"):
            input_laz = input_laz[len("pdal://"):]
        vdi_outfile = QgsProcessingUtils.generateTempFilename("VDI.tif", context=context)
        dtm_outfile = QgsProcessingUtils.generateTempFilename("DTM.tif", context=context)
        mrm_outfile = QgsProcessingUtils.generateTempFilename("MRM.tif", context=context)
        chm_outfile = QgsProcessingUtils.generateTempFilename("CHM.tif", context=context)
        low_vegetation_outfile = QgsProcessingUtils.generateTempFilename("LowVegetation.tif", context=context)
        high_vegetation_outfile = QgsProcessingUtils.generateTempFilename("HighVegetation.tif", context=context)
        output_raster = self.parameterAsOutputLayer(parameters, self.OUTPUT_NORMALIZED, context)

        if sourceCloud is None:
            raise QgsProcessingException(
                self.invalidSourceError(parameters, self.POINTCLOUD)
            )

       # Define CREATE_NO_WINDOW only on Windows to suppress the console window
        creationflags = 0
        if os.name == "nt":
            creationflags = getattr(subprocess, "CREATE_NO_WINDOW", 0)              

        crs = sourceCloud.crs().horizontalCrs()
        if not crs.isValid():
            raise QgsProcessingException("Invalid CRS in input point cloud")

        # Check if PDAL is installed
        try:
            subprocess.run(["pdal", "--version"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        except (subprocess.CalledProcessError, FileNotFoundError):
            raise QgsProcessingException("PDAL is not installed or not found in PATH. Please install PDAL to continue.")

        # Check if PDAL filter.expression is available
        try:
            result = subprocess.run(
                ["pdal", "--drivers"],
                check=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                creationflags=creationflags,
            )
            if "filters.expression" not in result.stdout:
                raise QgsProcessingException("PDAL 'filters.expression' driver is not available. Please ensure your PDAL installation includes this filter.")
        except subprocess.CalledProcessError as e:
            feedback.reportError(e.stderr or str(e))
            raise QgsProcessingException("Failed to check PDAL drivers. See log for details.")

        # Ensure output directories exist
        for outfile in [dtm_outfile, mrm_outfile, chm_outfile, vdi_outfile, low_vegetation_outfile, high_vegetation_outfile, output_raster]:
            out_dir = os.path.dirname(outfile)
            if out_dir and not os.path.isdir(out_dir):
                os.makedirs(out_dir, exist_ok=True)

        feedback.setCurrentStep(next(counter))
        if feedback.isCanceled():
            return {}

        feedback.pushInfo("Creating DTM...")                          

        # Run DTM PDAL pipeline
        dtm_pipeline_path = os.path.join(os.path.dirname(__file__), DTM_PIPELINE)
        
        try:
            subprocess.run(
                [
                    "pdal", "pipeline", dtm_pipeline_path,
                    f"--readers.las.filename={input_laz}",
                    f"--writers.gdal.filename={dtm_outfile}",
                    f"--writers.gdal.resolution={PIXEL_SIZE}",
                ],
                check=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                creationflags=creationflags
            )
        except subprocess.CalledProcessError as e:
            feedback.reportError(e.stderr or str(e))
            raise QgsProcessingException("PDAL DTM pipeline failed. See log for details.")
        if not os.path.exists(dtm_outfile):
            raise QgsProcessingException(f"DTM output file was not created: {dtm_outfile}")

        feedback.setCurrentStep(next(counter))
        if feedback.isCanceled():
            return {}

        feedback.pushInfo("Creating CHM...")

        chm_pipeline_path = os.path.join(os.path.dirname(__file__), CHM_PIPELINE)
        try:
            subprocess.run(
                [
                    "pdal", "pipeline", chm_pipeline_path,
                    f"--readers.las.filename={input_laz}",
                    f"--writers.gdal.filename={chm_outfile}",
                    f"--writers.gdal.resolution={PIXEL_SIZE}",
                ],
                check=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                creationflags=creationflags,
                text=True,
            )
        except subprocess.CalledProcessError as e:
            feedback.reportError(e.stderr or str(e))
            raise QgsProcessingException("PDAL CHM pipeline failed. See log for details.")
        if not os.path.exists(chm_outfile):
            raise QgsProcessingException(f"CHM output file was not created: {chm_outfile}")

        # Load rasters
        with rasterio.open(chm_outfile) as chm_src:
            chm_array = chm_src.read(1)

        with rasterio.open(dtm_outfile) as dtm_src:
            dtm_array = dtm_src.read(1)
            nodata_value = dtm_src.nodata if dtm_src.nodata is not None else 0
            # Use DTM as authoritative grid
            transform = dtm_src.transform
            height = dtm_src.height
            width = dtm_src.width


        # Grid information already taken from DTM raster

        feedback.setCurrentStep(next(counter))
        if feedback.isCanceled():
            return {}

        feedback.pushInfo("Calculating MRM...")
        dtm_smoothed_array = median_filter(dtm_array, size=10)
        mrm_array = dtm_array - dtm_smoothed_array
        mrm_array = np.clip(mrm_array, -1, 1)

        self.create_single_raster(mrm_array, transform, mrm_outfile, crs.toWkt(), nodata_value=nodata_value)

        feedback.setCurrentStep(next(counter))
        if feedback.isCanceled():
            return {}

        feedback.pushInfo("Calculating low vegetation...")

        try:
            subprocess.run(
                [
                    "pdal", "pipeline", os.path.join(os.path.dirname(__file__), LOW_VEGETATION_PIPELINE),
                    f"--readers.las.filename={input_laz}",
                    f"--writers.gdal.filename={low_vegetation_outfile}",
                    f"--writers.gdal.resolution={PIXEL_SIZE}",
                ],
                check=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                creationflags=creationflags,
            )
        except subprocess.CalledProcessError as e:
            feedback.reportError(e.stderr or str(e))
            raise QgsProcessingException("PDAL low vegetation pipeline failed. See log for details.")
        if not os.path.exists(low_vegetation_outfile):
            raise QgsProcessingException(f"Low vegetation output file was not created: {low_vegetation_outfile}")

        feedback.setCurrentStep(next(counter))
        if feedback.isCanceled():
            return {}

        feedback.pushInfo("Calculating high vegetation...")

        try:
            subprocess.run(
                [
                    "pdal", "pipeline", os.path.join(os.path.dirname(__file__), HIGH_VEGETATION_PIPELINE),
                    f"--readers.las.filename={input_laz}",
                    f"--writers.gdal.filename={high_vegetation_outfile}",
                    f"--writers.gdal.resolution={PIXEL_SIZE}",
                ],
                check=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                creationflags=creationflags,
            )
        except subprocess.CalledProcessError as e:
            feedback.reportError(e.stderr or str(e))
            raise QgsProcessingException("PDAL high vegetation pipeline failed. See log for details.")
        if not os.path.exists(high_vegetation_outfile):
            raise QgsProcessingException(f"High vegetation output file was not created: {high_vegetation_outfile}")

        feedback.setCurrentStep(next(counter))
        if feedback.isCanceled():
            return {}

        feedback.pushInfo("Calculating VDI...")

        with rasterio.open(low_vegetation_outfile) as low_veg_src:
            low_veg_array = low_veg_src.read(1)
        with rasterio.open(high_vegetation_outfile) as high_veg_src:
            high_veg_array = high_veg_src.read(1)   

        with np.errstate(divide='ignore', invalid='ignore'):
            vdi_array = np.divide(
                low_veg_array.astype(np.float32),
                high_veg_array.astype(np.float32),
                out=np.zeros_like(low_veg_array, dtype=np.float32),
                where=high_veg_array != 0,
            )

        vdi_array = np.where(vdi_array == 0, 0.1, vdi_array)

        self.create_single_raster(vdi_array, transform, vdi_outfile, crs.toWkt(), nodata_value=nodata_value)

        feedback.setCurrentStep(next(counter))
        if feedback.isCanceled():
            return {}

        feedback.pushInfo("Creating normalized raster by combining the results...")

        combined_array = np.stack([dtm_array, chm_array, mrm_array, vdi_array], axis=2)
        normalized_array = self.normalize_percentile(combined_array, nodata_value=nodata_value)

        self.create_multiband_raster(
            [normalized_array[:,:,i] for i in range(4)], 
            transform, 
            output_raster,
            crs.toWkt(),
            nodata_value=nodata_value
        )

        feedback.setCurrentStep(count_max)

        # Register the output raster
        raster_out = {'OUTPUT': output_raster}

        return {self.OUTPUT_NORMALIZED: raster_out["OUTPUT"]}

    def createInstance(self):
        return self.__class__()
