import time
from typing import Optional, List, Dict
import numpy as np
import pandas as pd
from landsklim.lk.landsklim_constants import DATASET_COLUMN_X, DATASET_COLUMN_Y
from landsklim.lk.logger import Log
from landsklim.lk.map_layer import RasterLayer

from qgis._core import QgsRasterLayer, QgsPointXY


class PolygonsDefinition:
    """
    Class containing data about polygons
    """

    def __init__(self, polygons_definition: Optional[List[List[int]]], polygons_connectedness: List[List[int]], polygons_centroids: Optional[np.ndarray]):
        """
        :param polygons_definition: Polygons definition
        :type polygons_definition: Optional[List[List[int]]]

        :param polygons_connectedness: Connected polygons for each polygons
        :type polygons_connectedness: List[List[int]]

        :param polygons_centroids: Centroid of each polygons
        :type polygons_centroids: Optional[np.ndarray]

        """

        self.__polygons_path: Optional[str] = None
        self.__polygons_definition: Optional[List[List[int]]] = polygons_definition
        self.__polygons_connectedness: Optional[List[List[int]]] = polygons_connectedness
        self.__polygons_centroids: Optional[np.ndarray] = polygons_centroids

    def set_polygons(self, polygons_path: Optional[str], polygons_definition: Optional[List[List[int]]], polygons_connectedness: List[List[int]], polygons_centroids: Optional[np.ndarray]):
        """

        :param polygons_path: Path of the raster with the polygons
        :type polygons_path: Optional[str]

        :param polygons_definition: Polygons definition
        :type polygons_definition: Optional[List[List[int]]]

        :param polygons_connectedness: Connected polygons for each polygons
        :type polygons_connectedness: List[List[int]]

        :param polygons_centroids: Centroid of each polygons
        :type polygons_centroids: Optional[np.ndarray]

        """
        self.__polygons_path = polygons_path
        self.__polygons_definition = polygons_definition
        self.__polygons_connectedness = polygons_connectedness
        self.__polygons_centroids = polygons_centroids

    def set_polygon_path(self, polygons_path: str):
        self.__polygons_path = polygons_path

    def polygons_path(self) -> str:
        """
        :returns: Polygons layer path
        :rtype: str
        """
        return self.__polygons_path

    def polygons_definition(self) -> Optional[List[List[int]]]:
        """
        :returns: List[List[int]] of shape (*polygons_count*, n) : Definition of each polygon,
            i.e. the stations included on the station's neighborhood
        :rtype: Optional[List[List[int]]]
        """
        return self.__polygons_definition

    def polygons_connectedness(self) -> List[List[int]]:
        """
        :returns: List[List[int]] of connectedness : List of each neighbors polygons
        :rtype: List[List[int]]
        """
        return self.__polygons_connectedness

    def polygons_centroids(self) -> Optional[np.ndarray]:
        """
        :returns: np.ndarray of center of gravity : List of each polygons' center of gravity
        :rtype: Optional[np.ndarray]
        """
        return self.__polygons_centroids

    def polygons_count(self) -> int:
        return len(self.__polygons_definition) if self.__polygons_definition is not None else 1

    def get_polygon_of_points(self, dataset: pd.DataFrame) -> np.ndarray:
        """
        Associate each row of a dataset its polygon

        :param dataset: Dataset in the same projection as the DEM.
        :type dataset: pd.DataFrame

        :returns: Array of polygons number, for each dataset rows
        :rtype: np.ndarray
        """

        # Rows with nan will not be interpolated.
        # nan_mask is the mask array identifying rows with nans
        # dataset_polygons is the array storing polygon of each row of 'dataset'. (NaN for rows with NaN)
        # points is the dataframe containing only valid rows to speed up calculations
        # polygons_layer: QgsRasterLayer = self.polygons_layer().qgis_layer()
        polygons_layer: QgsRasterLayer = QgsRasterLayer(self.__polygons_path)

        nan_mask: np.ndarray = dataset.isna().any(axis=1).values
        points = dataset.loc[~nan_mask, [DATASET_COLUMN_X, DATASET_COLUMN_Y]]

        start = time.perf_counter()
        raster_data_provider = polygons_layer.dataProvider()
        points_polygons = np.array(list(map(lambda r: raster_data_provider.sample(QgsPointXY(r[0], r[1]), 1)[0], zip(points[DATASET_COLUMN_X].values, points[DATASET_COLUMN_Y].values))))
        Log.info("[get_polygon_of_points] {0:.3f}s".format(time.perf_counter() - start))

        """start = time.perf_counter()
        # It's faster to create a QgsMapLayer through a csv temp file
        with tempfile.NamedTemporaryFile(mode='w', delete=False) as csv_temp:
            points_geography = points[[DATASET_COLUMN_X, DATASET_COLUMN_Y]]
            points_geography.to_csv(csv_temp.name, index=False)
        Log.info("[get_polygon_of_points][export vector layer] {0:.3f}s".format(time.perf_counter() - start))

        start = time.perf_counter()
        # Maybe if transformations are required we can't keep this method (or vectorize transformation)
        path = "file:///" + csv_temp.name + "?encoding=%s&delimiter=%s&xField=%s&yField=%s&crs=%s" % ("UTF-8", ",", DATASET_COLUMN_X, DATASET_COLUMN_Y, self._crs_auth_id)
        layer_points = QgsVectorLayer(path, "templayer", "delimitedtext")

        print("[get_polygon_of_points][create vector layer] {0:.3f}s".format(time.perf_counter() - start))
        # perf_counter : 8-10s

        start = time.perf_counter()

        params = {'COLUMN_PREFIX': 'SAMPLE_',
                  'INPUT': layer_points,
                  'OUTPUT': "memory:",
                  'RASTERCOPY': polygons_layer}
        result = processing.run('native:rastersampling', params)

        points_polygons_layer: QgsVectorLayer = result['OUTPUT']
        print("[get_polygon_of_points][native:rastersampling] {0:.3f}s".format(time.perf_counter() - start))
        # perf_counter : 12s

        # Solution geopandas : need to write the raster on the disk : it's much much longer
        time_second_function = time.perf_counter()

        # Get polygons as a numpy array from points_polygons_layer
        # First columns are DATASET_COLUMN_X, DATASET_COLUMN_Y, so the closest polygon to the point is feat.attributes()[2]

        points_polygons = np.array(QgsVectorLayerUtils.getValues(points_polygons_layer, points_polygons_layer.fields().names()[2])[0])
        print("[get_polygon_of_points][extract from table] {0:.3f}s".format(time.perf_counter() - time_second_function))
        # perf_counter: 5s (list comprehension: 20s [empty loop : 4s])

        layer_points.setDataSource("", "", "")  # Free csv_temp to be able to delete it
        os.remove(csv_temp.name)"""

        dataset_polygons = np.empty(len(dataset))
        dataset_polygons.fill(np.nan)
        dataset_polygons[~nan_mask] = points_polygons
        return dataset_polygons

    def to_json(self) -> Dict:
        state_dict: Dict = self.__dict__.copy()
        return state_dict

