# -*- coding: utf-8 -*-
"""
/***************************************************************************
 GroundwaterVulnerability
                                 A QGIS plugin
 Groundwater Vulnerability Mapping (GLA method)
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2018-10-29
        git sha              : $Format:%H$
        copyright            : (C) 2018 by Christian Böhnke
        email                : christian@home-boehnke.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

import os
import numpy as np
from osgeo import gdal, ogr, osr
import csv

from PyQt5.QtGui import QIcon
from PyQt5 import uic, QtCore
from PyQt5.QtWidgets import *
from qgis.core import QgsProject, QgsMapLayer
from .gla_loose import LithLoose
from .gla_solid import LithSolid
from .resources import *

FORM_CLASS, _ = uic.loadUiType(os.path.join(
    os.path.dirname(__file__), 'ui_files', 'ui_gla_dialog.ui'))

class GLAMethod(QDialog, FORM_CLASS):
    def __init__(self, parent=None):
        """Constructor."""
        super(GLAMethod, self).__init__(parent)
        # Set up the user interface from Designer.
        # After setupUI you can access any designer object by doing
        # self.<objectname>, and you can use autoconnect slots - see
        # http://qt-project.org/doc/qt-4.8/designer-using-a-ui-file.html
        # #widgets-and-dialogs-with-auto-connect
        self.setupUi(self)

        #################################
        ##### DEFINE SUB FUNCTIONS ######
        #################################

        def switch_radio_button(cmb_tab_on, cmb_tab_off):
            ## enable preferred combobox
            ## disable other combobox
            cmb_tab_on.setEnabled(True)
            cmb_tab_off.setEnabled(False)

        def enable_additional_input(chb_additional, cmb_additional):
            if chb_additional.isChecked():
                cmb_additional.setEnabled(True)
            else:
                cmb_additional.setEnabled(False)

        def load_lookup_tables(default_lookup_dir=None):
            ## initially, clear all lookup tables
            self.cmb_awc_tab.clear()
            self.cmb_gwr_tab.clear()
            self.cmb_cwb_tab.clear()
            self.cmb_lith_l_tab.clear()
            self.cmb_struc_tab.clear()
            self.cmb_type_tab.clear()

            ## get current lookup directory
            current_lookup_dir = self.le_lookup_dir.text()

            lookup_tables = []

            ## check predefined path for lookup tables
            for table in os.listdir(current_lookup_dir):
                lookup_tables.append(table)

            ## fill each combobox with the complete list of available lookup tables
            self.cmb_awc_tab.addItems(lookup_tables)
            self.cmb_gwr_tab.addItems(lookup_tables)
            self.cmb_cwb_tab.addItems(lookup_tables)
            self.cmb_lith_l_tab.addItems(lookup_tables)
            self.cmb_struc_tab.addItems(lookup_tables)
            self.cmb_type_tab.addItems(lookup_tables)

            ## fill combobox with corresponding lookup table
            ## only if using default lookup table
            if current_lookup_dir == default_lookup_dir:
                awc = 'awc.csv'
                gwr = 'gwr.csv'
                cwb = 'cwb.csv'
                lith_l = 'loose_rock.csv'
                lith_struc = 'solid_rock_structure.csv'
                lith_type = 'solid_rock_type.csv'

                try:
                    self.cmb_awc_tab.setCurrentIndex(lookup_tables.index(awc))
                    self.cmb_gwr_tab.setCurrentIndex(lookup_tables.index(gwr))
                    self.cmb_cwb_tab.setCurrentIndex(lookup_tables.index(cwb))
                    self.cmb_lith_l_tab.setCurrentIndex(lookup_tables.index(lith_l))
                    self.cmb_struc_tab.setCurrentIndex(lookup_tables.index(lith_struc))
                    self.cmb_type_tab.setCurrentIndex(lookup_tables.index(lith_type))

                except ValueError as e:
                    print(e)
                    pass

            else:
                self.cmb_awc_tab.setCurrentIndex(0)
                self.cmb_gwr_tab.setCurrentIndex(0)
                self.cmb_cwb_tab.setCurrentIndex(0)
                self.cmb_lith_l_tab.setCurrentIndex(0)
                self.cmb_struc_tab.setCurrentIndex(0)
                self.cmb_type_tab.setCurrentIndex(0)

        def build_lookup_dict():
            ## build complete file path for each lookup table
            ## write to dictionary
            lookup_dict = {}

            lookup_dict['awc'] = os.path.join(self.le_lookup_dir.text(), self.cmb_awc_tab.currentText())
            lookup_dict['gwr'] = os.path.join(self.le_lookup_dir.text(), self.cmb_gwr_tab.currentText())
            lookup_dict['cwb'] = os.path.join(self.le_lookup_dir.text(), self.cmb_cwb_tab.currentText())
            lookup_dict['lith_l'] = os.path.join(self.le_lookup_dir.text(), self.cmb_lith_l_tab.currentText())
            lookup_dict['lith_struc'] = os.path.join(self.le_lookup_dir.text(), self.cmb_struc_tab.currentText())
            lookup_dict['lith_type'] = os.path.join(self.le_lookup_dir.text(), self.cmb_type_tab.currentText())

            return lookup_dict

        def open_custom_lookup_dir(current_lookup_dir):
            ## open a directory dialog, when clicking the browse button in the look up tab

            self.le_lookup_dir.clear()

            ## open file dialog
            new_lookup_dir = QFileDialog.getExistingDirectory(self)

            ## if file dialog is cancelled, insert current lookup dir
            if not new_lookup_dir:
                self.le_lookup_dir.insert(current_lookup_dir)

            else:
                self.le_lookup_dir.insert(new_lookup_dir)
                load_lookup_tables(new_lookup_dir)

        def reload_default_lookup_dir(default_lookup_dir):
            self.le_lookup_dir.clear()
            self.le_lookup_dir.insert(default_lookup_dir)
            load_lookup_tables(default_lookup_dir)

        def declare_output():
            ## open a file dialog, when clicking the browse button in the input tab
            ## choose a directory for the output file to be saved in
            ## choose a filename to be saved and return the filepath to the variable 'output_name'

            ## execute QFileDialog, when clicking the output button and get filename to be saved
            ## QFileDialog returns a two element tuple: (file name, file extension)
            self.le_output.clear()

            output_name = QFileDialog.getSaveFileName(self, 'Raster output file','.', 'GeoTiff (*.tif)')[0]

            if len(output_name) == 0:
                self.le_output.clear()

            else:
                if os.path.splitext(output_name)[-1].lower() == '.tif' or os.path.splitext(output_name)[-1].lower() == '.tiff':
                    output_name = output_name

                else:
                    output_name = output_name + '.tif'

            ## display filename in line output
            self.le_output.insert(output_name)

        def open_dialog_l(lookup_dict):
            ## open dialog for loose lithology
            lith_l = LithLoose(lookup_dict['lith_l'], self.layer_l_dict)
            lith_l.show()
            lith_l.exec_()

            self.layer_l_dict = lith_l.layer_dict

            ## populate gla_layers dictionary with loose lithology
            for loose in self.layer_l_dict:
                self.gla_layers['loose'][loose] = {'path':self.layer_l_dict[loose]['path'], \
                                                    'name':self.layer_l_dict[loose]['name'], \
                                                    'lith_short':self.layer_l_dict[loose]['lith_short'], \
                                                    'org':self.layer_l_dict[loose]['org'], \
                                                    'score':self.layer_l_dict[loose]['score']}

        def open_dialog_s(lookup_dict):
            ## open dialog for solid lithology
            lith_s = LithSolid(lookup_dict['lith_type'], lookup_dict['lith_struc'], self.layer_s_dict)
            lith_s.show()
            lith_s.exec_()

            self.layer_s_dict = lith_s.layer_dict

            ## populate gla_layers dictionary with solid lithology
            for solid in self.layer_s_dict:
                self.gla_layers['solid'][solid] = {'path':self.layer_s_dict[solid]['path'], \
                                                    'name':self.layer_s_dict[solid]['name'], \
                                                    'lith_type':self.layer_s_dict[solid]['lith_type'], \
                                                    'lith_struc':self.layer_s_dict[solid]['lith_struc'], \
                                                    'score_type':self.layer_s_dict[solid]['score_type'], \
                                                    'score_struc':self.layer_s_dict[solid]['score_struc']}

        def insert_layer(object, item):
            self.gla_layers[item]['path'] = object.itemData(object.currentIndex())
            self.gla_layers[item]['name'] = object.currentText()

        ################################
        ## INITIALLY FILL COMBOBOXES ###
        ################################

        ## populate the comboboxes for input datasets
        self.gla_layers = {'awc':{'path':None, 'name':None}, \
                            'perc':{'path':None, 'name':None}, \
                            'perched':{'path':None, 'name':None}, \
                            'art':{'path':None, 'name':None}, \
                            'loose':{}, \
                            'solid':{}}

        insert_layer(self.cmb_awc, item='awc')
        insert_layer(self.cmb_perc, item='perc')
        insert_layer(self.cmb_perched, item='perched')
        insert_layer(self.cmb_art, item='art')

        ################################
        ##### DEFINE LOOKUP TABLES #####
        ################################

        ## get current directory and set default lookup directory
        current_dir = os.path.dirname(__file__)
        default_lookup_dir = os.path.join(current_dir, 'lookup')
        self.le_lookup_dir.insert(default_lookup_dir)
        load_lookup_tables(default_lookup_dir)

        #####################################
        ##### MANAGE PERCOLATING LOOKUP #####
        #####################################

        ## set climatic data disabled, by default, enable by toggling radio buttons above
        self.cmb_cwb_tab.setEnabled(False)
        self.rbtn_gwr.setChecked(True)

        ## choose to input either groundwater recharge or climate data
        ## groundwater recharge data is set to default
        self.rbtn_gwr.clicked.connect(lambda: switch_radio_button(self.cmb_gwr_tab, self.cmb_cwb_tab))
        self.rbtn_cwb.clicked.connect(lambda: switch_radio_button(self.cmb_cwb_tab, self.cmb_gwr_tab))

        ###################################
        ##### MANAGE ADDITIONAL INPUT #####
        ###################################

        ## perched and artesian groundwater is optional
        ## disable combo boxes for perched and artesian groundwater
        ## enable only when data is available
        self.cmb_perched.setEnabled(False)
        self.cmb_art.setEnabled(False)
        
        self.chb_perched.stateChanged.connect(lambda: enable_additional_input(self.chb_perched, self.cmb_perched))
        self.chb_art.stateChanged.connect(lambda: enable_additional_input(self.chb_art, self.cmb_art))

        ################################
        ######## MANAGE BUTTONS ########
        ################################

        ## add icons
        self.btn_lookup_reload.setIcon(QIcon(os.path.join(current_dir, 'icons', 'reload.png')))

        self.btn_lookup_dir.clicked.connect(lambda: open_custom_lookup_dir(self.le_lookup_dir.text()))
        self.btn_lookup_reload.clicked.connect(lambda: reload_default_lookup_dir(default_lookup_dir))

        lookup_dict = build_lookup_dict()

        ## define layer dictionaries for lithology association
        self.layer_l_dict = {}
        self.layer_s_dict = {}

        self.btn_lith_l.clicked.connect(lambda: open_dialog_l(lookup_dict))
        self.btn_lith_s.clicked.connect(lambda: open_dialog_s(lookup_dict))
        self.btn_output.clicked.connect(declare_output)
        self.btn_close.clicked.connect(self.close)
        self.btn_run.clicked.connect(lambda: self.accept(lookup_dict))

        self.cmb_awc.activated.connect(lambda: insert_layer(self.cmb_awc, item='awc'))
        self.cmb_perc.activated.connect(lambda: insert_layer(self.cmb_perc, item='perc'))
        self.cmb_perched.activated.connect(lambda: insert_layer(self.cmb_perched, item='perched'))
        self.cmb_art.activated.connect(lambda: insert_layer(self.cmb_art, item='art'))

        self.msg_window.clear()
        self.msg_window.append('Ready!')

    def accept(self, lookup_dict):
        if not self.le_output.text():
            QMessageBox.information(None, 'Info', 'No output file specified')

            return

        ## retrieve projection information from master layer
        ## awc layer by default
        master_layer = self.gla_layers['awc']['path']
        master_extent, master_dims = self.get_geometry(master_layer)
        cols, rows = master_dims
        master_dataset = gdal.Open(master_layer)
        master_geotransform = master_dataset.GetGeoTransform()
        master_proj = master_dataset.GetProjection()

        master_band = master_dataset.GetRasterBand(1)
        master_datatype = gdal.GetDataTypeName(master_band.DataType)

        master_band = None
        master_dataset = None

        array_awc = self.get_array_from_layer(self.gla_layers['awc']['path'])
        array_perc = self.get_array_from_layer(self.gla_layers['perc']['path'])
        
        self.msg_window.append('Start calculation of AWC layer.')
        array_awc = self.calc_awc_perc(array_awc, lookup_dict['awc'])
        self.msg_window.append('Calculation of AWC layer finished.')

        self.msg_window.append('Start calculation of percolation layer.')

        if self.rbtn_gwr.isChecked():
            self.msg_window.append('GWR selected as percolation method.')

            array_perc = self.calc_awc_perc(array_perc, lookup_dict['gwr'])

        elif self.rbtn_cwb.isChecked():
            self.msg_window.append('CWB selected as percolation method.')

            array_perc = self.calc_awc_perc(array_perc, lookup_dict['cwb'])

        self.msg_window.append('Calculation of percolation layer finished.')

        ## create array_l and array_s to be 0-scalar
        ## if no lithology is loaded, awc + array_l or 0 = awc
        ## if lithology is loaded, these scalars will be overwritten
        array_l = 0
        array_s = 0
        array_perched = 0
        array_art = 0

        count_l = 0
        count_s = 0

        if self.gla_layers['loose']:
            self.msg_window.append('Start calculation of unconsolidated lithology layer.')

            ## create zeros array to store merged scores of lithology layers
            array_init = np.zeros((rows, cols), dtype=np.int32)

            for item in self.gla_layers['loose']:
                ## if initial score array has been overwritten, i.e. this is the 2nd iterator
                if type(array_l) != int:
                    array_init = array_l

                array_l = self.calc_lith_l(item)
                array_l = array_init + array_l

            self.msg_window.append('Calculation of unconsolidated lithology layer finished.')

        if self.gla_layers['solid']:
            self.msg_window.append('Start calculation of consolidated lithology layer.')

            ## create zeros array to store merged scores of lithology layers
            array_init = np.zeros((rows, cols), dtype=np.int32)

            for item in self.gla_layers['solid']:
                ## if initial score array has been overwritten, i.e. this is the 2nd iterator
                if type(array_s) != int:
                    array_init = array_s

                array_s = self.calc_lith_s(item)
                array_s = array_init + array_s

            self.msg_window.append('Calculation of consolidated lithology layer finished.')

        if self.chb_perched.isChecked() and self.gla_layers['perched']['path'] != None:
            self.msg_window.append('Occurence of perched groundwater selected.')
            self.msg_window.append('Start calculation of perched groundwater layer.')

            array_perched = self.get_array_from_layer(self.gla_layers['perched']['path'])
            array_perched = array_perched * 500

            self.msg_window.append('Calculation of perched groundwater layer finished.')

        if self.chb_art.isChecked() and self.gla_layers['art']['path'] != None:
            self.msg_window.append('Occurence of artesian groundwater selected.')
            self.msg_window.append('Start calculation of artesian groundwater layer.')

            array_art = self.get_array_from_layer(self.gla_layers['art']['path'])
            array_art = array_art * 1500

            self.msg_window.append('Calculation of artesian groundwater layer finished.')

        self.msg_window.append('Calculation of input layers finished.')
        self.msg_window.append('Start calculation of groundwater vulnerability.')

        array_gla = (array_awc + array_l + array_s) * array_perc + array_perched + array_art

        self.msg_window.append('Calculation of groundwater vulnerability finished.')

        ## save gla array to desired output file
        output_name = self.le_output.text()

        self.msg_window.append('Save grounwater vulnerability result to {}.'.format(output_name))

        if self.chb_proc.isChecked():
            proc_dict = {'proc_awc':array_awc, 'proc_perc':array_perc, 'proc_l':array_l, 'proc_s':array_s, 'proc_perched':array_perched, 'proc_art':array_art}

            path, filename = os.path.split(output_name)

            self.msg_window.append('Additionally, save intermediate results to {}.'.format(path))

            for key, proc_array in proc_dict.items():
                if key == 'proc_perched' and type(array_perched) != np.ndarray:
                    continue

                if key == 'proc_art' and type(array_art) != np.ndarray:
                    continue

                output_name = os.path.join(path, key + '.tif')
                self.save_array_to_raster(proc_array, output_name, master_geotransform, master_proj, master_datatype)

        else:
            self.save_array_to_raster(array_gla, output_name, master_geotransform, master_proj, master_datatype)

        self.msg_window.append('Processing finished. Result written to disk.')

    def save_array_to_raster(self, array, out_file, geotransform, proj, datatype):
        try:
            os.remove(out_file)

        except Exception as e:
            print(e)

        rows = array.shape[0]
        cols = array.shape[1]

        datatype = gdal.GDT_Float32

        driver_name = 'GTiff'

        driver_out = gdal.GetDriverByName(driver_name)
        dataset_out = driver_out.Create(out_file, cols, rows, 1, datatype)

        dataset_out.SetGeoTransform(geotransform)
        dataset_out.SetProjection(proj)
        band_out = dataset_out.GetRasterBand(1)

        band_out.WriteArray(array)
        dataset_out.FlushCache()

    def calc_awc_perc(self, array, lookup_csv_path):
        ## read values from csv files
        lookup_csv = csv.DictReader(open(lookup_csv_path, 'r'), delimiter=';')
        lookup_dict = {'min':[], 'max':[], 'score':[]}

        score_array = np.zeros(array.shape)

        for line in lookup_csv:
            try:
                min_int = int(line['Min'])
                min_float = None

            except ValueError:
                min_float = float(line['Min'])
                min_int = None

            if min_int != None:
                ## try to convert string in lookup to int
                lookup_dict['min'].append(int(line['Min']))
                lookup_dict['max'].append(int(line['Max']))
                lookup_dict['score'].append(float(line['Score']))

            elif min_float != None:
                ## conversion may fail if string is real
                ## then convert to float
                lookup_dict['min'].append(float(line['Min']))
                lookup_dict['max'].append(float(line['Max']))
                lookup_dict['score'].append(float(line['Score']))

        for index, score in enumerate(lookup_dict['score']):
            score_index = np.where((array > lookup_dict['min'][index]) & (array <= lookup_dict['max'][index]))
            score_array[score_index] = score

        return score_array

    def calc_lith_l(self, item):
        layer = self.gla_layers['loose'][item]['path']
        lith_short = self.gla_layers['loose'][item]['lith_short']
        lith_org = self.gla_layers['loose'][item]['org']

        try:
            ## try to convert string in lookup to int
            score = int(self.gla_layers['loose'][item]['score'])

        except ValueError:
            ## conversion may fail if string is real
            ## then convert to float
            score = float(self.gla_layers['loose'][item]['score'])

        if lith_org == 'yes':
            org = 75

        elif lith_org == 'no':
            org = 0

        array = self.get_array_from_layer(layer)
        array *= (score + org)

        return array

    def calc_lith_s(self, item):
        layer = self.gla_layers['solid'][item]['path']
        lith_type = self.gla_layers['solid'][item]['lith_type']
        lith_struc = self.gla_layers['solid'][item]['lith_struc']
        
        try:
            ## try to convert string in lookup to int
            score_type = int(self.gla_layers['solid'][item]['score_type'])
            score_struc = int(self.gla_layers['solid'][item]['score_struc'])

        except ValueError:
            ## conversion may fail if string is real
            ## then convert to float
            score_type = float(self.gla_layers['solid'][item]['score_type'])
            score_struc = float(self.gla_layers['solid'][item]['score_struc'])

        array = self.get_array_from_layer(layer)
        array *= (score_type * score_struc)

        return array

    def get_array_from_layer(self, layer):
        dataset = gdal.Open(layer)
        nodata_value = dataset.GetRasterBand(1).GetNoDataValue()

        array = np.array(dataset.GetRasterBand(1).ReadAsArray())

        ## change nodata value to 0
        array[array == nodata_value] = 0

        return array

    def get_geometry(self, layer):
        dataset = gdal.Open(layer)
        
        ## get number of rows and cols
        cols = dataset.RasterXSize
        rows = dataset.RasterYSize

        # get georeference info from FC image
        geotransform = dataset.GetGeoTransform()
        x_min = geotransform[0]
        y_max = geotransform[3]
        width = geotransform[1]
        height = geotransform[5] ## height is negative
        x_skew = geotransform[2]
        y_skew = geotransform[4]

        x_max = x_min + (cols * width)
        y_min = y_max - (rows * abs(height))

        extent = [x_min, x_max, y_min, y_max]
        dims = [cols, rows]

        return extent, dims
