import cProfile
import datetime
import gc
import os.path
import time
from functools import partial
from math import ceil
from multiprocessing import Pool
import multiprocessing as mp
from typing import Union, Tuple, Callable, Optional
import ctypes

# import PIL.ImageCms
import pandas as pd
import pyproj
from PIL import Image
from PyQt5.QtCore import QDateTime
from osgeo import gdal
from osgeo import osr
from osgeo.osr import SpatialReference
from qgis.PyQt.QtCore import QCoreApplication, QDate
from qgis._core import QgsCoordinateTransform, QgsCoordinateReferenceSystem, QgsPointXY, QgsRectangle, \
    QgsProcessingParameterDateTime, QgsProcessingParameterRasterLayer
from qgis.core import QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException, QgsRasterLayer
from qgis import processing
import numpy as np

from landsklim.lk import environment
from landsklim.lk.cache import qgis_project_cache
from landsklim.processing.landsklim_processing_regressor_algorithm import LandsklimProcessingRegressorAlgorithm
from landsklim.lk.utils import LandsklimUtils
from landsklim.lk.logger import Log

debug_i = 3
debug_j = 10

class RadiationProcessingAlgorithm(LandsklimProcessingRegressorAlgorithm):
    """
    Processing algorithm computing radiation from a DEM
    """

    INPUT_SLOPE = "INPUT_SLOPE"
    INPUT_ORIENTATION = "INPUT_ORIENTATION"

    def __init__(self):
        super().__init__({'OUTPUT': 'Radiation raster'})
        self.__geotransform = None
        self._orientation_raster: Optional[np.ndarray] = None
        self._slope_raster: Optional[np.ndarray] = None

        self.__year = 2024
        self.__month = 1
        self.__day = 1

        self.__a4 = np.pi / 180
        self.__a6 = 180 / np.pi

    def log(self, message):
        if environment.TEST_MODE:
            Log.info(message)

    def createInstance(self):
        return RadiationProcessingAlgorithm()

    def initAlgorithm(self, config=None):
        super().initAlgorithm(config)

        self.addParameter(
            QgsProcessingParameterDateTime(
                'INPUT_DATE',
                self.tr('Date (used for solar irradiance)')
            )
        )

    def name(self) -> str:
        """
        Unique name of the algorithm
        """
        return 'radiation'

    def displayName(self) -> str:
        """
        Displayed name of the algorithm
        """
        return self.tr('Radiation')

    def shortHelpString(self) -> str:
        return self.tr('Compute windowed radiation from DEM')

    def add_dependencies(self):
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                RadiationProcessingAlgorithm.INPUT_SLOPE,
                self.tr('Slope', "LandsklimProcessingRegressorAlgorithm")
            )
        )
        self.addParameter(
            QgsProcessingParameterRasterLayer(
                RadiationProcessingAlgorithm.INPUT_ORIENTATION,
                self.tr('Orientation', "LandsklimProcessingRegressorAlgorithm")
            )
        )


    def pixel2coord(self, x, y):
        xoff, a, b, yoff, d, e = self.__geotransform
        xp = a * x + b * y + xoff
        yp = d * x + e * y + yoff
        return xp, yp

    def processAlgorithm(self, parameters, context, feedback):
        """
        Called when a processing algorithm is run
        """
        # Load input raster and its metadata
        input_raster: QgsRasterLayer = self.parameterAsRasterLayer(parameters, 'INPUT', context)
        input_slope: QgsRasterLayer = self.parameterAsRasterLayer(parameters, RadiationProcessingAlgorithm.INPUT_SLOPE, context)
        input_orientation: QgsRasterLayer = self.parameterAsRasterLayer(parameters, RadiationProcessingAlgorithm.INPUT_ORIENTATION, context)
        no_data, self.__geotransform = self.get_raster_metadata(parameters, context, input_raster)

        input_window = self.parameterAsInt(parameters, 'INPUT_WINDOW', context)
        additional_variables_folder = self.parameterAsString(parameters, 'ADDITIONAL_VARIABLES_FOLDER', context).strip()

        out_srs: SpatialReference = self.get_spatial_reference(input_raster)

        # Path of the layer is given. If a temporary layer is selected, layer is created in qgis temp dir
        out_radiation = self.parameterAsOutputLayer(parameters, 'OUTPUT', context)

        date: QDateTime = self.parameterAsDateTime(parameters, 'INPUT_DATE', context)
        self.__year = date.date().year()
        self.__month = date.date().month()
        self.__day = date.date().day()

        np_input = LandsklimUtils.raster_to_array(input_raster)

        neighbors_count_raster: np.ndarray = self.get_neighboors_count_raster(np_input, self.create_kernel(input_window), no_data)

        self.__transform = QgsCoordinateTransform(input_raster.crs(), QgsCoordinateReferenceSystem("EPSG:4326"), qgis_project_cache())

        if self._orientation_raster is None:
            self._orientation_raster = LandsklimUtils.raster_to_array(input_orientation)
            self._slope_raster = LandsklimUtils.raster_to_array(input_slope)

        output_no_data = -9999

        if environment.TEST_MODE:
            np_radiation = self.radiation(np_input, no_data, input_raster.crs().authid(), output_no_data)
            np_radiation[(np_radiation != output_no_data) & (neighbors_count_raster < 3)] = 0
        else:
            smoothed_source = processing.run("landsklim:altitude",
                                             {
                                                 'INPUT': input_raster,
                                                 'INPUT_WINDOW': input_window,
                                                 'INPUT_CUSTOM_NO_DATA': None,
                                                 'OUTPUT': QgsProcessing.TEMPORARY_OUTPUT
                                             })['OUTPUT']
            np_radiation = self.grass_radiation(smoothed_source, input_orientation.source(), input_slope.source(), date.date().dayOfYear(), np_input.shape)

        self.write_raster(out_radiation, np_radiation, out_srs, self.__geotransform, output_no_data)

        return {'OUTPUT': out_radiation}


    def radiation(self, input_raster: np.ndarray, no_data: Optional[Union[int, float]], authid: str, output_no_data: Union[int, float]) -> np.ndarray:
        """
        Compute radiation raster

        :param input_raster: DEM
        :type input_raster: np.ndarray

        :param no_data: DEM's NO_DATA
        :type no_data: Optional[Union[int, float]]

        :param authid: CRS of the DEM
        :type authid: str

        :param output_no_data: NO_DATA of radiation raster
        :type output_no_data: Union[int, float]
        """
        time_total = time.perf_counter()
        no_data_mask: np.ndarray = input_raster == no_data

        if self.__month < 3:
            self.__year = self.__year - 1
            self.__month = self.__month + 13

        self._orientation_raster[self._orientation_raster == 365] = 90

        p1 = pyproj.CRS(authid)
        p2 = pyproj.CRS('EPSG:4326')
        y_valid, x_valid = np.nonzero(~no_data_mask)
        x_crs, y_crs = self.pixel2coord(x_valid, y_valid)
        lat_array, lon_array = pyproj.transform(p1, p2, x_crs, y_crs)

        radiations: np.ndarray = np.ones(lat_array.size) * output_no_data

        # Make chunks from here
        chunk_size = 262144  # 2000000 # 262144  # batch size of 262144 * 24 = 6291456
        chunks_count = ceil(lat_array.size/chunk_size)

        for i in range(chunks_count):
            time_chunk = time.perf_counter()
            start_i = i * chunk_size
            end_i = (i + 1) * chunk_size
            batch_size = lat_array[start_i:end_i]  # Even if end_i > len(xi), no error, just subarray is smaller

            # 'table' is a np array with 4 columns [lat, lon, hour, sun altitude] for each valid cells
            table = np.zeros((batch_size.size * 24, 4))
            # np.repeat repeat each element 24 times
            table[:, 0] = np.repeat(lat_array[start_i:end_i], 24)
            table[:, 1] = np.repeat(lon_array[start_i:end_i], 24)
            # np.tile create repetitions of [0,1,2...,24] for each cell
            table[:, 2] = np.tile(np.arange(0, 24), batch_size.size)
            time_hsol = time.perf_counter()
            table[:, 3] = LandsklimUtils.sun_height(table, self.__day, self.__month, self.__year)
            y_valid_array = np.repeat(y_valid[start_i:end_i], 24)
            x_valid_array = np.repeat(x_valid[start_i:end_i], 24)
            time_hsol = time.perf_counter() - time_hsol

            time_table = time.perf_counter()
            df = pd.DataFrame(
                {
                    'i': (y_valid_array * input_raster.shape[1] + x_valid_array),
                    'hs': table[:, 3],
                    'h': table[:, 2].astype(int),
                    'slope': self._slope_raster[y_valid_array, x_valid_array],
                    'orientation': self._orientation_raster[y_valid_array, x_valid_array]
                }
            )
            time_table = time.perf_counter() - time_table

            rglobs: np.ndarray = self.sun_energy(df)
            radiations[start_i:end_i] = rglobs
            time_chunk = time.perf_counter() - time_chunk

            self.log("[rad][hsol] {0:.3f}s".format(time_hsol))
            self.log("[rad][table] {0:.3f}s\n".format(time_table))
            Log.info("[rad][chunk {0}/{1}] {2:.2f}s".format(i+1, chunks_count, time_chunk))

        radiation_raster: np.ndarray = np.ones(input_raster.size) * output_no_data
        radiation_raster[~no_data_mask.ravel()] = radiations
        radiation_raster = radiation_raster.reshape(input_raster.shape)
        self.log("\n[rad][total] {0:.3f}".format(time.perf_counter() - time_total))

        """lat_debug, lon_debug = pyproj.transform(p1, p2, self.pixel2coord(debug_j, debug_i)[0], self.pixel2coord(debug_j, debug_i)[1])
        print("[debug]", lat_debug, lon_debug)
        print("[debug radiation]", radiation_raster[debug_i, debug_j])
        print("[debug orientation]", self._orientation_raster[debug_i, debug_j])
        print("[debug slope]", self._slope_raster[debug_i, debug_j])"""
        return radiation_raster

    def sun_energy(self, table: pd.DataFrame) -> np.ndarray:
        """
        Daily ground energy computed from the table.

        :param table: Table with hourly data about sun

            - ``table.i`` : Position index (relative to the DEM)
            - ``table.h`` : Hour
            - ``table.hs`` : Sun height at this time
            - ``table.slope`` : Slope
            - ``table.orientation`` : Orientation

        :type table: pd.DataFrame

        :returns: Daily ground energy
        :rtype: np.ndarray of size ``len(table)/24``
        """

        # table["hs"] = -0.2347 # TODO: Bug if only "hs" negative values for a location : filtering should remove every rows for a (X, Y) group so beware. Currently they are NaN so it's easy to replace with 0 (but we could avoid making these useless computations)

        table_h, table_hs = table['h'].values, table['hs'].values

        table_he = table_hs * self.__a6

        time_ph = time.perf_counter()
        # Obtenir l'heure du levé du soleil : pour chaque coordonnée, obtenir la dernière heure ou le soleil est sous l'horizon
        sun_is_below = ((table_h < 12) & (table_hs < 0))
        ph = table[sun_is_below].groupby('i').agg({'h': 'max'})
        ph = ph.rename(columns={'h': 'ph'})
        self.log("[rad][ph] {0:.3f}s".format(time.perf_counter() - time_ph))

        time_dh = time.perf_counter()
        # Obtenir l'heure de couché du soleil : pour chaque coordonnée, obtenir la dernière heure ou le soleil est au dessus de l'horizon
        sun_is_above = (table_h >= 12) & (table_hs > 0)
        dh = table[sun_is_above].groupby('i').agg({'h': 'max'})
        dh = dh.rename(columns={'h': 'dh'})
        self.log("[rad][dh] {0:.3f}s".format(time.perf_counter() - time_dh))

        time_merge = time.perf_counter()
        # table = table.merge(ph, left_on='i', right_index=True, how='left')
        i_sun_is_below = np.unique(np.nonzero(sun_is_below)[0]//24) * 24
        i_sun_is_below_hours = (i_sun_is_below.reshape(-1, 1) + np.arange(24)).ravel()
        ph_by_hour = np.zeros(len(table), dtype='int8')
        ph_by_hour[i_sun_is_below_hours] = np.repeat(ph.values + 1, 24)
        table_ph = ph_by_hour
        self.log("[rad][merge_1] {0:.3f}s".format(time.perf_counter() - time_merge))

        """time_manips = time.perf_counter()
        table["ph"] += 1
        # Si le soleil ne se couche ou ne se lève pas, alors l'aggrégation a retourné 'NaN' : on remplace par des valeurs par défaut
        table["ph"] = table["ph"].fillna(0).astype('int8')
        self.log("[rad][conv] {0:.3f}s".format(time.perf_counter() - time_manips))"""

        time_merge = time.perf_counter()
        # table = table.merge(dh, left_on='i', right_index=True, how='left')
        i_sun_is_above = np.unique(np.nonzero(sun_is_above)[0]//24) * 24
        i_sun_is_above_hours = (i_sun_is_above.reshape(-1, 1) + np.arange(24)).ravel()
        dh_by_hour = np.ones(len(table), dtype='int8') * 23
        dh_by_hour[i_sun_is_above_hours] = np.repeat(dh.values, 24)
        table_dh = dh_by_hour + 1
        self.log("[rad][merge_2] {0:.3f}s".format(time.perf_counter() - time_merge))

        """time_manips = time.perf_counter()
        table["dh"] = table["dh2"].fillna(23).astype('int8') + 1
        # table.loc[table["ph"] < 0, "ph"] = 0
        # table.loc[table["dh"] > 23, "dh"] = 23
        # table["dh"] += 1
        self.log("[rad][conv] {0:.3f}s".format(time.perf_counter() - time_manips))"""

        time_filtering = time.perf_counter()
        mask = (table_ph <= table_h) & (table_h < table_dh)
        table = table.loc[mask].copy()
        self.log("[rad][filtering] {0:.3f}s".format(time.perf_counter() - time_filtering))

        time_calculs = time.perf_counter()
        hs = table["hs"].values
        he = table_he[mask]
        ty = table["orientation"].values
        pnt = table["slope"].values
        pnt_rad = pnt * self.__a4
        h = table["h"].values
        a5 = ((1000 * .2) - 70) * np.sin(np.power(hs, .4))
        rd = (1300 - 30 * 3) * np.exp(-3 / (12.6 * np.sin((he + 2) * self.__a4)))
        gh = a5 + (rd * np.sin(hs))
        cos_pnt = np.cos(pnt_rad)
        d2 = .5 * (a5 * (1 + cos_pnt) + .2 * gh * (1 - cos_pnt))
        ax = ty - (h * 15)
        az = np.abs(ax.astype(int))
        az = np.where(az > 180, 360 - az, az)
        d3 = .5 * (a5 + .2 * gh)
        cc = np.abs((rd * np.sin(((pnt * np.cos(ax * self.__a4)) + he) * self.__a4)).astype(int))
        rglobs = (d2 - (d3 / 4.0) * (az / 180.0)) + cc

        rglobs[np.isnan(rglobs)] = 0
        table["rglobs"] = rglobs
        rglobs = table.groupby('i').agg({'rglobs': 'sum'})['rglobs'].values / 100  # rglob = np.sum(rglobs)//100

        self.log("[rad][calculs] {0:.3f}s".format(time.perf_counter() - time_calculs))

        return rglobs

    def grass_radiation(self, altitude_raster: str, orientation_raster: str, slope_raster: str, day_of_year: int, output_shape: Tuple) -> np.ndarray:

        output = processing.run("grass7:r.sun.insoltime", {
            'elevation': altitude_raster,
            'aspect': orientation_raster,
            'aspect_value': 270,
            'slope': slope_raster,
            'slope_value': 0, 'linke': None, 'albedo': None, 'albedo_value': 0.2, 'lat': None, 'long': None,
            'coeff_bh': None, 'coeff_dh': None, 'horizon_basemap': None, 'horizon_step': None, 'day': day_of_year, 'step': 0.5,
            'declination': None, 'distance_step': 1, 'npartitions': 1, 'civil_time': None, '-p': True, '-m': False,
            'insol_time': QgsProcessing.TEMPORARY_OUTPUT,
            'beam_rad': QgsProcessing.TEMPORARY_OUTPUT,
            'diff_rad': QgsProcessing.TEMPORARY_OUTPUT,
            'glob_rad': QgsProcessing.TEMPORARY_OUTPUT,
            'GRASS_REGION_PARAMETER': None,
            'GRASS_REGION_CELLSIZE_PARAMETER': 0,
            'GRASS_RASTER_FORMAT_OPT': '',
            'GRASS_RASTER_FORMAT_META': ''
        })

        global_irradiance: QgsRasterLayer = output['glob_rad']
        np_irradiance = LandsklimUtils.source_to_array(global_irradiance)
        np_irradiance[np.isnan(np_irradiance)] = -9999

        try:
            mode = Image.Resampling.NEAREST
        except AttributeError as _:
            mode = Image.NEAREST

        # r.sun.insoltime can return a raster with a different shape because GRASS seems to handle only one pixel size for both horizontal and vertical axes.
        np_irradiance = np.array(Image.fromarray(np_irradiance.T).resize(output_shape, mode)).T

        return np_irradiance
