import abc
import datetime
from typing import Dict

import numpy as np
from osgeo import gdal, gdal_array
from qgis._core import QgsRasterLayer, QgsCoordinateReferenceSystem

from landsklim.lk.landsklim_analysis import LandsklimAnalysis
from landsklim.lk.landsklim_interpolation import LandsklimInterpolation, LandsklimInterpolationType


class NetCdfExporter(metaclass=abc.ABCMeta):

    def _convert_layer_to_wgs(self, raster_layer: QgsRasterLayer) -> str:
        import processing
        reproject_alg = processing.run("gdal:warpreproject",
                                       {'INPUT': raster_layer,
                                        'SOURCE_CRS': raster_layer.crs(),
                                        'TARGET_CRS': QgsCoordinateReferenceSystem('EPSG:4326'),
                                        'RESAMPLING': 1,
                                        'NODATA': None,
                                        'DATA_TYPE': 0,
                                        'OPTIONS': '',
                                        'TARGET_RESOLUTION': None,
                                        'TARGET_EXTENT': None,
                                        'TARGET_EXTENT_CRS': None,
                                        'MULTITHREADING': False,
                                        'EXTRA': '',
                                        'OUTPUT': 'TEMPORARY_OUTPUT'})
        return reproject_alg['OUTPUT']

    @abc.abstractmethod
    def export(self, interpolation: LandsklimInterpolation,
                                         interpolation_type: LandsklimInterpolationType, file: str, unit: str,
                                         long_name: str, standard_name: str, dates: Dict[int, datetime.datetime]):
        """
        :param interpolation: Interpolation to convert
        :type interpolation: LandsklimInterpolation

        :param interpolation_type: Type of the interpolation to convert
        :type interpolation_type: LandsklimInterpolationType

        :param file: Write NetCDF file to this path
        :type file: str

        :param unit: Unit of the interpolated variable
        :type unit: str

        :param long_name: Full human-readable name of the interpolated variable
        :type long_name: str

        :param standard_name: CF Standard Name of the interpolated variable
        :type standard_name: str

        :param dates: List of date of the interpolation for each situation
        :type dates: Dict[int, datetime.datetime]
        """

        raise NotImplementedError()


class NetCdfExporterGdal(NetCdfExporter):  # pragma: no cover

    def export(self, interpolation: LandsklimInterpolation,
                                             interpolation_type: LandsklimInterpolationType, file: str, unit: str,
                                             long_name: str, standard_name: str, dates: Dict[int, datetime.datetime]):
        """
        ==> https://gdal.org/drivers/vector/netcdf.html
        Si on peut exporter une série temporelle en .nc avec gdal, c'est mieux que d'utiliser un package tiers supplémentaire
        """

        situations = interpolation.get_situations_count()
        array = None
        geotransform = None
        layers: Dict[int, "MapLayer"] = interpolation.get_layers(interpolation_type)
        basedate = datetime.datetime(1858, 11, 17, 0, 0, 0)
        times = []

        for i, (situation, layer) in enumerate(layers.items()):
            date = dates[situation]
            dtime = (date - basedate).total_seconds() / 86400.
            times.append(dtime)
            reprojected_layer = self._convert_layer_to_wgs(layer.qgis_layer())
            tmn = gdal.Open(reprojected_layer)
            a = tmn.ReadAsArray()
            if array is None:
                array = np.zeros((situations, a.shape[0], a.shape[1]))
                geotransform = tmn.GetGeoTransform()
            array[i, :, :] = a

        """with gdal.GetDriverByName("netCDF").CreateMultiDimensional(file) as ds:
            rg = ds.GetRootGroup()
            dim = rg.CreateDimension("my_dim", "", "", 10)
            array = rg.CreateMDArray("my_var", [dim], gdal.ExtendedDataType.Create(gdal.GDT_Byte))
    
            del array
            del dim
            del rg"""

        src_ds = gdal_array.OpenArray(array)
        # src_ds = gdal.GetDriverByName("netCDF").Create(file, array.shape[1], array.shape[2], situations, gdal.GDT_Float32, options=['COMPRESS=DEFLATE', 'BAND_NAMES={0}'.format('.'.join(analysis.get_situations_names()))])
        for i, (situation, layer) in enumerate(layers.items()):
            # src_ds.GetRasterBand(i+1).WriteArray(array[i].astype(np.float32))
            src_ds.GetRasterBand(i + 1).SetNoDataValue(layer.no_data())
            src_ds.GetRasterBand(i + 1).SetMetadata(
                {'NETCDF_VARNAME': standard_name, 'long_name': long_name, 'standard_name': standard_name})

        size_time = interpolation.get_situations_count()
        src_ds.SetMetadataItem('NETCDF_DIM_EXTRA', '{time}')
        str_times = ",".join(map(str, times))
        src_ds.SetMetadataItem('NETCDF_DIM_time_DEF', f"{{{size_time},6}}")
        src_ds.SetMetadataItem('NETCDF_DIM_time_VALUES', f"{{{str_times}}}")
        src_ds.SetMetadataItem('time#units', "days since 1858-11-17 00:00:00")
        src_ds.SetMetadataItem('time#standard_name', "time")
        src_ds.SetMetadataItem('time#axis', 'T')
        src_ds.SetMetadataItem('crs#long_name', 'Lon/Lat Coords in WGS84')
        src_ds.SetMetadataItem('crs#grid_mapping_name', 'latitude_longitude')
        src_ds.SetMetadataItem('crs#longitude_of_prime_meridian', '0.0')
        src_ds.SetMetadataItem('crs#semi_major_axis', '6378137.0')
        src_ds.SetMetadataItem('crs#inverse_flattening', '298.257223563')
        # src_ds.SetMetadataItem('interpolated#long_name', 'Air temperature')
        src_ds.SetGeoTransform(geotransform)

        # src_ds = None

        # Create netCDF file
        gdal.GetDriverByName('netCDF').CreateCopy(file, src_ds, options=['COMPRESS=DEFLATE'])

