import os
import time
from math import ceil
from typing import List, Union, Optional, Type, Iterable

from PyQt5.QtCore import QVariant, QDate
from PyQt5.QtWidgets import QFileDialog, QApplication, QWidget
from osgeo import gdal, gdal_array
import numpy as np
import pandas as pd
from qgis._core import QgsRasterLayer, QgsFeatureRequest, QgsVectorFileWriter, QgsProject, QgsVectorLayer, QgsField, \
    QgsMessageLog, Qgis, QgsApplication
from scipy.stats import pearsonr

from landsklim.lk.cache import qgis_project_cache
from landsklim.lk.landsklim_constants import DATASET_COLUMN_X, DATASET_COLUMN_Y
from landsklim.lk.logger import Log


class LandsklimUtils:
    """
    Gives utils functions to works with raster and vector data
    """

    def __new__(cls):
        raise TypeError("LandsklimUtils can't be instantiated")

    @staticmethod
    def landsklim_version() -> str:
        """
        Read the Landsklim version from the metadata.txt file

        :returns: Landsklim version
        :rtype: str
        """
        metadata_path: str = os.path.join(os.path.dirname(__file__), os.path.pardir, 'metadata.txt')
        with open(metadata_path, "r") as f:
            lines = f.readlines()
        version = "unknown"
        for line in lines:
            tokens = line.split("=")
            if tokens[0] == "version":
                version = tokens[1].replace('\n', '')
        return version



    @staticmethod
    def netcdf_file_to_array(source: str, band: int = 1):
        ds = gdal.Open("NETCDF:{0}".format(source))
        data = np.array(ds.ReadAsArray(0, 0, ds.RasterXSize, ds.RasterYSize), dtype=np.float64)
        del ds
        return data

    @staticmethod
    def gdal_to_array(gdal_dataset: gdal.Dataset, band: int = 1) -> np.array:
        # pixelType = gdal_dataset.GetRasterBand(1).DataType
        return np.array(gdal_dataset.GetRasterBand(1).ReadAsArray(), dtype=np.float64)

    @staticmethod
    def source_to_array(raster_source: str, band: int = 1) -> np.array:
        ds = gdal.Open(raster_source)
        res = np.array(ds.GetRasterBand(band).ReadAsArray(), dtype=np.float64)
        del ds
        return res

    @staticmethod
    def raster_to_array(raster_layer: QgsRasterLayer, band: int = 1) -> np.array:
        """
        Extract QgsRasterLayer band into a numpy array

        :param raster_layer: Input raster
        :type raster_layer: QgsRasterLayer

        :param band: Band to extract
        :type band: int

        :returns: Raster values as numpy array
        :rtype: np.ndarray
        """
        return LandsklimUtils.source_to_array(raster_layer.source(), band)

    @staticmethod
    def force_local_analysis_models(phase_regression: "PhaseMultipleRegression", models_definition_filename: str):
        """
        Well should not be here (but in test_landsklim_interpolation) but sometimes it's easier to use this from QGIS to see what happens
        """
        path: str = os.path.join(os.path.dirname(__file__), '..', 'tests', 'sources', models_definition_filename)
        with open(path, 'r') as f:
            for line in f.readlines():
                tokens: List[str] = line.split('\t')
                if len(tokens) > 1:
                    polygon: int = int(tokens[0])
                    features: int = (len(tokens) - 2) // 2
                    features_names: List[str] = [feat_name for feat_name in tokens[1:1+features]]
                    intercept: float = float(tokens[1+features])
                    coefs: List[float] = [float(coef) for coef in tokens[2+features:]]
                    Log.info("Polygon {0} : Features Names : {1}, Intercept : {2}, Coefs : {3}".format(polygon, features_names, intercept, coefs))
                    polygon_model = phase_regression.get_model()[polygon]
                    dataset_model = phase_regression._filter_dataset(phase_regression.get_dataset(), polygon)
                    polygon_model.force_model_params(features_names, intercept, np.array(coefs), dataset_model)

    @staticmethod
    def adjusted_r2(r2: float, n: int, p: int) -> float:
        """
        Formula for adjusted R-squared

        :param r2: Base R-squared
        :type r2: float

        :param n: Sample size
        :type n: int

        :param p: Number of explanatory variables
        :type p: int

        :returns: Adjusted R-squared
        :rtype: float
        """

        return 1 - (1 - r2) * (n) / (n - p - 1)  # LISDQS version
        # return 1 - (1 - r2) * (n - 1) / (n - p - 1)  # Official formula
        # LISDQS use (n) instead of (n - 1) as a denominator. The reason is not yet found

    @staticmethod
    def r2(y_true: np.ndarray, y_pred: np.ndarray) -> float:
        """
        Formula for Pearson r-squared

        :param y_true: Input array
        :type y_true: (N,) array_like

        :param y_pred: Input array
        :type y_pred: (N,) array_like

        :returns: R-squared
        :rtype: float
        """
        return np.power(pearsonr(y_true, y_pred)[0], 2)
        # Formula is good but is already implemented in sklearn : r2_score(y, y_hat)
        # y_bar = np.mean(y)
        # return 1 - (np.sum(np.square(y - y_hat)) / np.sum(np.square(y - y_bar)))

    @staticmethod
    def rmse(y_true: np.ndarray, y_pred: np.ndarray, predictors: int) -> float:
        """
        Formula used in LISDQS to compute the root-mean-square error

        :param y_true: Input array
        :type y_true: (N,) array_like

        :param y_pred: Input array
        :type y_pred: (N,) array_like

        :param predictors: int
        :type predictors: Number of predictors on the model.

        :returns: RMSE
        :rtype: float
        """
        return np.sqrt(np.sum(np.power(y_true - y_pred, 2))/(len(y_true)-predictors-1))

    @staticmethod
    def unbiased_estimate_standard_deviation(array: np.ndarray) -> float:
        """
        Formula for the unbiased estimate of the standard deviation

        :param array: Input array
        :type array: (N,) array_like

        :returns: std
        :rtype: float
        """
        array_mean = array.mean()
        return np.sqrt(np.sum(np.power(array_mean-array, 2)) / (len(array) - 1))

    @staticmethod
    def rse(residuals: np.ndarray, n: int, model_features: int) -> float:
        """
        Formula for the Residual standard error.

        Measure the standard deviation of residuals on a regression model

        :param residuals: Residuals
        :type residuals: np.ndarray

        :param n: Number of samples
        :type n: int

        :param model_features: Number of features
        :type model_features: int
        """

        df = n-model_features-1  # predictors + intercept
        return np.sqrt(np.sum(np.power(residuals, 2)) / df)

    """
    @staticmethod
    def write_point_geopackage(output_shapefile: str, source: "QgsVectorLayer", data: np.ndarray, data_label: str, no_data: Union[int, float]):
        provider = "GPKG"  # source.dataProvider().storageType()
        Log.info("[write_point_shapefile]")
        if not os.path.exists(output_shapefile):
            features_id = [feature.id() for feature in source.getFeatures()]
            # create a new layer with all features
            new_layer = source.materialize(QgsFeatureRequest().setFilterFids(features_id))
            save_options = QgsVectorFileWriter.SaveVectorOptions()
            save_options.driverName = provider
            save_options.fileEncoding = source.dataProvider().encoding()
            transform_context = QgsProject.instance().transformContext()
            error = QgsVectorFileWriter.writeAsVectorFormatV3(new_layer, output_shapefile, transform_context, save_options)
            if error[0] == QgsVectorFileWriter.NoError:
                print("Save Success !")
            else:
                print("Error !")
                print(error)

        layer: QgsVectorLayer = QgsVectorLayer(output_shapefile, '')
        success = layer.dataProvider().addAttributes([QgsField(data_label, QVariant.Double, "double")])
        layer.updateFields()
        res_field_index = layer.fields().indexFromName(data_label)
        layer.startEditing()
        for i, feature in enumerate(layer.getFeatures()):  # type: int, QgsFeature
            attrs = {res_field_index: float(round(data[i], 10))}
            layer.changeAttributeValues(feature.id(), attrs)
        layer.commitChanges()
    """

    @staticmethod
    def sun_height(table: np.ndarray, day: int, month: int, year: int) -> np.ndarray:
        """
        Compute sun height for a particular day.

        :param table: Numpy Array.
            - ``table[:, 0]`` : Latitude
            - ``table[:, 1]`` : Longitude
            - ``table[:, 2]`` : Hour
        :type table: np.ndarray

        :param day: Day
        :type day: int

        :param month: Month
        :type month: int

        :param year: Year
        :type year: int

        :returns: Sun heights
        :rtype: np.ndarray
        """

        day_of_year: int = QDate(year, month, day).dayOfYear()
        sun_declinaison: float = np.radians(-23.45) * np.cos(np.radians((360/365)*(day_of_year+10)))
        local_hour_angle: np.ndarray = np.radians(15 * (table[:, 2]-12))
        lat_rad = np.radians(table[:, 0])
        elevation = np.arcsin(np.sin(sun_declinaison)*np.sin(lat_rad) + np.cos(sun_declinaison)*np.cos(lat_rad)*np.cos(local_hour_angle))
        return elevation

    @staticmethod
    def sun_height_old(table: np.ndarray, day: int, month: int, year: int) -> np.ndarray:  # pragma: no cover
        """
        Compute sun height for a particular day.

        :param table: Numpy Array.
            - ``table[:, 0]`` : Latitude
            - ``table[:, 1]`` : Longitude
            - ``table[:, 2]`` : Hour
        :type table: np.ndarray

        :param day: Day
        :type day: int

        :param month: Month
        :type month: int

        :param year: Year
        :type year: int

        :returns: Sun heights
        :rtype: np.ndarray
        """

        """if month < 3:
            year = year - 1
            month = month + 13"""

        a4 = np.pi / 180
        a6 = 180 / np.pi

        h = np.arange(24)
        jd = np.trunc((365.25 * year))
        jd = np.trunc(jd + (30.6001 * (month + 1)))
        jd = np.trunc(jd + day + 1720996.5 - (year // 100) + 1)
        jd = np.trunc(jd + ((year // 100) // 4))
        to = jd - 2415020
        to = to / 36525
        so = 2400.051262 * to
        so = so + 6.6460656

        k = h * 100
        tu = h + (k // 100 - h) / 0.6 + (100 * h - k) // 36
        jd = to + tu / 24 / 36525
        ts = so + tu * 1.0027389
        ts = ts % 24

        m2 = 358.47583 + 35999.04975 * jd
        ls = 279.69668 + 36000.76892 * jd
        ls = ls + (1.919 - 0.004789 * jd) * np.sin(m2 * a4) + 0.020094 * np.sin((2 * m2) * a4)
        ls_rad = ls * a4
        in_ = 23.45229 - 0.01301 * jd
        in_rad = in_ * a4
        de = np.sin(in_rad) * np.sin(ls_rad)
        de = np.arctan(de / np.sqrt(1 - de * de))

        de = de * a6

        al = (np.arctan(np.cos(in_rad) * np.tan(ls_rad))) * a6
        cos_ls = np.cos(ls_rad)
        al[cos_ls < 0] = al[cos_ls < 0] + 180
        gh = ts * 15 - al
        #print("[gh = {0} * 15 - {1}] = {2}".format(ts[0], al[0], gh[0]))
        gh = gh % 360

        de = np.tile(de * a4, table.shape[0] // 24)
        gh = np.tile(gh, table.shape[0] // 24)

        lat_rad = table[:, 0] * a4

        hf = np.sin(de) * np.sin(lat_rad) + np.cos(de) * np.cos(lat_rad) * np.cos((gh - table[:, 1]) * a4)
        hs = np.arctan(hf / np.sqrt(1 - hf * hf))
        return hs

    @staticmethod
    def distance_matrix(points_1: Union[np.ndarray, pd.DataFrame], points_2: Union[np.ndarray, pd.DataFrame]) -> np.ndarray:
        """
        Compute matrix of euclidean distances between two sets of points.

        :param points_1: First set of points
        :type points_1: Union[pd.DataFrame, np.ndarray]

        :param points_2: Second set of points
        :type points_2: Union[pd.DataFrame, np.ndarray]

        :returns: Matrix of dimensions (n, m) with the euclidean distance between each set of points.
        :rtype: np.ndarray
        """

        if type(points_1) is pd.DataFrame:
            points_1 = points_1[[DATASET_COLUMN_X, DATASET_COLUMN_Y]].values

        if type(points_2) is pd.DataFrame:
            points_2 = points_2[[DATASET_COLUMN_X, DATASET_COLUMN_Y]].values

        distances_matrix: Optional[np.ndarray] = np.zeros((len(points_1), len(points_2)), dtype=np.float32)
        points_2_x, points_2_y = points_2[:, 0], points_2[:, 1]
        for i, point in enumerate(points_1):  # type: pd.Series
            point_position_x = point[0]
            point_position_y = point[1]
            diff_x = point_position_x - points_2_x
            diff_y = point_position_y - points_2_y
            station_distances: np.ndarray = np.sqrt((diff_x * diff_x) + (diff_y * diff_y))
            distances_matrix[i, :] = station_distances
        return distances_matrix

    @staticmethod
    def import_numpy():
        # os.environ['OPENBLAS_NUM_THREADS'] = '1'
        import numpy as np
        return np

    @staticmethod
    def export_csv(parent: QWidget, dataset: pd.DataFrame, path: Optional[str] = None):
        if path is None:
            path, _ = QFileDialog.getSaveFileName(parent, QApplication.translate("Main", "Export to CSV (Comma-Separated Values)"), '',
                                               QApplication.translate("Main", "CSV (*.csv)"))
        if len(path) > 0:
            dataset.to_csv(path, sep=",", header=True, index=True, float_format='%.5f')

    @staticmethod
    def free_path(path: str):
        """
        Remove layers opened in QGIS pointing to a requested file
        """
        map_layers = qgis_project_cache().mapLayers().items()
        for layer_id, map_layer in map_layers:
            if os.path.exists(str(path)) and os.path.exists(map_layer.source()) and os.path.samefile(map_layer.source().split('|')[0], str(path)):
                qgis_project_cache().removeMapLayer(map_layer)

    @staticmethod
    def rename_attr(old_cls_name: str, new_cls_name: str, attr_name: str) -> str:
        old_cls_identifier: str = "_{0}_".format(old_cls_name)
        new_cls_identifier: str = "_{0}_".format(new_cls_name)
        new_attr_name: str = attr_name
        if attr_name.startswith(old_cls_identifier):
            new_attr_name = attr_name.replace(old_cls_identifier, new_cls_identifier)
        return new_attr_name
