# -*- coding: utf-8 -*-
"""
/***************************************************************************
 FLO2DMapCrafter
                                 A QGIS plugin
 This plugin creates maps from FLO-2D output files.
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2023-09-21
        git sha              : $Format:%H$
        copyright            : (C) 2023 by FLO-2D
        email                : contact@flo-2d.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""
import os

import numpy as np
from PyQt5.QtWidgets import QMessageBox
from osgeo import gdal
from qgis._core import QgsProject, QgsRasterLayer, QgsUnitTypes, QgsMessageLog

from flo2d_mapcrafter.mapping.check_data import check_project_id, check_mapping_group, check_raster_file
from flo2d_mapcrafter.mapping.scripts import read_ASCII, remove_layer, set_raster_style

import processing

class HazardMaps:

    def __init__(self, units_switch):
        """
        Class constructor
        :param units_switch: 0 english 1 metric
        """
        self.units_switch = units_switch

    def check_hazard_files(self, output_dir):
        """
        Function to check the hazard maps that can be created
        """

        hazard_maps = {
            "ARR": False,
            "Austrian": False,
            "FLO-2D": False,
            "Swiss": [False, False],
            "UK": False,
            "USBR": [False, False, False, False, False],
            "FEMA": False
        }

        # Australian Rainfall and Runoff (ARR)
        arr_files = {
            r"DEPTH.OUT": False,
            r"VELFP.OUT": False,
            r"VEL_X_DEPTH.OUT": False
        }

        # SWISS FLOOD INTENSITY
        swiss_files = {
            r"DEPTH.OUT": False,
            r"VEL_X_DEPTH.OUT": False,
            r"VELFP.OUT": False
        }

        # US Bureau of Reclamation
        usbr_files = {
            r"DEPTH.OUT": False,
            r"VELFP.OUT": False,
        }

        files = os.listdir(output_dir)
        for file in files:
            for key, value in arr_files.items():
                if file.startswith(key):
                    arr_files[key] = True
            for key, value in swiss_files.items():
                if file.startswith(key):
                    swiss_files[key] = True
            for key, value in usbr_files.items():
                if file.startswith(key):
                    usbr_files[key] = True

        # ARR Check if all files are true
        if all(value for value in arr_files.values()):
            hazard_maps["ARR"] = True

        # SWISS Check if all files are true
        if all(value for value in arr_files.values()):
            hazard_maps["Swiss"] = [True, True]

        if all(value for value in usbr_files.values()):
            hazard_maps["USBR"] = [True, True, True, True, True]

        return hazard_maps

    def create_maps(self, hazard_rbs, flo2d_results_dir, map_output_dir, mapping_group, crs, project_id):
        """
        Function to create the maps
        """

        mapping_group_name = check_project_id("Hazard Maps", project_id)
        mapping_group = check_mapping_group(mapping_group_name, mapping_group)

        vector_style_directory = os.path.dirname(os.path.realpath(__file__))[:-8] + r"\vector_styles"
        raster_style_directory = os.path.dirname(os.path.realpath(__file__))[:-8] + r"\raster_styles"

        # ARR
        ARR_group_name = "Australian Rainfall and Runoff (ARR)"
        if mapping_group.findGroup(ARR_group_name):
            ARR_group = mapping_group.findGroup(ARR_group_name)
        else:
            ARR_group = mapping_group.insertGroup(0, ARR_group_name)

        # Basic
        SWISS_group_name = "SWISS"
        if mapping_group.findGroup(SWISS_group_name):
            SWISS_group = mapping_group.findGroup(SWISS_group_name)
        else:
            SWISS_group = mapping_group.insertGroup(0, SWISS_group_name)

        # Derived
        USBR_group_name = "US Bureau of Reclamation"
        if mapping_group.findGroup(USBR_group_name):
            USBR_group = mapping_group.findGroup(USBR_group_name)
        else:
            USBR_group = mapping_group.insertGroup(0, USBR_group_name)

        # ARR
        if hazard_rbs.get("ARR"):
            name = check_project_id("ARR_FLOOD_HAZARD", project_id)
            name, raster = check_raster_file(name, map_output_dir)
            depth_file = flo2d_results_dir + r"\DEPTH.OUT"
            vel_file = flo2d_results_dir + r"\VELFP.OUT"
            vel_x_depth_file = flo2d_results_dir + r"\VEL_X_DEPTH.OUT"

            hydro_risk_raster = self.create_arr_map(
                map_output_dir, raster, depth_file, vel_file, vel_x_depth_file, crs, project_id
            )

            QgsProject.instance().addMapLayer(hydro_risk_raster, False)
            set_raster_style(hydro_risk_raster, 2)

            ARR_group.insertLayer(0, hydro_risk_raster)

        usbr_maps = hazard_rbs.get("USBR")
        for index, hazard_type in enumerate(usbr_maps):
            if hazard_type:
                depth_file = flo2d_results_dir + r"\DEPTH.OUT"
                vel_file = flo2d_results_dir + r"\VELFP.OUT"
                vel_data = np.loadtxt(vel_file, skiprows=0)
                depth_data = np.loadtxt(depth_file, skiprows=0)
                # Houses
                if index == 0:
                    name = check_project_id("USBR_HOUSES_HAZARD", project_id)
                    name, raster = check_raster_file(name, map_output_dir)
                    hydro_risk_raster = self.create_usbr_map(
                        name, raster, depth_data, vel_data, index, crs
                    )
                    QgsProject.instance().addMapLayer(hydro_risk_raster, False)
                    set_raster_style(hydro_risk_raster, 8)
                    USBR_group.insertLayer(0, hydro_risk_raster)
                # mobile house
                if index == 1:
                    name = check_project_id("USBR_MOBILE_HAZARD", project_id)
                    name, raster = check_raster_file(name, map_output_dir)
                    hydro_risk_raster = self.create_usbr_map(
                        name, raster, depth_data, vel_data, index, crs
                    )
                    QgsProject.instance().addMapLayer(hydro_risk_raster, False)
                    set_raster_style(hydro_risk_raster, 8)
                    USBR_group.insertLayer(0, hydro_risk_raster)
                # vehicle
                if index == 2:
                    name = check_project_id("USBR_VEHICLE_HAZARD", project_id)
                    name, raster = check_raster_file(name, map_output_dir)
                    hydro_risk_raster = self.create_usbr_map(
                        name, raster, depth_data, vel_data, index, crs
                    )
                    QgsProject.instance().addMapLayer(hydro_risk_raster, False)
                    set_raster_style(hydro_risk_raster, 8)
                    USBR_group.insertLayer(0, hydro_risk_raster)
                # adults
                if index == 3:
                    name = check_project_id("USBR_ADULTS_HAZARD", project_id)
                    name, raster = check_raster_file(name, map_output_dir)
                    hydro_risk_raster = self.create_usbr_map(
                        name, raster, depth_data, vel_data, index, crs
                    )
                    QgsProject.instance().addMapLayer(hydro_risk_raster, False)
                    set_raster_style(hydro_risk_raster, 8)
                    USBR_group.insertLayer(0, hydro_risk_raster)
                # children
                if index == 4:
                    name = check_project_id("USBR_CHILDREN_HAZARD", project_id)
                    name, raster = check_raster_file(name, map_output_dir)
                    hydro_risk_raster = self.create_usbr_map(
                        name, raster, depth_data, vel_data, index, crs
                    )
                    QgsProject.instance().addMapLayer(hydro_risk_raster, False)
                    set_raster_style(hydro_risk_raster, 8)
                    USBR_group.insertLayer(0, hydro_risk_raster)

        swiss_maps = hazard_rbs.get("Swiss")
        for index, hazard_type in enumerate(swiss_maps):
            if hazard_type:
                depth_file = flo2d_results_dir + r"\DEPTH.OUT"
                vel_file = flo2d_results_dir + r"\VELFP.OUT"
                vel_x_depth_file = flo2d_results_dir + r"\VEL_X_DEPTH.OUT"
                vel_x_depth_data = np.loadtxt(vel_x_depth_file, skiprows=0)
                depth_data = np.loadtxt(depth_file, skiprows=0)
                vel_data = np.loadtxt(vel_file, skiprows=0)
                # Flood intensity
                if index == 0:
                    name = check_project_id("SWISS_FLOOD_INTENSITY", project_id)
                    name, raster = check_raster_file(name, map_output_dir)
                    hydro_risk_raster = self.create_swiss_map(
                        name, raster, depth_data, vel_data, vel_x_depth_data, index, crs
                    )
                    QgsProject.instance().addMapLayer(hydro_risk_raster, False)
                    set_raster_style(hydro_risk_raster, 8)
                    SWISS_group.insertLayer(0, hydro_risk_raster)
                if index == 1:
                    name = check_project_id("SWISS_DEBRIS_INTENSITY", project_id)
                    name, raster = check_raster_file(name, map_output_dir)
                    hydro_risk_raster = self.create_swiss_map(
                        name, raster, depth_data, vel_data, vel_x_depth_data, index, crs
                    )
                    QgsProject.instance().addMapLayer(hydro_risk_raster, False)
                    set_raster_style(hydro_risk_raster, 8)
                    SWISS_group.insertLayer(0, hydro_risk_raster)

        # Uncheck and Collapse the layers added
        allLayers = mapping_group.findLayers()
        for layer in allLayers:
            lyr = QgsProject.instance().layerTreeRoot().findLayer(layer.layerId())
            lyr.setItemVisibilityChecked(False)
            lyr.setExpanded(False)


    def create_swiss_map(self, name, hydro_risk, depth_data, vel_data, vel_x_depth_data, map_type, crs):
        """Create the SWISS flood intensity map"""

        # adjust units
        if self.units_switch == "1":
            uc = 1
        else:
            uc = 3.28

        values = []
        cellSize_data = []

        # Flood Intensity
        if map_type == 0:
            for (id_v, x, y, depth_x_velocity), (_, _, _, depth) in zip(vel_x_depth_data, depth_data):
                if depth != 0:
                    # Unit conversion
                    depth = depth * uc
                    depth_x_velocity = depth_x_velocity * (uc ** 2)
                    # low intensity
                    if depth > 2 or depth_x_velocity > 2:
                        values.append((x, y, 3))
                        if len(cellSize_data) < 2:
                            cellSize_data.append((x, y))
                    # moderate intensity
                    elif 0.5 < depth < 2 or 0.5 < depth_x_velocity < 2:
                        values.append((x, y, 2))
                        if len(cellSize_data) < 2:
                            cellSize_data.append((x, y))
                    # high intensity
                    else:
                        values.append((x, y, 1))
                        if len(cellSize_data) < 2:
                            cellSize_data.append((x, y))

        # Debris Intensity
        if map_type == 1:
            for (id_v, x, y, depth), (_, _, _, velocity) in zip(depth_data, vel_data):
                if depth != 0:
                    # Unit conversion
                    depth = depth * uc
                    velocity = velocity * uc
                    # high intensity
                    if depth > 1 and velocity > 1:
                        values.append((x, y, 3))
                        if len(cellSize_data) < 2:
                            cellSize_data.append((x, y))
                    # moderate intensity
                    elif depth < 1 or velocity < 1:
                        values.append((x, y, 2))
                        if len(cellSize_data) < 2:
                            cellSize_data.append((x, y))

        # Calculate the differences in X and Y coordinates
        dx = cellSize_data[1][0] - cellSize_data[0][0]
        dy = cellSize_data[1][1] - cellSize_data[0][1]

        if dx != 0:
            cellSize = int(abs(dx))
        if dy != 0:
            cellSize = int(abs(dy))

        # Get the extent and number of rows and columns
        min_x = min(point[0] for point in values)
        max_x = max(point[0] for point in values)
        min_y = min(point[1] for point in values)
        max_y = max(point[1] for point in values)
        num_cols = int((max_x - min_x) / cellSize) + 1
        num_rows = int((max_y - min_y) / cellSize) + 1

        # Convert the list of values to an array.
        raster_data = np.full((num_rows, num_cols), -9999, dtype=np.float32)
        for point in values:
            if point[2] != 0:
                col = int((point[0] - min_x) / cellSize)
                row = int((max_y - point[1]) / cellSize)
                raster_data[row, col] = point[2]

        # Initialize the raster
        driver = gdal.GetDriverByName("GTiff")
        raster = driver.Create(hydro_risk, num_cols, num_rows, 1, gdal.GDT_Float32)
        raster.SetGeoTransform(
            (
                min_x - cellSize / 2,
                cellSize,
                0,
                max_y + cellSize / 2,
                0,
                -cellSize,
            )
        )
        raster.SetProjection(crs.toWkt())

        band = raster.GetRasterBand(1)
        band.SetNoDataValue(-9999)  # Set a no-data value if needed
        band.WriteArray(raster_data)

        raster.FlushCache()

        layer = QgsRasterLayer(hydro_risk, name)

        return layer

    def create_arr_map(
        self, map_output_dir, hydro_risk, depth_file, vel_file, vel_x_depth_file, crs, project_id
    ):
        """Create the ARR hydrodynamic risk map"""

        name_speed = check_project_id("FLOW_SPEED", project_id)
        name, flow_speed = check_raster_file(name_speed, map_output_dir)

        name_depth = check_project_id("FLOOD_DEPTH", project_id)
        name, flood_depth = check_raster_file(name_depth, map_output_dir)

        name_hxv = check_project_id("HxV", project_id)
        name, h_x_v = check_raster_file(name_hxv, map_output_dir)

        # Check flood depth and flow speed files
        # flow_speed = map_output_dir + r"\FLOW_SPEED.tif"
        # flood_depth = map_output_dir + r"\FLOOD_DEPTH.tif"
        # h_x_v = map_output_dir + r"\HxV.tif"

        if os.path.isfile(flood_depth):
            QgsProject.instance().addMapLayer(QgsRasterLayer(flood_depth, name_depth), True)
        else:
            read_ASCII(depth_file, flood_depth, name_depth, crs)
            QgsProject.instance().addMapLayer(QgsRasterLayer(flood_depth, name_depth), True)

        if os.path.isfile(flow_speed):
            QgsProject.instance().addMapLayer(QgsRasterLayer(flow_speed, name_speed), True)
        else:
            read_ASCII(vel_file, flow_speed, name_speed, crs)
            QgsProject.instance().addMapLayer(QgsRasterLayer(flow_speed, name_speed), True)

        if os.path.isfile(h_x_v):
            QgsProject.instance().addMapLayer(QgsRasterLayer(h_x_v, name_hxv), True)
        else:
            read_ASCII(vel_x_depth_file, h_x_v, name_hxv, crs)
            QgsProject.instance().addMapLayer(QgsRasterLayer(h_x_v, name_hxv), True)

        # if os.path.isfile(hydro_risk):
        #     try:
        #         remove_layer("ARR_FLOOD_HAZARD")
        #         os.remove(hydro_risk)
        #     except OSError as e:
        #         print(f"Error deleting {hydro_risk}: {str(e)}")

        # adjust units
        if self.units_switch == "1":
            uc = 1
        else:
            uc = 3.28

        r0_e = f'"{name_hxv}@1" = 0 AND "{name_depth}@1" = 0 AND "{name_speed}@1" = 0'
        r1_e = f'"{name_hxv}@1" <= {0.3 * uc} AND "{name_depth}@1" < {0.3 * uc} AND "{name_speed}@1" < {2 * uc}'
        r2_e = f'"{name_hxv}@1" <= {0.6 * uc} AND "{name_depth}@1" < {0.5 * uc} AND "{name_speed}@1" < {2 * uc}'
        r3_e = f'"{name_hxv}@1" <= {0.6 * uc} AND "{name_depth}@1" < {1.2 * uc} AND "{name_speed}@1" < {2 * uc}'
        r4_e = f'"{name_hxv}@1" <= {1.0 * uc} AND "{name_depth}@1" < {2.0 * uc} AND "{name_speed}@1" < {2 * uc}'
        r5_e = f'"{name_hxv}@1" <= {4.0 * uc} AND "{name_depth}@1" < {4.0 * uc} AND "{name_speed}@1" < {4 * uc}'
        r6_e = f'"{name_hxv}@1" > {4.0 * uc} OR "{name_depth}@1" >= {4.0 * uc} OR "{name_speed}@1" >= {4 * uc}'

        QgsMessageLog.logMessage(f"IF({r0_e},0,if({r1_e},1,if({r2_e},2,if({r3_e},3,if({r4_e},4,if({r5_e},5,if({r6_e},6,0)))))))")

        # Australian Rainfall and Runoff Classification
        arr_class = processing.run(
            "qgis:rastercalculator",
            {
                "EXPRESSION": f"IF({r0_e},0,if({r1_e},1,if({r2_e},2,if({r3_e},3,if({r4_e},4,if({r5_e},5,if({r6_e},6,0)))))))",
                "LAYERS": [flood_depth],
                "CELLSIZE": 0,
                "EXTENT": None,
                "CRS": crs,
                "OUTPUT": hydro_risk,
            },
        )["OUTPUT"]

        remove_layer(name_depth)
        remove_layer(name_speed)
        remove_layer(name_hxv)

        name_arr = os.path.splitext(os.path.basename(hydro_risk))[0]

        return QgsRasterLayer(arr_class, name_arr)

    def create_usbr_map(self, name, hydro_risk, depth_data, vel_data, map_type, crs):
        """Create the USBR hydrodynamic risk map"""

        # adjust units
        if self.units_switch == "1":
            uc = 3.28
        else:
            uc = 1

        values = []
        cellSize_data = []
        for (id_v, x, y, velocity), (_, _, _, depth) in zip(vel_data, depth_data):

            if depth != 0 and velocity != 0:

                # Unit conversion
                depth = depth * uc
                velocity = velocity * uc

                # Houses
                if map_type == 0:
                    low_curve_value = 0.0004 * velocity ** 3 - 0.0121 * velocity ** 2 - 0.0809 * velocity + 3.076
                    high_curve_value = 0.0007 * velocity ** 3 - 0.0276 * velocity ** 2 + 0.0206 * velocity + 5.9005
                # Mobile
                if map_type == 1:
                    low_curve_value = -0.0007 * velocity ** 2 - 0.0308 * velocity + 1.9458
                    high_curve_value = -0.0009 * velocity ** 2 - 0.0262 * velocity + 2.5373
                # Vehicles
                if map_type == 2:
                    low_curve_value = 0.0002 * velocity ** 3 - 0.0009 * velocity ** 2 - 0.0904 * velocity + 2.0311
                    high_curve_value = 0.0004 * velocity ** 3 - 0.0056 * velocity ** 2 - 0.1036 * velocity + 3.0877
                # Adults
                if map_type == 3:
                    low_curve_value = -0.0053 * velocity ** 3 + 0.1241 * velocity ** 2 - 1.0323 * velocity + 3.1671
                    high_curve_value = -0.0011 * velocity ** 4 + 0.0282 * velocity ** 3 - 0.1888 * velocity ** 2 - 0.2374 * velocity + 4.633
                # Children
                if map_type == 4:
                    low_curve_value = 0.0726 * velocity ** 2 - 0.6786 * velocity + 1.5994
                    high_curve_value = 0.0029 * velocity ** 5 - 0.0526 * velocity ** 4 + 0.3337 * velocity ** 3 - 0.7657 * velocity ** 2 - 0.2936 * velocity + 3.0475

                # low danger
                if depth < low_curve_value:
                    values.append((x, y, 1))
                    if len(cellSize_data) < 2:
                        cellSize_data.append((x, y))
                # high danger
                elif depth > high_curve_value:
                    values.append((x, y, 3))
                    if len(cellSize_data) < 2:
                        cellSize_data.append((x, y))
                # judgment
                else:
                    values.append((x, y, 2))
                    if len(cellSize_data) < 2:
                        cellSize_data.append((x, y))

                # Fix maximums:
                if map_type == 0 and (depth > 10 or velocity > 25):
                    values.append((x, y, 3))
                if map_type == 1 and (depth > 3 or velocity > 16):
                    values.append((x, y, 3))
                if map_type == 2 and (depth > 4 or velocity > 16):
                    values.append((x, y, 3))
                if map_type == 3 and (depth > 5 or velocity > 12):
                    values.append((x, y, 3))
                if map_type == 4 and (depth > 4 or velocity > 8):
                    values.append((x, y, 3))


            # Calculate the differences in X and Y coordinates
        dx = cellSize_data[1][0] - cellSize_data[0][0]
        dy = cellSize_data[1][1] - cellSize_data[0][1]

        if dx != 0:
            cellSize = int(abs(dx))
        if dy != 0:
            cellSize = int(abs(dy))

        # Get the extent and number of rows and columns
        min_x = min(point[0] for point in values)
        max_x = max(point[0] for point in values)
        min_y = min(point[1] for point in values)
        max_y = max(point[1] for point in values)
        num_cols = int((max_x - min_x) / cellSize) + 1
        num_rows = int((max_y - min_y) / cellSize) + 1

        # Convert the list of values to an array.
        raster_data = np.full((num_rows, num_cols), -9999, dtype=np.float32)
        for point in values:
            if point[2] != 0:
                col = int((point[0] - min_x) / cellSize)
                row = int((max_y - point[1]) / cellSize)
                raster_data[row, col] = point[2]

        # Initialize the raster
        driver = gdal.GetDriverByName("GTiff")
        raster = driver.Create(hydro_risk, num_cols, num_rows, 1, gdal.GDT_Float32)
        raster.SetGeoTransform(
            (
                min_x - cellSize / 2,
                cellSize,
                0,
                max_y + cellSize / 2,
                0,
                -cellSize,
            )
        )
        raster.SetProjection(crs.toWkt())

        band = raster.GetRasterBand(1)
        band.SetNoDataValue(-9999)  # Set a no-data value if needed
        band.WriteArray(raster_data)

        raster.FlushCache()

        layer = QgsRasterLayer(hydro_risk, name)

        return layer