class NetCdfExporterNetCdf4(NetCdfExporter):

    def export(self, interpolation: LandsklimInterpolation,
                                        interpolation_type: LandsklimInterpolationType, file: str, unit: str, long_name: str,
                                        standard_name: str, dates: Dict[int, datetime.datetime]):
        """
        Thanks to https://gis.stackexchange.com/a/70487
        """
        import netCDF4
        layers: Dict[int, "MapLayer"] = interpolation.get_layers(interpolation_type)

        ref_layer = list(layers.values())[0].qgis_layer()
        reprojected_layer = self._convert_layer_to_wgs(ref_layer)

        ds = gdal.Open(reprojected_layer)
        a = ds.ReadAsArray()
        nlat, nlon = np.shape(a)
        b = ds.GetGeoTransform()
        lon = np.arange(nlon) * b[1] + b[0]
        lat = np.arange(nlat) * b[5] + b[3]

        basedate = datetime.datetime(1858, 11, 17, 0, 0, 0)

        nco = netCDF4.Dataset(file, 'w', clobber=True)
        chunk_lon = 16
        chunk_lat = 16
        chunk_time = 12

        nco.createDimension('lon', nlon)
        nco.createDimension('lat', nlat)
        nco.createDimension('time', None)
        timeo = nco.createVariable('time', 'f4', ('time'))
        timeo.units = 'hours since 1858-11-17 00:00:00'
        timeo.standard_name = 'time'
        timeo.calendar = 'proleptic_gregorian'

        lono = nco.createVariable('lon', 'f4', ('lon'))
        lono.units = 'degrees_east'
        lono.standard_name = 'longitude'

        lato = nco.createVariable('lat', 'f4', ('lat'))
        lato.units = 'degrees_north'
        lato.standard_name = 'latitude'

        crso = nco.createVariable('crs', 'i4')
        crso.long_name = 'Lon/Lat Coords in WGS84'
        crso.grid_mapping_name = 'latitude_longitude'
        crso.longitude_of_prime_meridian = 0.0
        crso.semi_major_axis = 6378137.0
        crso.inverse_flattening = 298.257223563

        tmno = nco.createVariable(standard_name, 'f8', ('time', 'lat', 'lon'),
                                  zlib=True, chunksizes=[chunk_time, chunk_lat, chunk_lon],
                                  fill_value=list(layers.values())[0].no_data())
        tmno.units = unit
        tmno.scale_factor = 1.00
        tmno.add_offset = 0.00
        tmno.long_name = long_name
        tmno.standard_name = standard_name
        tmno.grid_mapping = 'crs'
        if interpolation.get_minimum_value() is not None:
            tmno.valid_min = interpolation.get_minimum_value()
        if interpolation.get_maximum_value() is not None:
            tmno.valid_max = interpolation.get_maximum_value()
        tmno.set_auto_maskandscale(False)

        nco.Conventions = 'CF-1.6'
        nco.source = "Landsklim"

        lono[:] = lon
        lato[:] = lat

        itime = 0

        for situation, layer in layers.items():
            date = dates[situation]
            dtime = (date - basedate).total_seconds() / 3600.  # (hourly precision)  86400.
            timeo[itime] = dtime
            reprojected_layer = self._convert_layer_to_wgs(layer.qgis_layer())
            tmn = gdal.Open(reprojected_layer)
            a = tmn.ReadAsArray()
            tmno[itime, :, :] = a
            itime = itime + 1

        nco.close()

