# -*- coding: utf-8 -*-

"""
/***************************************************************************
 CatGrowing
                                 A QGIS plugin
 Region-growing algorithm for spatial analysis using categorical raster and vector data.
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-01-08
        copyright            : (C) 2025 Consejo Superior de Investigaciones Científicas (CSIC), Estación Experimental de Zonas Áridas (EEZA)
        email                : arranca@eeza.csic.es
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 3 of the License, or     *
 *   any later version.                                                    *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Developed by Alberto Ruiz-Rancaño and Gabriel del Barrio Escribano (EEZA - CSIC)'
__date__ = '2024-01-08'
__copyright__ = '(C) 2025 Consejo Superior de Investigaciones Científicas (CSIC), Estación Experimental de Zonas Áridas (EEZA)'

# This will get replaced with a git SHA1 when you do a git archive
__revision__ = '$Format:%H$'

# Standard and third-party imports
import os
import time
import numpy as np
from scipy.ndimage import binary_dilation
from osgeo import gdal

# QGIS imports
from qgis.PyQt.QtCore import QCoreApplication
from qgis.PyQt.QtWidgets import QMessageBox
from qgis.core import (
    Qgis,
    QgsProject,
    QgsVectorLayer,
    QgsVectorFileWriter,
    QgsProcessing,
    QgsFeatureSink,
    QgsProcessingAlgorithm,
    QgsProcessingParameterField,
    QgsProcessingParameterNumber,
    QgsProcessingParameterBoolean,
    QgsProcessingParameterFeatureSource,
    QgsProcessingParameterRasterLayer,
    QgsProcessingParameterVectorLayer,
    QgsProcessingParameterFeatureSink,
    QgsProcessingParameterRasterDestination,
    QgsProcessingException,
    QgsProcessingFeedback,
    QgsProcessingUtils
)
from qgis.utils import iface
from qgis import processing

# ------------------------------------------------------------------------------
# Utility Functions
# ------------------------------------------------------------------------------

def get_boundaries(array):
    """
    Compute the boundaries of a binary image using morphological dilation.

    Parameters
    ----------
    array : numpy.ndarray
        Binary image.

    Returns
    -------
    numpy.ndarray
        Binary array of boundary pixels.
    """
    struct = np.ones((3, 3), dtype=bool)
    boundaries = binary_dilation(array, structure=struct) - array
    return boundaries.astype(np.uint8)


def get_hist(array, categories):
    """
    Compute the histogram of an image based on specified categories.

    Parameters
    ----------
    array : numpy.ndarray
        Input image.
    categories : list
        List of category values.

    Returns
    -------
    numpy.ndarray
        Histogram of category frequencies.
    """
    hist, _ = np.histogram(array, bins=len(categories), range=(categories[0], categories[-1] + 1))
    return hist.astype(np.float64)


def cord_distance(f1, f2):
    """
    Compute the Orloci Cord Distance between two frequency vectors.
    The formula used for the calculation is: 𝐷 = √(2(1−(∑_(𝑖=1)^𝑛〖𝑦_𝑖1∙𝑦_𝑖2〗)/√(∑_(𝑖=1)^𝑛〖𝑦_𝑖1^2〗∙∑_(𝑖=1)^𝑛〖𝑦_𝑖2^2〗)))

    Parameters
    ----------
    f1 : numpy.ndarray
        First frequency vector.
    f2 : numpy.ndarray
        Second frequency vector.

    Returns
    -------
    float
        Orloci Cord Distance.
    """
    denom = np.sqrt((f1 * f1).sum() * (f2 * f2).sum())
    num = (f1 * f2).sum()
    orloci_distance = np.sqrt(2 * (1 - (num / denom)))
    return orloci_distance


def orloci_coefficient(seed_map, land_condition_map, seed, boundaries, categories, kernel_size):
    """
    Given a specific seed within the seed map, it calculates the Orloci coefficient for each point along its boundary.
    To calculate the Orloci coefficient, two histograms are used: the first corresponds to the seed's histogram based
    on the land condition map, and the second corresponds to the histogram of all points within a specific-sized kernel
    using the same land condition map.

    Parameters
    ----------
    seed_map : numpy.ndarray
        Binary array representing the initial seed map.
    land_condition_map : numpy.ndarray
        Array representing the land condition categories for each pixel.
    seed : numpy.ndarray
        Specific seed inside the seed map.
    boundaries : numpy.ndarray
        Binary array of boundary pixels.
    categories : list
        List of land condition categories.
    kernel_size : int
        Size of the kernel used for local histogram comparison.

    Returns
    -------
    numpy.ndarray
        Array of Orloci coefficient values for the seed's boundaries.
    """

    # Get coordinates of boundaries pixels
    boundary_coords = np.argwhere(boundaries)
    orloci_mask = np.zeros_like(boundaries, dtype=float)

    # Compute histogram of the seed area
    seed_area = land_condition_map * seed
    seed_histogram = get_hist(seed_area, categories)

    # Define kernel bounds based on whether the size is odd or even.
    # When it's even, we arbitrarily center it in the bottom right corner
    half_size = kernel_size // 2
    kernel_low, kernel_high = half_size, half_size + (kernel_size % 2)

    for x, y in boundary_coords:
        # Extract local patch around the boundary pixel
        local_patch = land_condition_map[x - kernel_low:x + kernel_high, y - kernel_low:y + kernel_high]

        if np.any(local_patch != 0):
            local_histogram = get_hist(local_patch, categories)
            value = cord_distance(seed_histogram, local_histogram)
            orloci_mask[x, y] = value + 10  # Offset to avoid zero-value bug

    return orloci_mask

def grow_regions(existing_regions, boundaries, thresholded_mask):
    """
    Expand regions by incorporating new areas based on a thresholded mask.

    Parameters
    ----------
    existing_regions : numpy.ndarray
        Current regions represented as a binary array.
    boundaries : numpy.ndarray
        Binary array indicating the boundary pixels of the regions.
    thresholded_mask : numpy.ndarray
        Binary mask identifying pixels to be added based on the Orloci criteria.

    Returns
    -------
    numpy.ndarray
        Updated regions resulting from the combination of existing regions and new areas based on the thresholded mask.
    """
    updated_regions = existing_regions + (boundaries * thresholded_mask)
    return updated_regions

def region_growth_algorithm(seed_map, land_condition_map, seed_id, kernel_size, threshold_value, feedback=None):
    """
    Compute a growing algorthim where the process stops when no pixel is updated according to the Orloci criteria.
    Given a specific seed within the seed map, it calculates the Orloci coefficient for each point along its boundary.
    To calculate the Orloci coefficient, two histograms are used: the first corresponds to the seed's histogram based
    on the land condition map, and the second corresponds to the histogram of all points within a specific-sized kernel
    using the same land condition map.

    Parameters
    ----------
    seed_map : numpy.ndarray
        Binary array representing the initial seed map.
    land_condition_map : numpy.ndarray
        Array representing the land condition map.
    seed_id : int
        Identifier for a specific seed within the seed map.
    kernel_size : int
        Array representing the land condition categories for each pixel.
    threshold_value : float
        Threshold value to determine pixels to include based on the Orloci criteria.
    feedback : QgsProcessingFeedback, optional
        Object for reporting progress and handling user cancellation.

    Returns
    -------
    numpy.ndarray
        Final grown region based on the seed map and land condition map.
    """
    # Initialize variables
    seed_region = (seed_map == seed_id).astype(int)
    updated_regions = seed_region
    thresholded_mask = np.array([1])  # Placeholder to start the loop

    # Loop until no pixel is updated
    while not np.all(thresholded_mask == 0):
        # Check for user cancellation
        if feedback and feedback.isCanceled():
            raise QgsProcessingException('--- PROCESS CANCELLED BY USER ---')

        # Calculate boundaries of the current region
        boundaries = get_boundaries(updated_regions)

        # Prevent growth beyond valid land condition map areas
        boundaries *= (land_condition_map > 0)

        # Define categories based on the land condition map
        categories = np.unique(land_condition_map)
        categories = categories[categories > 0]

        # Compute Orloci coefficients for boundary pixels
        orloci_values = orloci_coefficient(
            updated_regions, land_condition_map, seed_region, boundaries, categories, kernel_size
        )

        # Apply threshold to identify similar points
        thresholded_mask = (orloci_values >= 10) & (orloci_values <= 10 + threshold_value)

        # Update the region with new pixels
        updated_regions = grow_regions(updated_regions, boundaries, thresholded_mask)

    return updated_regions


def save_array_to_raster(gdal_data, filename, raster_array):
    """
    Save a numpy array as a raster file using metadata from an input GDAL dataset.

    Parameters
    ----------
    gdal_data : gdal.Dataset
        GDAL dataset representing the input raster file. Used to extract metadata such as dimensions, data type,
        NoData value, geotransform, and projection.
    filename : str
        Path to the output raster file.
    raster_array : numpy.ndarray
        Numpy array containing the data to be saved as a raster.

    Returns
    -------
    None
    """

    # Get raster dimensions
    rows = gdal_data.RasterYSize
    cols = gdal_data.RasterXSize

    # Retrieve the data type of the input raster
    input_data_type = gdal_data.GetRasterBand(1).DataType

    # Create the output raster file
    out_ds = gdal.GetDriverByName("GTiff").Create(
        filename, cols, rows, 1, input_data_type
    )
    out_band = out_ds.GetRasterBand(1)

    # Write the data array to the output raster
    out_band.WriteArray(raster_array, 0, 0)

    # Set NoData value
    out_band.SetNoDataValue(0)

    # Flush the cache to ensure data is written
    out_band.FlushCache()

    # Set georeferencing and projection for the output raster
    out_ds.SetGeoTransform(gdal_data.GetGeoTransform())
    out_ds.SetProjection(gdal_data.GetProjection())



class CatGrowingAlgorithm(QgsProcessingAlgorithm):
    """
    Implements a growing algorithm based on a land condition map and vector seeds.
    This algorithm processes vector and raster layers to generate a rasterized output
    and a final grown region based on the Orloci criteria.

    Constants
    ---------
    INPUT_VECTOR : str
        Identifier for the input vector layer parameter.
    INPUT_RASTER : str
        Identifier for the input raster layer parameter.
    RASTERIZE_FIELD : str
        Identifier for the field used for rasterization.
    SIZE : str
        Identifier for the kernel size parameter.
    THRESHOLD : str
        Identifier for the threshold parameter.
    OUTPUT_RASTERIZE : str
        Identifier for the rasterized output parameter.
    OUTPUT : str
        Identifier for the final output parameter.
    """

    INPUT_VECTOR = 'INPUT_VECTOR'
    INPUT_RASTER = 'INPUT_RASTER'
    RASTERIZE_FIELD = 'RASTERIZE_FIELD'
    SIZE = 'SIZE'
    THRESHOLD = 'THRESHOLD'
    OUTPUT_RASTERIZE = 'OUTPUT_RASTERIZE'
    OUTPUT = 'OUTPUT'

    def initAlgorithm(self, config):
        """
        Define the inputs and outputs of the algorithm.

        Parameters
        ----------
        config : dict
            Configuration dictionary for the algorithm.
        """
        # Input vector layer
        self.addParameter(
            QgsProcessingParameterFeatureSource(
                self.INPUT_VECTOR,
                self.tr('Input vector layer'),
                [QgsProcessing.TypeVectorAnyGeometry]
            )
        )

        # Input raster layer
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.INPUT_RASTER,
                self.tr('Input raster layer')
            )
        )

        # Field to rasterize
        self.addParameter(
            QgsProcessingParameterField(
                self.RASTERIZE_FIELD,
                self.tr('Rasterize field'),
                parentLayerParameterName=self.INPUT_VECTOR,
                allowMultiple=False,
                defaultValue=None,
                type=QgsProcessingParameterField.Numeric
            )
        )

        # Kernel size for the algorithm
        self.addParameter(
            QgsProcessingParameterNumber(
                self.SIZE,
                self.tr('Kernel size'),
                type=QgsProcessingParameterNumber.Integer,
                minValue=1,
                maxValue=25,
                defaultValue=3
            )
        )

        # Threshold value for the Orloci criteria
        self.addParameter(
            QgsProcessingParameterNumber(
                self.THRESHOLD,
                self.tr('Threshold: [0 - √2]'),
                type=QgsProcessingParameterNumber.Double,
                minValue=0.0,
                maxValue=np.sqrt(2),
                defaultValue=0.57317317
            )
        )

        # Rasterized output layer
        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT_RASTERIZE,
                self.tr('Rasterized input'),
                createByDefault=True
            )
        )

        # Final output layer
        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT,
                self.tr('Output'),
                createByDefault=False
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        """
        Execute the processing algorithm.

        Parameters
        ----------
        parameters : dict
            Dictionary containing the input parameters.
        context : QgsProcessingContext
            Context for the processing algorithm.
        feedback : QgsProcessingFeedback
            Feedback object for progress reporting and user cancellation.

        Returns
        -------
        dict
            Dictionary containing the output file paths.
        """
        # Retrieve input parameters
        source_vector = self.parameterAsSource(parameters, self.INPUT_VECTOR, context)
        source_raster = self.parameterAsRasterLayer(parameters, self.INPUT_RASTER, context)
        rasterize_field = self.parameterAsString(parameters, self.RASTERIZE_FIELD, context)
        size_value = self.parameterAsInt(parameters, self.SIZE, context)
        threshold_value = self.parameterAsDouble(parameters, self.THRESHOLD, context)
        output_layer_path = self.parameterAsOutputLayer(parameters, self.OUTPUT_RASTERIZE, context)
        output_layer_path2 = self.parameterAsOutputLayer(parameters, self.OUTPUT, context)

        # Extract polygon IDs from the vector layer
        features = source_vector.getFeatures()
        polygons_ids = [feature[rasterize_field] for feature in features]

        # Create a temporary memory layer for rasterization
        mem_layer = QgsVectorLayer(f"Polygon?crs={source_vector.sourceCrs().authid()}", "temp_layer", "memory")
        mem_layer_data = mem_layer.dataProvider()
        mem_layer_data.addAttributes(source_vector.fields())
        mem_layer.updateFields()
        mem_layer_data.addFeatures(list(source_vector.getFeatures()))

        # Rasterize the vector layer
        params = {
            'BURN': 0,
            'DATA_TYPE': 5,
            'EXTENT': source_raster.extent(),
            'EXTRA': '',
            'FIELD': rasterize_field,
            'HEIGHT': source_raster.rasterUnitsPerPixelX(),
            'INIT': None,
            'INPUT': mem_layer,
            'INVERT': False,
            'NODATA': 0,
            'OPTIONS': '',
            'OUTPUT': output_layer_path,
            'UNITS': 1,
            'USE_Z': False,
            'WIDTH': source_raster.rasterUnitsPerPixelY()
        }
        processing.run("gdal:rasterize", params, context=context, feedback=feedback)

        # Load rasterized vector and input raster as numpy arrays
        rasterize_vector = gdal.Open(output_layer_path)
        rasterize_vector_band = rasterize_vector.GetRasterBand(1)
        rasterize_vector_array = np.array(rasterize_vector_band.ReadAsArray())

        # Handle NoData values for rasterized vector
        rasterize_vector_nodata = rasterize_vector_band.GetNoDataValue()
        if rasterize_vector_nodata is not None:
            rasterize_vector_array = np.where(rasterize_vector_array == rasterize_vector_nodata, np.nan,
                                              rasterize_vector_array)

        raster = gdal.Open(source_raster.source())
        raster_band = raster.GetRasterBand(1)
        raster_array = np.array(raster_band.ReadAsArray())

        # Handle NoData values for input raster
        raster_nodata = raster_band.GetNoDataValue()
        if raster_nodata is not None:
            raster_array = np.where(raster_array == raster_nodata, np.nan, raster_array)

        # Handle case where no polygons are selected
        if not polygons_ids:
            polygons_ids = list(np.unique(rasterize_vector_array))
            polygons_ids.remove(0)

        # Initialize global output array
        global_output = np.zeros_like(raster_array, dtype=np.int32)

        # Process each polygon ID
        total_polygons = len(polygons_ids)
        for i, polygon_id in enumerate(polygons_ids):
            if feedback.isCanceled():
                raise QgsProcessingException('--- PROCESS CANCELLED BY USER ---')

            feedback.setProgress(i / total_polygons * 100)
            feedback.setProgressText(f'Processing polygon ID: {polygon_id}')

            output_id = region_growth_algorithm(rasterize_vector_array, raster_array, polygon_id, size_value, threshold_value, feedback)
            global_output += output_id

            # Update progress bar
            QCoreApplication.processEvents()

        # Save the final output raster
        save_array_to_raster(raster, output_layer_path2, global_output)

        # Clean up datasets
        rasterize_vector = None
        raster = None

        # Set layer details for the output
        layer_details = context.layerToLoadOnCompletionDetails(output_layer_path2)
        layer_details.name = f'Output - Size: {size_value} - Threshold: {threshold_value}'

        return {self.OUTPUT: output_layer_path2}

    def shortHelpString(self):
        """
        Provides a brief description of the plugin's functionality.

        Returns
        -------
        str
            A short string describing the algorithm and its purpose.
        """
        return self.tr(
            "<p>The CatGrowing plugin is a QGIS processing tool designed to perform a region-growing algorithm for spatial analysis. "
            "It starts with an initial set of seeds provided as a vector layer and analyzes their surroundings using a categorical raster layer. "
            "The algorithm identifies and expands regions by comparing the histogram of categories within the seed areas "
            "to those in their surroundings, using the Orloci Chord Distance as a dissimilarity metric. "
            "A user-defined threshold determines the inclusion of new areas into the growing regions. "
            "The plugin outputs a rasterized version of the input vector layer and a final raster representing the grown regions. "
            "This tool is particularly useful for spatial studies requiring categorical data analysis. "
            "For more details, visit: <a href='https://github.com/placeholder'>CatGrowing Documentation</a>.</p>"
            "<p><i>CatGrowing was produced in the frame of MEDCONECTA, a project supported by the Biodiversity Foundation of the Ministry "
            "for the Ecological Transition and the Demographic Challenge (MITECO) of the Government of Spain, within the framework of the Recovery, "
            "Transformation and Resilience Plan (PRTR), funded by the European Union - NextGenerationEU.</i></p>"
        )

    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return 'Categorical region growing'

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr(self.name())

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        return self.tr(self.groupId())

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return ''

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)

    def createInstance(self):
        return CatGrowingAlgorithm()
