# -*- coding: utf-8 -*-
"""
/***************************************************************************
 ThermalMetrics
                                 A QGIS plugin
 This plugin helps to calculate basic metrics and indices from thermal images. 
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2021-10-02
        git sha              : $Format:%H$
        copyright            : (C) 2021 by Florian Ellsäßer
        email                : info@ecothermographylab.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""
    
from qgis.PyQt.QtCore import QSettings, QTranslator, QCoreApplication
from qgis.PyQt.QtGui import QIcon
from qgis.PyQt.QtWidgets import QAction

# Initialize Qt resources from file resources.py
from .resources import *
# Import the code for the dialog
from .thermalmetrics_dialog import ThermalMetricsDialog
import os.path

# Import libraries
from qgis.PyQt.QtCore import QSettings, QTranslator, QCoreApplication
from qgis.PyQt.QtGui import QIcon
from qgis.PyQt.QtWidgets import QAction, QFileDialog
from qgis.core import QgsProject, Qgis, QgsRasterLayer

# Import Libraries for the ThermalMetrics Plugin
from osgeo import gdal
import numpy as np
import pandas as pd
import operator
import math
import itertools
from scipy import stats
from scipy.stats import kurtosis, skew
import scipy.ndimage as ndi
from scipy.ndimage.morphology import grey_erosion, grey_dilation, grey_closing, grey_opening, grey_closing



class ThermalMetrics:
    """QGIS Plugin Implementation."""

    def __init__(self, iface):
        """Constructor.

        :param iface: An interface instance that will be passed to this class
            which provides the hook by which you can manipulate the QGIS
            application at run time.
        :type iface: QgsInterface
        """
        # Save reference to the QGIS interface
        self.iface = iface
        # initialize plugin directory
        self.plugin_dir = os.path.dirname(__file__)
        # initialize locale
        locale = QSettings().value('locale/userLocale')[0:2]
        locale_path = os.path.join(
            self.plugin_dir,
            'i18n',
            'ThermalMetrics_{}.qm'.format(locale))

        if os.path.exists(locale_path):
            self.translator = QTranslator()
            self.translator.load(locale_path)
            QCoreApplication.installTranslator(self.translator)

        # Declare instance attributes
        self.actions = []
        self.menu = self.tr(u'&ThermalMetrics')

        # Check if plugin was started the first time in current QGIS session
        # Must be set in initGui() to survive plugin reloads
        self.first_start = None

    # noinspection PyMethodMayBeStatic
    def tr(self, message):
        """Get the translation for a string using Qt translation API.

        We implement this ourselves since we do not inherit QObject.

        :param message: String for translation.
        :type message: str, QString

        :returns: Translated version of message.
        :rtype: QString
        """
        # noinspection PyTypeChecker,PyArgumentList,PyCallByClass
        return QCoreApplication.translate('ThermalMetrics', message)


    def add_action(
        self,
        icon_path,
        text,
        callback,
        enabled_flag=True,
        add_to_menu=True,
        add_to_toolbar=True,
        status_tip=None,
        whats_this=None,
        parent=None):
        """Add a toolbar icon to the toolbar.

        :param icon_path: Path to the icon for this action. Can be a resource
            path (e.g. ':/plugins/foo/bar.png') or a normal file system path.
        :type icon_path: str

        :param text: Text that should be shown in menu items for this action.
        :type text: str

        :param callback: Function to be called when the action is triggered.
        :type callback: function

        :param enabled_flag: A flag indicating if the action should be enabled
            by default. Defaults to True.
        :type enabled_flag: bool

        :param add_to_menu: Flag indicating whether the action should also
            be added to the menu. Defaults to True.
        :type add_to_menu: bool

        :param add_to_toolbar: Flag indicating whether the action should also
            be added to the toolbar. Defaults to True.
        :type add_to_toolbar: bool

        :param status_tip: Optional text to show in a popup when mouse pointer
            hovers over the action.
        :type status_tip: str

        :param parent: Parent widget for the new action. Defaults None.
        :type parent: QWidget

        :param whats_this: Optional text to show in the status bar when the
            mouse pointer hovers over the action.

        :returns: The action that was created. Note that the action is also
            added to self.actions list.
        :rtype: QAction
        """

        icon = QIcon(icon_path)
        action = QAction(icon, text, parent)
        action.triggered.connect(callback)
        action.setEnabled(enabled_flag)

        if status_tip is not None:
            action.setStatusTip(status_tip)

        if whats_this is not None:
            action.setWhatsThis(whats_this)

        if add_to_toolbar:
            # Adds plugin icon to Plugins toolbar
            self.iface.addToolBarIcon(action)

        if add_to_menu:
            self.iface.addPluginToMenu(
                self.menu,
                action)

        self.actions.append(action)

        return action

    def initGui(self):
        """Create the menu entries and toolbar icons inside the QGIS GUI."""

        icon_path = ':/plugins/thermalmetrics/icon.png'
        self.add_action(
            icon_path,
            text=self.tr(u'ThermalMetrics'),
            callback=self.run,
            parent=self.iface.mainWindow())

        # will be set False in run()
        self.first_start = True


    def unload(self):
        """Removes the plugin menu item and icon from QGIS GUI."""
        for action in self.actions:
            self.iface.removePluginMenu(
                self.tr(u'&ThermalMetrics'),
                action)
            self.iface.removeToolBarIcon(action)


    def select_input_file(self):
        '''Opens a file browser and populates the output_name lineEdit widget
        with the file path and name the user chose'''
        file_name, _filter = QFileDialog.getOpenFileName(
                self.dlg, 'Select input raster name','','*.tif')
        base_name = file_name
        self.dlg.input_name.setText(file_name)
        self.file_name = file_name
        in_raster = QgsRasterLayer(file_name, base_name)
        QgsProject.instance().addMapLayer(in_raster)
        self.in_file_name = base_name.split('/')[-1]
               
    def select_output_file(self):
        '''Opens a file browser and populates the output_name lineEdit widget
        with the file path and name the user chose'''
        filename, _filter = QFileDialog.getSaveFileName(
                self.dlg, 'Select output file name','','*.csv')
        self.dlg.output_name.setText(filename) 
        
    def select_patch_output_file(self):
        '''Opens a file browser and populates the output_name lineEdit widget
        with the file path and name the user chose'''
        filename, _filter = QFileDialog.getSaveFileName(
                self.dlg, 'Select output file name','','*.tif')
        self.dlg.patch_output_name.setText(filename) 
        
     
    def read_lst_img(self,in_file, na_val=None):
        '''This function reads thermal image and extracts land sturface 
        temperature (lst) projection (prj) and georeference data (geo).
        :in_file : path and file name
        :na_val : float that specifies NaN values
        '''            
        self.new_raster = gdal.Open(in_file,gdal.GA_ReadOnly)
        
        self.dlg.lst = self.new_raster.GetRasterBand(1).ReadAsArray()
        # if input is in Kelvin, convert to °C
        if  np.mean(self.dlg.lst[~np.isnan(self.dlg.lst)]) > 180.0:
            self.dlg.lst = self.dlg.lst - 273.15
        self.dlg.prj = self.new_raster.GetProjection()
        self.dlg.geo = self.new_raster.GetGeoTransform()
        self.dlg.lon = float(self.dlg.geo[0])
        self.dlg.lat = float(self.dlg.geo[3])
        # set all zeros to NaN
        self.dlg.lst[self.dlg.lst == 0.0] = self.dlg.na_val
        # now check if any values are negative and set them to nan 
        self.dlg.lst[self.dlg.lst < 0] = None
        # check whether input values are in Kelvin or in °C
        # -> therefore we take the mean of all Non-NaN values
        lst_mean = np.mean(self.dlg.lst[~np.isnan(self.dlg.lst)])
        # then check if lst_mean is smaller than 180 -> this would indicate that 
        # the LST values were inputed as °C
        if lst_mean < 180:
            # then onvert °C to Kelvin
            self.dlg.lst = self.dlg.lst + 273.15
        else:
            # just keep the kelvin values
            pass
        # now check if any values are negative and set them to nan 
        self.dlg.lst[self.dlg.lst < 0] = None
        
    def get_model_parameters(self,in_file):
        '''This function reads the input of the gui and assigns it to the parameters'''
        self.dlg.na_val = None
        #input
        self.read_lst_img(in_file) # read the lst image
        try:
            self.dlg.na_val = float(self.dlg.nan_input.text()) # define the NaN value
        except:
            self.dlg.na_val = None
        
        ## simple metrics ##
        
        # check if the image metrics check box is checked
        if self.dlg.imageMetrics_checkBox.isChecked():
            self.dlg.calculate_imageMetrics = True
        else:
            self.dlg.calculate_imageMetrics = False
            
        # check if the geographic metrics check box is checked
        if self.dlg.geoMetrics_checkBox.isChecked():
            self.dlg.calculate_geoMetrics = True
        else:
            self.dlg.calculate_geoMetrics = False
            
       # check if the standard metrics check box is checked
        if self.dlg.standardMetrics_checkBox.isChecked():
            self.dlg.calculate_standardMetrics = True
        else:
            self.dlg.calculate_standardMetrics = False
            
        ## histogram and distribution metrics ##
            
        # check if the distribution metrics check box is checked
        if self.dlg.distributionMetrics_checkBox.isChecked():
            self.dlg.calculate_distributionMetrics = True
        else:
            self.dlg.calculate_distributionMetrics = False
            
        # bin number input    
        try:
            self.bin_number = int(self.dlg.bin_number_input.text())
        except:
            self.bin_number = 'auto' 
            
        ## Patch richness and distribution ##
        
        # check if the patch richness and density check box is checked
        if self.dlg.patchRichnessAndDensity_checkBox.isChecked():
            self.dlg.calculate_patchRichnessAndDensity = True
        else:
            self.dlg.calculate_patchRichnessAndDensity = False
            
        # patches class number 
        try:
            self.class_number = int(self.dlg.class_number_input.text())
        except:
            self.class_number = 'auto' 
        
        # read the dilation value    
        try:
            self.dilation_value = int(self.dlg.dilation.text())
        except:
            self.dilation_value = 5 # a random default value
            
        # read the closing value
        try:
            self.closing_value = int(self.dlg.closing.text())
        except:
            self.closing_value = 5 # a random default value
            
        # read the erosion value
        try:
            self.erosion_value = int(self.dlg.erosion.text())
        except:
            self.erosion_value = 20 # a random default value   
            
        ## the indices checkboxes    
            
        # check if the Shannon Diversity Index check box is checked
        if self.dlg.shannonDiversityIndex_checkBox.isChecked():
            self.dlg.calculate_shannonDiversityIndex = True
        else:
            self.dlg.calculate_shannonDiversityIndex = False
            
        # check if the shannon Equitability Index check box is checked
        if self.dlg.shannonEquitabilityIndex_checkBox.isChecked():
            self.dlg.shannonEquitabilityIndex = True
        else:
            self.dlg.shannonEquitabilityIndex = False
            
        # check if the Simpson Diversity Index check box is checked
        if self.dlg.simpsonDiversityIndex_checkBox.isChecked():
            self.dlg.calculate_simpsonDiversityIndex = True
        else:
            self.dlg.calculate_simpsonDiversityIndex = False
            
        # check if the gini Simpson Diversity Index check box is checked
        if self.dlg.giniSimpsonDiversityIndex_checkBox.isChecked():
            self.dlg.giniSimpsonDiversityIndex = True
        else:
            self.dlg.giniSimpsonDiversityIndex = False
            
        # check if the simpson Reciprocal Index check box is checked
        if self.dlg.simpsonReciprocalIndex_checkBox.isChecked():
            self.dlg.simpsonReciprocalIndex = True
        else:
            self.dlg.simpsonReciprocalIndex = False
             
       
                   
        # results -> create a result out raster only if desired
        if self.dlg.calculate_patchRichnessAndDensity == True:
            self.dlg.thermalmetrics_out_raster = None 
        else:
            pass
        
      # get NaN values from input field
    def get_nan_value(self):
        if self.dlg.na_val == 'None':
            self.dlg.na_val = None
        else:
            try:
                self.dlg.na_val = float(self.dlg.nan_input.text())
            except:
                self.dlg.na_val = None
                
    def get_covered_area(self):
        self.area_full_raster = None  
        self.area_only_non_nan = None  
        # get x and y resolution
        x_resolution, y_resolution = operator.itemgetter(1,5)(self.new_raster.GetGeoTransform())
        # convert resolution from 100 km to 1m and absolute values 
        x_resolution = abs(x_resolution)*100000
        y_resolution = abs(y_resolution)*100000
        # area all pixels
        self.area_full_raster = (self.dlg.lst.shape[0]*self.dlg.lst.shape[1]) * (x_resolution * y_resolution)
        # area covered by non-nan pixels
        self.area_only_non_nan = (np.count_nonzero(~np.isnan(self.dlg.lst))) * (x_resolution * y_resolution)
        
        return self.area_full_raster, self.area_only_non_nan
        # return x_resolution, y_resolution
        
    def get_distribution(self):
        
        # now create a histogram
        self.histogram = np.histogram(self.dlg.lst[~np.isnan(self.dlg.lst)], bins=self.bin_number ) #, range=None, normed=None, weights=None, density=None)
        # now make that histogram printable
        out_histogram_pixel_number = ','.join(map(str,self.histogram[0]))
        out_histogram_bins = ','.join(map(str,self.histogram[1]))
        final_bin_number = len(list(map(str,self.histogram[1])))
        # now get kurtosis and skew 
        out_skewness = skew(self.dlg.lst[~np.isnan(self.dlg.lst)], bias=False)
        out_kurtosis = kurtosis(self.dlg.lst[~np.isnan(self.dlg.lst)], bias=False)
        # calculate Fisher-Pearson coefficient of skewness
        # Fisher-Pearson Coefficient of Skewness
        out_fpcs = ((np.mean(self.dlg.lst[~np.isnan(self.dlg.lst)]) 
                        -float((" ".join(map(str,stats.mode(self.dlg.lst[~np.isnan(self.dlg.lst)])[0] ))))) 
                       / np.std(self.dlg.lst[~np.isnan(self.dlg.lst)]))
             
        return final_bin_number, out_histogram_bins, out_histogram_pixel_number, out_skewness, out_kurtosis, out_fpcs
    
    # create some useful functions for the patches

    # this function just gets a square
    def get_square(self,in_value):
        out_square = np.ones((in_value, in_value), dtype=bool)
        return out_square

    # create a dilation function
    def multi_dil(self,im, num):
        square = self.get_square(3)
        for i in range(num):
            #im = dilation(im, element) 
            # get the grey dilation
            im = grey_dilation(im, structure = square)
        return im

    # create an erosion function
    def multi_ero(self,im, num):
        square = self.get_square(3)
        for i in range(num):
            im = grey_erosion(im,structure = square)
        return im
    
    # define a function that creates classes according to bins
    def define_class_array(self,in_array,bin_number):
        # now get the binary array for each bin -> 'auto' for automatic bin size
        bin_borders = np.histogram(in_array[~np.isnan(in_array)], bins=bin_number)[1]
        # get bin size
        bin_size = (bin_borders.max() - bin_borders.min() )  / len(bin_borders)
        # create an empty numpy array in the shape of lst
        out_array = np.empty_like(in_array, dtype=None)
        # now iterate through the classes
        for index, item in enumerate(bin_borders):
            # get the upper and lower borders of the bin
            lower_border = item 
            upper_border = item + bin_size 
            # create a mask to get only the values in this range
            mask = (in_array>lower_border) & (in_array<upper_border)
            # take the mask and convert to integers
            mask = mask.astype(int)
            # change all the non zero values to the index + 1 to avoid zero
            mask[mask != 0] = index + 1
            # now sum all these values up in the empty array
            out_array = np.sum([out_array,mask], axis=0)
            # get a print output
            #print(index, '   ', item, '   ', lower_border, ' -> ', upper_border)
            
        return out_array

    # get the patches
    def get_patches(self,in_array, dilate, close, erode):
        square = self.get_square(3)
        # dilate
        multi_dilated = self.multi_dil(in_array, dilate)
        # close areas
        area_closed = grey_closing(multi_dilated, close)
        # erode areas
        multi_eroded = self.multi_ero(area_closed, erode)
        opened = grey_opening(multi_eroded,structure = square)
        return opened
    
    def calculate_patches(self):
        # define the classes
        out_array = self.define_class_array(self.dlg.lst,self.class_number)
        # run the patch function
        patch_raster = self.get_patches(out_array, self.dilation_value, self.closing_value, self.erosion_value)
        return patch_raster
        
    
    
   
    
    # frequency and type of patches
    def get_frequency(self,in_array):
        (unique, counts) = np.unique(in_array, return_counts=True)
        frequencies = np.asarray((unique, counts)).T
        # remove surroundings 
        frequencies = frequencies[1:]
        # frequencies ist die relative patch nummer und die Anzahl der Patches
        return frequencies
    
    def get_patch_mean(self,in_temp,patch,patch_id):
        
        #print(patch_id-1)
        # get only the single patch
        masked_array  = np.where(patch == (patch_id),in_temp,0)
        # turn zeros to np.nan
        masked_array[masked_array == 0] = np.nan
        # now get the means
        mean = np.nanmean(masked_array)
        #print('the mean:',mean)
        return mean
    
    # define a function that creates classes according to bins
    def define_bin_borders(self,in_array,bin_number):
        # now get the binary array for each bin -> 'auto' for automatic bin size
        bin_borders = np.histogram(in_array[~np.isnan(in_array)], bins=bin_number)[1]
        # get bin size
        bin_size = (bin_borders.max() - bin_borders.min() )  / len(bin_borders)
        # create an empty numpy array in the shape of lst
        #ut_array = np.empty_like(in_array, dtype=None)
        return bin_borders, bin_size
    
    # this function gets the mean of the patches
    def get_mean_patches(self,patches_details,in_lst,patch,patch_mean_list):
        # first wer for loop through all the patches in the patches details list
        for item in patches_details:
            patch_mean = self.get_patch_mean(in_lst,patch,item[0])
            patch_mean_list.append(patch_mean)
            
    def mask_the_array(self,in_array,lower_border,upper_border,index):
        mask = (in_array>lower_border) & (in_array<upper_border)
        # take the mask and convert to integers
        mask = mask.astype(int)
        # change all the non zero values to the index + 1 to avoid zero
        mask[mask != 0] = index + 1
        
        return mask
    
    # this is the main function that gets the pachtes and returns a messy file with the details from the array
    def get_patches_from_array(self,in_array,class_number):
        square = self.get_square(3)
        # call the histogram function first 
        bin_borders, bin_size = self.define_bin_borders(in_array,class_number)
        
        # create an empty list to fill
        patch_list = []
        # now iterate through the classes
        for index, item in enumerate(bin_borders):
            # get the upper and lower borders of the bin
            lower_border = item 
            upper_border = item + bin_size 
            # create a mask to get only the values in this range
            mask = self.mask_the_array(in_array,lower_border,upper_border,index)
            
            # label the single patches
            label_im, numpatches = ndi.label(mask,structure=square)
            # make a numpy array of the patches 
            patch = np.array(label_im)
            
            # get the patch details
            patches_details = self.get_frequency(patch)
            
            # we create an empty list here
            patch_mean_list =[]
            
            # run the get_mean_patches function to fill the empty list
            self.get_mean_patches(patches_details,self.dlg.lst,patch,patch_mean_list)
                        
            # now append the patches to a list 
            patch_list.append([patches_details, lower_border, upper_border, patch_mean_list, numpatches])
        return patch_list
    
    # now convert patch_list to a pandas dataframe
    def create_df_from_list(self,in_area,in_lst,patch_list):
        # get indiv pixel area
        # self.area_only_non_nan = (np.count_nonzero(~np.isnan(self.dlg.lst))) * (x_resolution * y_resolution)
    
        # get total number of pixels 
        total_number_of_pixels = (np.count_nonzero(~np.isnan(in_lst)))
        # get pixel size
        indiv_pixel_area = in_area / total_number_of_pixels
    
    
        # create lists to fill
        num_pixels_list = []
        lower_borders_list = []
        upper_borders_list = []
        means_list = []
    
        for item in patch_list:
    
            # check if there is actually a patch
            if len(item[0]) > 0:
                list_length = len(item[0])
                num_pixels = [ar[1] for ar in item[0]]
                lower_borders = [item[1]]*list_length
                upper_borders = [item[2]]*list_length
                means = item[3]
                # append to lists 
                num_pixels_list.append(num_pixels)
                lower_borders_list.append(lower_borders)
                upper_borders_list.append(upper_borders)
                means_list.append(means)
    
            else:
                pass
        # make flat lists
        num_pixels_list_flat = list(itertools.chain.from_iterable(num_pixels_list))
        lower_borders_list_flat = list(itertools.chain.from_iterable(lower_borders_list))
        upper_borders_list_flat = list(itertools.chain.from_iterable(upper_borders_list))
        means_list_flat = list(itertools.chain.from_iterable(means_list))
        patch_area = [i * indiv_pixel_area for i in num_pixels_list_flat]
    
        # create a pandas dataframe 
        out_data_frame = pd.DataFrame(list(zip(num_pixels_list_flat,
                                               means_list_flat,
                                               patch_area,
                                              lower_borders_list_flat,
                                              upper_borders_list_flat
                                              )),
                       columns =['number_of_pixels','mean','area','lower_bin_border','upper_bin_border'])
        
        # sort the data frame
        out_data_frame = out_data_frame.sort_values('area', ascending=False)
        
        # remove nan patches
        out_data_frame = out_data_frame[out_data_frame['mean'] >= 0]
    
        return out_data_frame
    
    def get_patch_data_frame(self):
        # call the function to get the patch list
        patch_list = self.get_patches_from_array(self.patches_out,self.class_number)
        # convert list to data frame
        area = self.get_covered_area()
        data = self.create_df_from_list(area[1],self.dlg.lst,patch_list)
        return data
    
   
    ############# Indices
    # patch richness 
    # -> https://www.umass.edu/landeco/research/fragstats/documents/Metrics/Diversity%20Metrics/Metrics/L125%20-%20PRD.htm
    # PR equals the number of different patch types present within the landscape boundary divided by total landscape area (m2), multiplied by 10,000 and 100 (to convert to 100 hectares). Note, total landscape area (A) includes any internal background present.
    # change m² later to a real value 
    
    def get_number_of_patches(self,in_data):
        total_number_of_patches = in_data.shape[0]
        return total_number_of_patches
    
    def get_patch_richness_density(self,in_data,square_meters,hectares):
        # now calculate it 
        patch_richness_density = np.round((self.get_number_of_patches(in_data)/square_meters)*10000*hectares)
        return patch_richness_density
    
    # Shannon diversity Index - measure the diversity of species in a community
    # -> http://www.tiem.utk.edu/~gross/bioed/bealsmodules/shannonDI.html
    # -> https://www.statology.org/shannon-diversity-index/
    # calculate p_i proportions
    def calculate_p_i_portions(self,in_df):
        # get a total number of pixels included in the patches
        all_pixels = in_df['area'].sum()
        in_df['p_i'] = in_df['area'] / all_pixels
        # now get the natural log of each proportion
        in_df['nat_log_p_i'] = np.log10(in_df['p_i'])
        return in_df
    
    # get the shannon diversity index
    def get_shannon_diversity_index(self,in_data):
        data_cleaned = self.calculate_p_i_portions(in_data)
        # multiply the proportions by the natural log of the proportions
        data_cleaned['pi_multi_nat_log_p_i'] = data_cleaned['p_i'] * data_cleaned['nat_log_p_i']
        # now calculate the shannon diversity index by summing up the last column
        shannon_diversity_index = data_cleaned['pi_multi_nat_log_p_i'].sum() * (-1)
        return shannon_diversity_index
    
    # Shannon Equitability index -> shannon diversity index / ln(number of patches)
    def get_shannon_equitability_index(self,in_data):
        # get the p_i portions
        data_cleaned = self.calculate_p_i_portions(in_data)
        # get the shannon diversity index
        shannon_diversity_index = self.get_shannon_diversity_index(data_cleaned)
        # get the shannon equitability index
        shannon_equitability_index = shannon_diversity_index / np.log10(data_cleaned.shape[0])
        return shannon_equitability_index
    
    # Simpson diversity Index
    # any two pixels randomly selected from an infinitely large community will belong to the same patch
    # -> http://www.tiem.utk.edu/~gross/bioed/bealsmodules/simpsonDI.html
    # -> https://www.omnicalculator.com/statistics/simpsons-diversity-index#what-is-simpson's-index
    # get simpson diversity index
    def get_simpson_diversity_index(self,in_data):
    
        #D = Σ(ni * (ni - 1)) / (N * (N - 1)),
        all_pixels = in_data['area'].sum() # see above
        #(N * (N - 1)),
        in_data['N_observ_multi_N_minus_1'] = all_pixels * (all_pixels - 1)
        # ni * (ni - 1)
        in_data['ni_multi_ni_minus_1'] = in_data['area'] * (in_data['area'] - 1)
    
        in_data['ni_x_divided_by_N_x'] = in_data['ni_multi_ni_minus_1'] / in_data['N_observ_multi_N_minus_1'] 
        # now sum up the index 
        simpson_diversity_index = in_data['ni_x_divided_by_N_x'].sum()
        return simpson_diversity_index
    
    # ohoo that's a difficult one :) get the gini simpson diversity index
    def get_gini_simpson_diversity_index(self,in_data):
        simpson_diversity_index = self.get_simpson_diversity_index(in_data)
        # Gini_Simpson index
        gini_simpson_diversity_index = 1 - simpson_diversity_index
        return gini_simpson_diversity_index
    
    # Simpson's reciprocal index
    def get_simpson_reciprocal_index(self,in_data):
        simpson_diversity_index = self.get_simpson_diversity_index(in_data)
        simpson_reciprocal_index = 1 / simpson_diversity_index
        return simpson_reciprocal_index
    
    ## To-do:
    # landscape shape index
    # -> https://www.umass.edu/landeco/research/fragstats/documents/Metrics/Area%20-%20Density%20-%20Edge%20Metrics/Metrics/L9%20-%20LSI.htm
    # patch cohesion index
    # -> https://www.umass.edu/landeco/research/fragstats/documents/Metrics/Connectivity%20Metrics/Metrics/C121%20-%20COHESION.htm
    # aggregation index
    # -> https://www.umass.edu/landeco/research/fragstats/documents/Metrics/Contagion%20-%20Interspersion%20Metrics/Metrics/C116%20-%20AI.htm
        
    def write_output_images(self):
        '''This function writes the output data into a GeoTIFF'''
        
        rows,cols=np.shape(self.dlg.lst)
        driver = gdal.GetDriverByName('GTiff')
        nbands=1
        
        # calculate patches
        self.patches_out = self.calculate_patches()
        out_raster_patches = self.patches_out.copy()
        
        # remove negative values
        out_raster_patches[out_raster_patches < 0] = np.nan
        
        out_raster = driver.Create(self.dlg.patch_output_name.text(), cols, 
                                   rows, nbands, gdal.GDT_Float32)
        out_raster.SetGeoTransform(self.dlg.geo)
        out_raster.SetProjection(self.dlg.prj)
        # Write cwsi to band 1
        band_1=out_raster.GetRasterBand(1)
        band_1.SetNoDataValue(0)
        band_1.WriteArray(out_raster_patches)
        band_1.FlushCache()
                                
        # Flush Cache
        out_raster.FlushCache()
        del out_raster
        
        # load layer into qgis self.dlg.patch_output_name.setText(filename) 
        qgis_raster = QgsRasterLayer(self.dlg.patch_output_name.text(), 'ThermalMetrics_Output_' + str(self.in_file_name))
        QgsProject.instance().addMapLayer(qgis_raster)
        
        
    def write_output_stats(self):
        # write the output data in a .csv file
        with open(self.dlg.output_name.text(), 'w') as output_file:
            # write an out file with the most important stats
            # model parameters
            output_file.write('Metrics and indices of the thermal image:' + ',' + str(self.file_name) + '\n')
            # check if the image metrics check box is checked
            if self.dlg.calculate_imageMetrics == True:
            
                ## total number of pixels 
                output_file.write('Total number of pixels in the thermal image:' + ',' + str(self.dlg.lst.shape[0]*self.dlg.lst.shape[1]) + '\n')
                ## Number of non-NaN Pixels
                output_file.write('Non-Nan pixels in the thermal image:' + ',' + str(np.count_nonzero(~np.isnan(self.dlg.lst))) + '\n')
                # Percentage of Non-NaN Pixels compared to whole 
                output_file.write('Percentage of Non-Nan pixels in the thermal image [%]:' + ',' + str("{:.2f}".format(np.count_nonzero(~np.isnan(self.dlg.lst))/
                                                                                                       (self.dlg.lst.shape[0]*self.dlg.lst.shape[1])*100)) + '\n')
                # Size of Raster in x and y axis
                output_file.write('X and Y shape of the thermal image [X & Y]:' + ',' + str(self.dlg.lst.shape[1])+ ',' + str(self.dlg.lst.shape[0]) + '\n')
                            
            else: 
                output_file.write('Total number of pixels in the thermal image:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Non-Nan pixels in the thermal image:' + ',' + 'this feature was not selected' + ',' + '\n')  
                output_file.write('Percentage of Non-Nan pixels in the thermal image [%]:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('X and Y shape of the thermal image [X & Y]:' + ',' + 'this feature was not selected' + ',' + '\n')
                
                
            # check if the geographical metrics check box is checked
            if self.dlg.calculate_geoMetrics == True:
                # get the area covered by the raster in square meter
                output_file.write('Area covered by whole image [m²]:' + ',' + str("{:.2f}".format(self.get_covered_area()[0])) + ',' + '\n')
                # area covered by non-nan pixels in square meter
                output_file.write('Area covered by non-NaN pixels in the image [m²]:' + ',' + str("{:.2f}".format(self.get_covered_area()[1])) + ',' + '\n')
                # get the area covered by the raster
                output_file.write('Area covered by whole image [km²]:' + ',' + str("{:.4f}".format(self.get_covered_area()[0]/1000000)) + ',' + '\n')
                # area covered by non-nan pixels
                output_file.write('Area covered by non-NaN pixels in the image [km²]:' + ',' + str("{:.4f}".format(self.get_covered_area()[1]/1000000))+ ',' + '\n')
            
            else:
                output_file.write('Area covered by whole image [m²]:'  + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Area covered by non-NaN pixels in the image [m²]:'  + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Area covered by whole image [km²]:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Area covered by non-NaN pixels in the image [km²]:' + ',' + 'this feature was not selected' + ',' + '\n')
                
                            
            # check if the standard metrics check box is checked
            if self.dlg.calculate_standardMetrics == True:
                output_file.write('Mean temperature [in Kelvin and in °C]:' + ',' + str("{:.4f}".format(np.mean(self.dlg.lst[~np.isnan(self.dlg.lst)]))) + ',' 
                                  + str("{:.4f}".format(np.mean(self.dlg.lst[~np.isnan(self.dlg.lst)])-273.15)) + ',' + '\n')
                output_file.write('Max temperature [in Kelvin and in °C]:' + ',' + str("{:.4f}".format(np.max(self.dlg.lst[~np.isnan(self.dlg.lst)]))) + ',' 
                                  + str("{:.4f}".format(np.max(self.dlg.lst[~np.isnan(self.dlg.lst)])-273.15)) + ','+ '\n')
                output_file.write('Min temperature [in Kelvin and in °C]:' + ',' + str("{:.4f}".format(np.min(self.dlg.lst[~np.isnan(self.dlg.lst)]))) + ',' 
                                  + str("{:.4f}".format(np.min(self.dlg.lst[~np.isnan(self.dlg.lst)])-273.15)) + ','+ '\n')
                output_file.write('Standard deviation of temperature:' + ',' + str("{:.6f}".format(np.std(self.dlg.lst[~np.isnan(self.dlg.lst)]))) + ',' + '\n')
                output_file.write('Variance of temperature:' + ',' + str("{:.6f}".format(np.var(self.dlg.lst[~np.isnan(self.dlg.lst)]))) + ',' + '\n')
                
            else:
                output_file.write('Mean temperature:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Max temperature:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Min temperature:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Standard deviation of temperature:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Variance of temperature:' + ',' + 'this feature was not selected' + ',' + '\n')
                            
            # check if the patch richness and density check box is checked
            if self.dlg.calculate_distributionMetrics == True:
                output_file.write('Selected number of histogram bins:' + ',' + str(self.get_distribution()[0]) + ',' + '\n')
                output_file.write('Histogram bins:' + ',' + str(self.get_distribution()[1]) + ',' + '\n')
                output_file.write('Histogram number of pixels in bins:' + ',' + str(self.get_distribution()[2]) + ',' + '\n')
                output_file.write('Skewness:' + ',' + str("{:.6f}".format(self.get_distribution()[3])) + ',' + '\n')
                output_file.write('Kurtosis:' + ',' + str("{:.6f}".format(self.get_distribution()[4])) + ',' + '\n')
                output_file.write('Fisher-Pearson Coefficient of Skewness (FPCS):' + ',' + str("{:.6f}".format(self.get_distribution()[5])) + ',' + '\n')
            else: 
                output_file.write('Selected number of histogram bins:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Histogram bins:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Histogram number of pixels in bins:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Skewness:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Kurtosis:' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Fisher-Pearson Coefficient of Skewness (FPCS):' + ',' + 'this feature was not selected' + ',' + '\n')
   
         
            # check if calculate patches was checked
            if self.dlg.calculate_patchRichnessAndDensity == True:
                output_file.write('Patch richness:' + ',' + str(self.get_number_of_patches(self.data_frame)) + ',' + '\n')
                output_file.write('Patch density:' + ',' + str(self.get_patch_richness_density(self.data_frame, self.get_covered_area()[0],self.get_covered_area()[0]/10000)) + ',' + '\n')
           
            else: 
                output_file.write('Patch richness :' + ',' + 'this feature was not selected' + ',' + '\n')
                output_file.write('Patch density:' + ',' + 'this feature was not selected' + ',' + '\n')
               
                
            # check if the Shannon Diversity Index check box is checked
            if self.dlg.calculate_shannonDiversityIndex == True:
                output_file.write('Shannon Diversity Index:' + ',' + str(self.get_shannon_diversity_index(self.data_frame)) + ',' + '\n')  
            else:
                output_file.write('Shannon Diversity Index:' + ',' + 'this feature was not selected' + ',' + '\n')
                
            # check if the Shannon Equitability Index check box is checked
            if self.dlg.shannonEquitabilityIndex == True:
                output_file.write('Shannon Equitability Index:' + ',' + str(self.get_shannon_equitability_index(self.data_frame)) + ',' + '\n')  
            else:
                output_file.write('Shannon Equitability Index:' + ',' + 'this feature was not selected' + ',' + '\n')
                
            # check if the Simpson Diversity Index check box is checked
            if self.dlg.calculate_simpsonDiversityIndex == True:
                output_file.write('Simpson Diversity Index:' + ',' + str(self.get_simpson_diversity_index(self.data_frame)) + ',' + '\n')  
            else:
                output_file.write('Simpson Diversity Index:' + ',' + 'this feature was not selected' + ',' + '\n')
                
            # check if the gini Simpson Diversity Index check box is checked
            if self.dlg.giniSimpsonDiversityIndex == True:
                output_file.write('Gini Simpson Diversity Index:' + ',' + str(self.get_gini_simpson_diversity_index(self.data_frame)) + ',' + '\n')  
            else:
                output_file.write('Gini Simpson Diversity Index:' + ',' + 'this feature was not selected' + ',' + '\n')
            
            # check if the simpson Reciprocal Index check box is checked
            if self.dlg.simpsonReciprocalIndex == True:
                output_file.write('Simpson Reciprocal Index:' + ',' + str(self.get_simpson_reciprocal_index(self.data_frame)) + ',' + '\n')  
            else:
                output_file.write('Simpson Reciprocal Index:' + ',' + 'this feature was not selected' + ',' + '\n')
                
       
    # write the pandas data frame to .csv        
    def write_output_data_frame(self):
            # write the output data in a .csv file
            new_out_name = self.dlg.output_name.text().split('.')[0] + '_patch_detais_data_frame.csv'
            # write out .csv
            self.data_frame.to_csv(new_out_name)
        
        
                
                        
    
    def run(self):
        """Run method that performs all the real work"""

        # Create the dialog with elements (after translation) and keep reference
        # Only create GUI ONCE in callback, so that it will only load when the plugin is started
        if self.first_start == True:
            self.first_start = False
            self.dlg = ThermalMetricsDialog()
            self.dlg.search_input_button.clicked.connect(self.select_input_file)
            self.dlg.search_output_button.clicked.connect(self.select_output_file)
            self.dlg.search_patch_output_button.clicked.connect(self.select_patch_output_file)
        # show the dialog
        self.dlg.show()
        # Run the dialog event loop
        result = self.dlg.exec_()
        # See if OK was pressed
        if result:
            # define in raster
            in_file = self.dlg.input_name.text()
            
            # get all the model parameters
            self.get_model_parameters(in_file)
             # get NaN values from input field
            self.get_nan_value()
            
            # check if calculate patches was checked
            if self.dlg.calculate_patchRichnessAndDensity == True:
                #read file name
                
                # get raster and output
                self.write_output_images()
                
            # check if any of the indices was ticked
            if (self.dlg.calculate_shannonDiversityIndex == True 
                or self.dlg.shannonEquitabilityIndex == True  
                or self.dlg.calculate_simpsonDiversityIndex == True 
                or self.dlg.giniSimpsonDiversityIndex == True 
                or self.dlg.simpsonReciprocalIndex == True):
                
                # run the data frame function for the patches
                self.data_frame = self.get_patch_data_frame()
                
                # save that df in a seperate .csv
                self.write_output_data_frame()
                
            else:
                pass
            
        
                
                
                
            
            
         
            
        
            
       
            
            
            
            
            
            
            
            
            
            
            
            
                
                                
            
            # write output stats file 
            self.write_output_stats()
                         
            # Display a push message that QWaterModel was successful
            self.iface.messageBar().pushMessage(
                    'Success', 'Metrics and Indices calculated successfully!',
                    level=Qgis.Success, duration=3) 