# -*- coding: utf-8 -*-

"""

/***************************************************************************
 MovementAnalysis
                                 A QGIS plugin
 Toolbox for raster based movement analysis: least-cost path, cost surface, accessibility.
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-05-17
        copyright            : (C) 2024 by Zoran Čučković
        email                : cuckovic.zoran@gmail.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Zoran Čučković'
__date__ = '2024-05-17'
__copyright__ = '(C) 2024 by Zoran Čučković'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from qgis.PyQt.QtCore import QCoreApplication

from qgis.core import (QgsProcessing,
                       QgsFeatureSink,
                       QgsProcessingException,
                       QgsProcessingAlgorithm,
                       
                       QgsProcessingParameterFeatureSource,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterRasterDestination,
                        QgsProcessingParameterBoolean,
                      QgsProcessingParameterNumber,
                       QgsProcessingParameterEnum)


from osgeo import gdal
import numpy as np

IMPORT_ERROR = False
try: from skimage import graph
except ImportError : IMPORT_ERROR = True  

from .modules import Raster as rst 
from .modules  import Points as pts



    
def pix_line_length (pixels):
    
    # Find indices where matrix == search_value
    # y_idx, x_idx = np.where(matrix == search_value)

    # Combine into list of [x, y] pixel coordinates (in order of appearance)
    # pixels = np.stack((x_idx, y_idx), axis=1)
 
    # Calculate the differences between successive pixels
    deltas = np.diff(pixels, axis=0)
    
    if np.any (deltas > 1) : raise Exception ("Broken pixel line !")
                                              
    #Convert to movement lengths:
        # - Straight (x or y offset = 1, the other = 0) → distance = 1
        # - Diagonal (x and y offsets = ±1) → distance = √2
    lengths = np.where(
        np.abs(deltas).sum(axis=1) == 2,  # both x and y changed (diagonal)
        np.sqrt(2), 1 )
    
    return lengths.sum()


def get_window(p1, p2, shape, padding=0, min_ratio=None):
    """
    Return slice objects to extract a rectangular window from a 2D array
    NUMPY ORDER (y, x) !
    that contains the line from p1 to p2, with optional padding and aspect ratio correction.

    Parameters:
        p1, p2    : (row, col) coordinates (i.e., (y, x))
        shape     : (rows, cols) of the array to constrain clipping
        padding   : extra pixels around the window
        min_ratio : enforce min(width / height) or min(height / width)

    Returns:
        Tuple of slices: (row_slice, col_slice)
    """
    p1 = np.array(p1)
    p2 = np.array(p2)

    min_rc = np.minimum(p1, p2) - padding
    max_rc = np.maximum(p1, p2) + 1 + padding

    # Enforce minimum aspect ratio if requested
    if min_ratio:

            
        height = max_rc[0] - min_rc[0]
        width  = max_rc[1] - min_rc[1]

        if height == 0: height = 1
        if width == 0:  width = 1

        current_ratio = min(width / height, height / width)

        if current_ratio < min_ratio:
            # Enlarge the shorter side
            center = (min_rc + max_rc) // 2

            if width < height:
                desired_width = int(np.ceil(height * min_ratio))
                half_w = desired_width // 2
                min_rc[1] = center[1] - half_w
                max_rc[1] = center[1] + half_w + (desired_width % 2)
            else:
                desired_height = int(np.ceil(width * min_ratio))
                half_h = desired_height // 2
                min_rc[0] = center[0] - half_h
                max_rc[0] = center[0] + half_h + (desired_height % 2)

    # Clip to array bounds
    min_rc = np.clip(min_rc, [0, 0], shape)
    max_rc = np.clip(max_rc, [0, 0], shape)

    return slice(min_rc[1], max_rc[1]), slice(min_rc[0], max_rc[0])

def find_costs (costs, start_point, end_point = None):
    """"Find costs on a friction surface 
        points are given numpy way (y, x)
    """

    # IMPORTANT : skimage.graph automatically ignores np.inf !!
    masked_costs = np.where(costs > 0, costs, np.inf)
    
    lg = graph.MCP_Geometric (masked_costs)
    
    # Convert (y, x) to (row, col) = (i, j) = (y, x), but reversed for skimage.graph which expects (x, y)
    start_xy = start_point[::-1]
  
    # Calculate the least-cost distance from the start cell to all other cells
    # reverse x, y (in numpy it's y,x)
    acc, tcb = lg.find_costs(starts=[start_xy])
    
    if end_point is not None:
        end_xy = end_point[::-1]
        path = lg.traceback(end_xy)  # Returns a list of paths
    
    else : path = None
    
    return acc, tcb, path
 

class CostCorridor(QgsProcessingAlgorithm):
        # Constants used to refer to parameters and outputs. They will be
       # used when calling the algorithm from another algorithm, or when
       # calling from the QGIS console.
    
    FRICTION_SURF = 'FRICTION_SURF'
      # SPEED_TOGGLE = 'SPEED_TOGGLE'
    DEPARTURES = 'DEPARTURES'
    DESTINATIONS = 'DESTINATIONS'
      
    OUTPUT = 'OUTPUT'
    RADIUS = 'RADIUS'
    
    MARGIN = 'MARGIN'
    RELATIVE_COST = 'RELATIVE_COST'
    COMBINE_MODE = 'COMBINE_MODE'
       ##Centripetal=boolean True
       ##Peripheric=boolean True
       ##Anisotropic=boolean False
       ##Anisotropic_DEM=raster 
       
    SPEED = 'SPEED'
    
    COMBINE_MODES = [ 'addition', 'minimum', 'maximum'] 
       

    def initAlgorithm(self, config=None):
    
        if IMPORT_ERROR :  raise Exception (
                "Scikit Image not installed ! Cannot proceed further.")   
        
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                self.FRICTION_SURF,
                self.tr('Friction or speed surface')
            ) )
        """
        self.addParameter(QgsProcessingParameterBoolean(
            self.SPEED_TOGGLE,
            self.tr('Using speed vaules instead of friction'),
            False, False)) 
        """
        self.addParameter(
           QgsProcessingParameterFeatureSource(
               self.DEPARTURES,
               self.tr('Departure points'),
               [QgsProcessing.TypeVectorPoint]
           ))
           
        self.addParameter(
           QgsProcessingParameterFeatureSource(
               self.DESTINATIONS,
               self.tr('Destination points'),
               [QgsProcessing.TypeVectorPoint]
           ))
       
        self.addParameter(QgsProcessingParameterNumber(
            self.RADIUS,
            self.tr('Maximum distance'),
             QgsProcessingParameterNumber.Double, defaultValue=10000))          
            
        self.addParameter(QgsProcessingParameterNumber(
            self.MARGIN,
            self.tr('Margin (percentile)'),
              QgsProcessingParameterNumber.Double, maxValue = 100,
              defaultValue=5))
        
        self.addParameter(QgsProcessingParameterBoolean(
            self.SPEED,
            self.tr('Using speed (higher values are desirable)')))
        
        self.addParameter(QgsProcessingParameterBoolean(
            self.RELATIVE_COST,
            self.tr('Distance-relative cost (instead of total cost)')))
        
        self.addParameter(QgsProcessingParameterEnum (
            self.COMBINE_MODE,
            self.tr('Cumulation mode'),
            self.COMBINE_MODES,
            defaultValue=1))
        
        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT,
            self.tr("Output file")))

    def processAlgorithm(self, parameters, context, feedback):
        
      

        friction = self.parameterAsRasterLayer(parameters,self.FRICTION_SURF, context)
        departures = self.parameterAsSource(parameters, self.DEPARTURES, context)
        destinations = self.parameterAsSource(parameters, self.DESTINATIONS, context)
       
        #thick = self.parameterAsBool(parameters,self.THICK,context)
        radius = self.parameterAsInt(parameters,self.RADIUS,context)
        margin = self.parameterAsInt(parameters,self.MARGIN,context)
        speed = self.parameterAsInt(parameters,self.SPEED,context)
        relative = self.parameterAsInt(parameters,self.RELATIVE_COST,context)
              
        cumul_mode = self.parameterAsInt(parameters,self.COMBINE_MODE,context) 
        cumul_mode += 1 # TODO : use constants from the Raster class 
        
        output_path = self.parameterAsOutputLayer(parameters,self.OUTPUT,context)
    
        dem = rst.Raster(friction.source(), output_path)
               
        # if radius is larger than raster extent : adjust
        radius = min (radius, max(dem.size) * dem.pix)
        
        # larger window for safer calculation
        # Ideal, but too slow : calculation over the entire dataset
        dem.set_master_window (radius * 1.3)
        
        # window center
        center = np.array([dem.radius_pix, dem.radius_pix])
        
        dep_pts = pts.Points(departures)
        dest_pts = pts.Points(destinations) 
                     
        dep_pts.take(dem.extent, dem.pix)
        dest_pts.take(dem.extent, dem.pix)
        
        # Swap for efficiency (this is valid only for ISOTROPIC approach)
        if dep_pts.count > dest_pts.count : 
            dep_pts, dest_pts = dest_pts, dep_pts
            
           
        if dep_pts.count == 0 or dest_pts.count == 0 :
            err= "  \n ******* \n ERROR! \n No destination/departure points in the chosen area!"
            feedback.reportError(err, fatalError = True)
            raise QgsProcessingException(err )   
            
        # Specify the size of the cell
        # cost = surface cost * cellsize
        #SciKit calls this ANISOTROPIC : not true !! it's only geometric correction!
        sample = (1,1) # could be : (dem.pix, dem.pix) = pixel size x and y
           #cumul mode 0 for raster class : Single (here addition, so +1)
        dem.set_buffer(mode = cumul_mode, live_memory = True)
                
        dep_pts.network(dest_pts, override_radius_pix = radius/dem.pix,
                        reciprocity = True)
        
        cnt =0
        for key, ob in dep_pts.pt.items():
            
            if not ob['targets'] : continue # skip if no destinations
                
            pix_dep = ob['pix_coord']
            
            data = dem.open_window (pix_dep) 
            #radius is set on master window
            
            if speed: data = 1/data #reverse for speed calculation
        
            lcd, traceback, _ = find_costs (data, center)
            
            for key, tg in ob['targets'].items() : 
                
                pix_dest =  tg['pix_coord']
                # reposition within the window
                pix_dest = center + np.subtract(pix_dest , pix_dep)
                
                # for distance-relative option :
                # find the path to the center point (=departure)
                lcd2, traceback2, path2 = find_costs (data, pix_dest, 
                     end_point = center if relative else None)
        # ? could use a smaller window here, but dangerous for a maze-like scenario
                temp = lcd + lcd2
                
                #there is a bug in np.percentile , we need to clean the data ??
                #temp = np.where (temp>0, temp, np.nan)
                #temp[center], temp[pix_dest] = 0, 0 # re-introduce zeros ...

                # bottom percentile works for speed also, as it was inversed to 1/speed 
                # use a window between the two points (instead of the entire array)
                mask = temp <= np.percentile(temp[
                    *get_window(center, pix_dest, shape= temp.shape, 
                                padding=10, min_ratio = 1.5)], 
                    margin)
               # mask = temp <= np.percentile(temp, margin)

                if relative:
                    l = pix_line_length(path2) * dem.pix_x
                    # for speed this gives us the average speed (distance/cost, beacause 1/cost)
                    # otherwise usual weighting by distance
                    temp = l/temp if speed else temp/l
               
                temp[~mask]= np.nan if cumul_mode in [2, 3] else 0
        
                dem.add_to_buffer(temp)
                   
            
            if feedback.isCanceled(): return None
            feedback.setProgress(int(cnt / dep_pts.count * 100))
                
            cnt +=1
                
        # if cumul_mode ==3 :
        #     dem_counter.open_window (pix_point) 
        #     dem_counter.add_to_buffer(counter_mask)
                 
        dem.result[dem.result ==0]= np.nan             
      
        dem.write_output(output_path)
        
        
    
        return {self.OUTPUT:output_path} 


    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return 'Least cost corridor'

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr(self.name())

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        return self.tr(self.groupId())

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return 'Movement analysis'

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)

    def createInstance(self):
        return CostCorridor()
