import os
import numpy as np
from osgeo import gdal
from PyQt5.QtCore import QObject, pyqtSignal
import shutil

# GDAL exceptions
gdal.UseExceptions()

class VulnerabilityMap(QObject):
    progress_updated = pyqtSignal(int)
    def __init__(self):
        super(VulnerabilityMap, self).__init__()
        self.data_folder = None
        self.initial_directory = None

    def set_working_directory(self, directory):
        '''
        Set up the working directory
        :param directory: your local directory with all dat files
        '''
        self.progress_updated.emit(0)
        self.data_folder = directory
        os.chdir(self.data_folder)

    def image_to_array(self,image):
        # Set up a GDAL dataset
        in_ds = gdal.Open(image)
        # Set up a GDAL band
        in_band = in_ds.GetRasterBand(1)
        # Create Numpy Array1
        arr = in_band.ReadAsArray()
        return arr

    def nrt_calculation(self, in_fn, deforestation_hrp, mask):
        '''
        NRT calculation
        :param in_fn: map of distance from the forest eddge in CAL
        :param deforestation_hrp:deforestation binary map in HRP
        :param mask: mask of the non-excluded jurisdiction (binary map)
        :return: NRT: Negligible Risk Threshold
        '''
        # Convert image to NumPy array
        self.progress_updated.emit(10)
        distance_arr_cal = self.image_to_array(in_fn)
        self.progress_updated.emit(30)
        deforestation_hrp_arr = self.image_to_array(deforestation_hrp)
        self.progress_updated.emit(40)
        mask_arr = self.image_to_array(mask)
        self.progress_updated.emit(50)

        # Mask the distance arr within deforstation pixel and study area
        distance_arr_masked=distance_arr_cal*mask_arr*deforestation_hrp_arr
        self.progress_updated.emit(60)

        ## Calculate the histogram
        # Flatten the distance_arr_masked and expect 0 for np.histogram function
        # The np.histogram is computed over the flattened array
        distance_arr_masked_1d = distance_arr_masked.flatten()
        self.progress_updated.emit(80)
        distance_arr_masked_1d = distance_arr_masked_1d[distance_arr_masked_1d != 0]

        ## Calculate the histogram
        # Set up bin width as spatial resolution
        in_ds = gdal.Open(in_fn)
        P = in_ds.GetGeoTransform()[1]
        bin_width =int(P)
        # Calculate the histogram
        hist, bin_edges = np.histogram(distance_arr_masked_1d, bins=np.arange(distance_arr_masked_1d.min(),
                                                                              distance_arr_masked_1d.max() + bin_width,
                                                                               bin_width))
        # Calculate the cumulative proportion
        # Normalize the histogram to get probability
        hist_normalized = hist / np.sum(hist)

        # Compute cumulative distribution
        cumulative_prop = np.cumsum(hist_normalized)

        # # Find the index cumulative proportion >= 0.995
        index_995 = np.argmax(cumulative_prop >= 0.995)

        # Get the bin edges for the NRT bin
        nrt_bin_start = bin_edges[index_995]
        nrt_bin_end = bin_edges[index_995 + 1]
        self.progress_updated.emit(90)

        # Calculate the average of the NRT bin
        NRT = int((nrt_bin_start + nrt_bin_end) / 2)
        self.progress_updated.emit(100)
        return NRT

    def geometric_classification(self, in_fn, NRT, n_classes):
        '''
        geometric classification
        :param in_fn: map of distance from the forest eddge
        :param NRT:Negligible Risk Threshold
        :param n_classes:number of classes
        :return: mask_arr: result array with mask larger than NRT
        '''
        # Convert in_fn to NumPy array
        # Set up a GDAL dataset
        in_ds = gdal.Open(in_fn)
        # Set up a GDAL band
        in_band = in_ds.GetRasterBand(1)
        # Create Numpy Array
        arr = in_band.ReadAsArray()

        # The lower limit of the highest class = spatial resolution (the minimum distance possible without being in non-forest)
        LL = int(in_ds.GetGeoTransform()[1])

        self.progress_updated.emit(10)
        # The upper limit of the lowest class = the Negligible Risk Threshold
        UL = NRT = int(NRT)
        n_classes = int(n_classes)

        # Calculate common ratio(r)=(LLmax/LLmin)^1/n_classes
        r = np.power(LL / UL, 1/n_classes)

        # Create 2D class_array for the areas within the NRT
        class_array = np.array([[i, i + 1] for i in range(n_classes)])

        # Calculate UL and LL value for the areas within the NRT
        x= np.power(r, class_array)
        risk_class=np.multiply(UL,x)

        self.progress_updated.emit(20)
        # Create mask: areas beyond the NRT, assign class 1
        mask_arr=arr
        mask_arr[arr >= NRT] = 1

        self.progress_updated.emit(30)
        # Use boolean indexing to reclassification mask_arr value >= LL into risk_class
        # (e.g., if n_class is 29, class the areas within the NRT into class 2 to 30)
        # Set the progress_updated.emit() outside the loop to fasten the process
        mask_arr[(risk_class[0][0] > mask_arr) & (mask_arr >= risk_class[0][1])] = 2
        mask_arr[(risk_class[1][0] > mask_arr) & (mask_arr >= risk_class[1][1])] = 3
        mask_arr[(risk_class[2][0] > mask_arr) & (mask_arr >= risk_class[2][1])] = 4
        mask_arr[(risk_class[3][0] > mask_arr) & (mask_arr >= risk_class[3][1])] = 5
        mask_arr[(risk_class[4][0] > mask_arr) & (mask_arr >= risk_class[4][1])] = 6
        self.progress_updated.emit(40)
        mask_arr[(risk_class[5][0] > mask_arr) & (mask_arr >= risk_class[5][1])] = 7
        mask_arr[(risk_class[6][0] > mask_arr) & (mask_arr >= risk_class[6][1])] = 8
        mask_arr[(risk_class[7][0] > mask_arr) & (mask_arr >= risk_class[7][1])] = 9
        mask_arr[(risk_class[8][0] > mask_arr) & (mask_arr >= risk_class[8][1])] = 10
        mask_arr[(risk_class[9][0] > mask_arr) & (mask_arr >= risk_class[9][1])] = 11
        self.progress_updated.emit(50)
        mask_arr[(risk_class[10][0] > mask_arr) & (mask_arr >= risk_class[10][1])] = 12
        mask_arr[(risk_class[11][0] > mask_arr) & (mask_arr >= risk_class[11][1])] = 13
        mask_arr[(risk_class[12][0] > mask_arr) & (mask_arr >= risk_class[12][1])] = 14
        mask_arr[(risk_class[13][0] > mask_arr) & (mask_arr >= risk_class[13][1])] = 15
        mask_arr[(risk_class[14][0] > mask_arr) & (mask_arr >= risk_class[14][1])] = 16
        self.progress_updated.emit(60)
        mask_arr[(risk_class[15][0] > mask_arr) & (mask_arr >= risk_class[15][1])] = 17
        mask_arr[(risk_class[16][0] > mask_arr) & (mask_arr >= risk_class[16][1])] = 18
        mask_arr[(risk_class[17][0] > mask_arr) & (mask_arr >= risk_class[17][1])] = 19
        mask_arr[(risk_class[18][0] > mask_arr) & (mask_arr >= risk_class[18][1])] = 20
        mask_arr[(risk_class[19][0] > mask_arr) & (mask_arr >= risk_class[19][1])] = 21
        self.progress_updated.emit(70)
        mask_arr[(risk_class[20][0] > mask_arr) & (mask_arr >= risk_class[20][1])] = 22
        mask_arr[(risk_class[21][0] > mask_arr) & (mask_arr >= risk_class[21][1])] = 23
        mask_arr[(risk_class[22][0] > mask_arr) & (mask_arr >= risk_class[22][1])] = 24
        mask_arr[(risk_class[23][0] > mask_arr) & (mask_arr >= risk_class[23][1])] = 25
        mask_arr[(risk_class[24][0] > mask_arr) & (mask_arr >= risk_class[24][1])] = 26
        self.progress_updated.emit(80)
        mask_arr[(risk_class[25][0] > mask_arr) & (mask_arr >= risk_class[25][1])] = 27
        mask_arr[(risk_class[26][0] > mask_arr) & (mask_arr >= risk_class[26][1])] = 28
        mask_arr[(risk_class[27][0] > mask_arr) & (mask_arr >= risk_class[27][1])] = 29
        mask_arr[(risk_class[28][0] > mask_arr) & (mask_arr >= risk_class[28][1])] = 30
        self.progress_updated.emit(90)
        return mask_arr

    def geometric_classification_alternative(self, in_fn, n_classes, mask, fmask):
        '''
        geometric classification for alternative vulnerability map
        :param in_fn: Empirical vulnerability map [0.0,1.0] range
        :param n_classes:number of classes
        :param mask: mask of the non-excluded jurisdiction (binary map)
        :param fmask: mask of the forest areas (binary map)
        :return: mask_arr: result array with mask larger than NRT
        '''
        # Convert in_fn to NumPy array
        # Set up a GDAL dataset
        in_ds = gdal.Open(in_fn)
        # Set up a GDAL band
        in_band = in_ds.GetRasterBand(1)
        # Create Numpy Array
        arr = in_band.ReadAsArray()

        self.progress_updated.emit(10)
        max_value = in_band.GetMaximum()

        # Rescaled empirical vulnerability map to a [1.0–2.0] range
        arr_rescale = 1+arr*1/max_value

        # The lower limit of the highest class = 1
        LL = int(1)

        # The upper limit of the lowest class = 2
        UL = int(2)
        n_classes = int(n_classes)

        # Calculate common ratio(r)=(LLmax/LLmin)^1/n_classes
        r = np.power(LL / UL, 1 / n_classes)

        # Create 2D class_array for the areas within the NRT
        class_array = np.array([[i, i + 1] for i in range(n_classes-1, -1, -1)])

        # Calculate UL and LL value for the areas within the NRT
        x = np.power(r, class_array)
        risk_class = np.multiply(UL, x)

        self.progress_updated.emit(20)

        # Mask jurisdiction and forest area
        # Create jurisdiction mask array
        in_ds1 = gdal.Open(mask)
        in_band1 = in_ds1.GetRasterBand(1)
        mask_arr = in_band1.ReadAsArray()
        # Create forest area array
        in_ds2 = gdal.Open(fmask)
        in_band2 = in_ds2.GetRasterBand(1)
        fmask_arr = in_band2.ReadAsArray()
        # Array multiple mask and fmask
        mask_arr = arr_rescale*mask_arr*fmask_arr

        self.progress_updated.emit(30)
        # Use boolean indexing to reclassification mask_arr value >= LL into risk_class
        # (e.g., if n_class is 29, class the areas within the NRT into class 2 to 30)
        # Set the progress_updated.emit() outside the loop to fasten the process
        mask_arr[(risk_class[0][0] > mask_arr) & (mask_arr >= risk_class[0][1])] = 1
        mask_arr[(risk_class[1][0] > mask_arr) & (mask_arr >= risk_class[1][1])] = 2
        mask_arr[(risk_class[2][0] > mask_arr) & (mask_arr >= risk_class[2][1])] = 3
        mask_arr[(risk_class[3][0] > mask_arr) & (mask_arr >= risk_class[3][1])] = 4
        mask_arr[(risk_class[4][0] > mask_arr) & (mask_arr >= risk_class[4][1])] = 5
        self.progress_updated.emit(40)
        mask_arr[(risk_class[5][0] > mask_arr) & (mask_arr >= risk_class[5][1])] = 6
        mask_arr[(risk_class[6][0] > mask_arr) & (mask_arr >= risk_class[6][1])] = 7
        mask_arr[(risk_class[7][0] > mask_arr) & (mask_arr >= risk_class[7][1])] = 8
        mask_arr[(risk_class[8][0] > mask_arr) & (mask_arr >= risk_class[8][1])] = 9
        mask_arr[(risk_class[9][0] > mask_arr) & (mask_arr >= risk_class[9][1])] = 10
        self.progress_updated.emit(50)
        mask_arr[(risk_class[10][0] > mask_arr) & (mask_arr >= risk_class[10][1])] = 11
        mask_arr[(risk_class[11][0] > mask_arr) & (mask_arr >= risk_class[11][1])] = 12
        mask_arr[(risk_class[12][0] > mask_arr) & (mask_arr >= risk_class[12][1])] = 13
        mask_arr[(risk_class[13][0] > mask_arr) & (mask_arr >= risk_class[13][1])] = 14
        mask_arr[(risk_class[14][0] > mask_arr) & (mask_arr >= risk_class[14][1])] = 15
        self.progress_updated.emit(60)
        mask_arr[(risk_class[15][0] > mask_arr) & (mask_arr >= risk_class[15][1])] = 16
        mask_arr[(risk_class[16][0] > mask_arr) & (mask_arr >= risk_class[16][1])] = 17
        mask_arr[(risk_class[17][0] > mask_arr) & (mask_arr >= risk_class[17][1])] = 18
        mask_arr[(risk_class[18][0] > mask_arr) & (mask_arr >= risk_class[18][1])] = 19
        mask_arr[(risk_class[19][0] > mask_arr) & (mask_arr >= risk_class[19][1])] = 20
        self.progress_updated.emit(70)
        mask_arr[(risk_class[20][0] > mask_arr) & (mask_arr >= risk_class[20][1])] = 21
        mask_arr[(risk_class[21][0] > mask_arr) & (mask_arr >= risk_class[21][1])] = 22
        mask_arr[(risk_class[22][0] > mask_arr) & (mask_arr >= risk_class[22][1])] = 23
        mask_arr[(risk_class[23][0] > mask_arr) & (mask_arr >= risk_class[23][1])] = 24
        mask_arr[(risk_class[24][0] > mask_arr) & (mask_arr >= risk_class[24][1])] = 25
        self.progress_updated.emit(80)
        mask_arr[(risk_class[25][0] > mask_arr) & (mask_arr >= risk_class[25][1])] = 26
        mask_arr[(risk_class[26][0] > mask_arr) & (mask_arr >= risk_class[26][1])] = 27
        mask_arr[(risk_class[27][0] > mask_arr) & (mask_arr >= risk_class[27][1])] = 28
        mask_arr[(risk_class[28][0] > mask_arr) & (mask_arr >= risk_class[28][1])] = 29
        mask_arr[(risk_class[29][0] > mask_arr) & (mask_arr >= risk_class[29][1])] = 30
        self.progress_updated.emit(90)
        return mask_arr

    def array_to_image(self, in_fn, out_fn, data, data_type, nodata=None):
        '''
         Create image from array
        :param in_fn: datasource to copy projection and geotransform from
        :param out_fn: path to the file to create
        :param data: NumPy array containing data to write
        :param data_type: output data type
        :param nodata: optional NoData value
        :return:
        '''
        in_ds = gdal.Open(in_fn)
        output_format = out_fn.split('.')[-1].upper()
        if (output_format == 'TIF'):
            output_format = 'GTIFF'
        elif (output_format == 'RST'):
            output_format = 'rst'
        driver = gdal.GetDriverByName(output_format)
        out_ds = driver.Create(out_fn, in_ds.RasterXSize, in_ds.RasterYSize, 1, data_type, options=["BigTIFF=YES"])
        out_ds.SetProjection(in_ds.GetProjection().encode('utf-8', 'backslashreplace').decode('utf-8'))
        out_ds.SetGeoTransform(in_ds.GetGeoTransform())
        out_band = out_ds.GetRasterBand(1)
        if nodata is not None:
            out_band.SetNoDataValue(nodata)
        out_band.WriteArray(data)
        out_band.FlushCache()
        out_ds.FlushCache()
        return

    def replace_ref_system(self, in_fn, out_fn):
        '''
         RST raster format: correct reference system name in rdc file
         :param in_fn: datasource to copy correct projection name
         :param out_fn: rst raster file
        '''
        if out_fn.split('.')[-1] == 'rst':
            read_file_name, _ = os.path.splitext(in_fn)
            write_file_name, _ = os.path.splitext(out_fn)
            temp_file_path = 'rdc_temp.rdc'

            with open(read_file_name + '.rdc', 'r') as read_file:
                for line in read_file:
                    if line.startswith("ref. system :"):
                        correct_name=line
                        break

            if correct_name:
                with open(write_file_name + '.rdc', 'r') as read_file, open(temp_file_path, 'w') as write_file:
                    for line in read_file:
                        if line.startswith("ref. system :"):
                            write_file.write(correct_name)
                        else:
                            write_file.write(line)

                # Move the temp file to replace the original
                shutil.move(temp_file_path, write_file_name + '.rdc')
                self.progress_updated.emit(100)