class NetCdfExporterScipy(NetCdfExporter):  # pragma: no cover
    def export(self, interpolation: LandsklimInterpolation,
                                              interpolation_type: LandsklimInterpolationType, file: str, unit: str,
                                              long_name: str, standard_name: str, dates: Dict[int, datetime.datetime]):
        from scipy.io import netcdf_file

        situations = interpolation.get_situations_count()
        basedate = datetime.datetime(1858, 11, 17, 0, 0, 0)
        times = []
        array = None
        geotransform = None
        no_data = None
        nlat, nlon = None, None
        layers: Dict[int, "MapLayer"] = interpolation.get_layers(interpolation_type)

        for i, (situation, layer) in enumerate(layers.items()):
            date = dates[situation]
            dtime = (date - basedate).total_seconds() / 3600.  # (hourly precision)  86400.
            times.append(dtime)
            reprojected_layer = self._convert_layer_to_wgs(layer.qgis_layer())
            tmn = gdal.Open(reprojected_layer)
            a = tmn.ReadAsArray()
            if array is None:
                array = np.zeros((situations, a.shape[0], a.shape[1]))
                geotransform = tmn.GetGeoTransform()
                nlat, nlon = np.shape(a)
                no_data = layer.no_data()
            array[i, :, :] = a

        lon = np.arange(nlon) * geotransform[1] + geotransform[0]
        lat = np.arange(nlat) * geotransform[5] + geotransform[3]

        f = netcdf_file(file, 'w', maskandscale=False)
        f.createDimension('time', None)
        f.createDimension('lon', array.shape[2])
        f.createDimension('lat', array.shape[1])
        time = f.createVariable('time', 'i', ('time',))
        time[:] = np.array(times)
        time.units = 'hours since 1858-11-17'
        time.standard_name = 'time'
        time.calendar = 'proleptic_gregorian'

        vlon = f.createVariable('lon', 'f4', ('lon',))
        vlon.units = 'degrees_east'
        vlon.standard_name = 'longitude'

        vlat = f.createVariable('lat', 'f4', ('lat',))
        vlat.units = 'degrees_north'
        vlat.standard_name = 'latitude'

        crs = f.createVariable('crs', 'i4', ())
        crs.long_name = 'Lon/Lat Coords in WGS84'
        crs.grid_mapping_name = 'latitude_longitude'
        crs.longitude_of_prime_meridian = 0.0
        crs.semi_major_axis = 6378137.0
        crs.inverse_flattening = 298.257223563

        tmno = f.createVariable(standard_name, 'f8', ('time', 'lat', 'lon'))
        tmno.units = unit
        tmno.scale_factor = 1.00
        tmno.add_offset = 0.00
        tmno.long_name = long_name
        tmno.standard_name = standard_name
        tmno.grid_mapping = 'crs'
        tmno._FillValue = no_data
        tmno[:, :, :] = array

        if interpolation.get_minimum_value() is not None:
            tmno.valid_min = interpolation.get_minimum_value()
        if interpolation.get_maximum_value() is not None:
            tmno.valid_max = interpolation.get_maximum_value()

        f.Conventions = 'CF-1.6'
        f.source = "Landsklim"

        vlon[:] = lon
        vlat[:] = lat

        f.close()
