# -*- coding: utf-8 -*-

"""
/***************************************************************************
 dbpriskapp
                                 A QGIS plugin
 dbpriskapp
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2025-06-12
        copyright            : (C) 2025 by KIOS Smart Water Team
        email                : mkiria01@ucy.ac.cy
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'KIOS Smart Water Team'
__date__ = '2025-06-12'
__copyright__ = '(C) 2025 by KIOS Smart Water Team'

import os
import traceback

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

try:
    from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
    from matplotlib.figure import Figure
except:
    pass
from epyt import epanet


class WaterQualitySimulation:
    def __init__(self,
                 inpname,
                 msxname,
                 excel_file,
                 sensor_id = [],
                 t_d = 1,
                 msx_timestep = 300,
                 injection_rate = [],
                 multiplication_pattern = [],
                 multiplication_sensor_location = [],
                 custom_pattern_file_path = [],
                 monte_carlo_simulations = 1,
                 MSX_uncertainty = [],
                 species_names = [],
                 species_types = [],
                 initial_concentration = [],
                 chemical_value = [],
                 chemical_param = [],
                 Demand_Uncertainty = [],
                 Input_Type = [],
                 scenario_id = [],
                 feedback = None
                 ):
        """
        Initializes a Water Quality Simulation scenario set.

        Parameters:
        ----------
        inpname : str
            Path to the EPANET .inp file defining the water distribution network.

        msxname : str
            Path to the EPANET-MSX .msx file defining the chemical reactions and species.

        excel_file : pd.DataFrame
            A DataFrame (read from Excel) containing demand patterns or sensor data to support scenario calibration.

        sensor_id : list of str
            List of node IDs representing sensor or injection locations (aligned per scenario).

        t_d : int, default=1
            Duration of the simulation in days.

        msx_timestep : int, default=300
            Time step for MSX simulation in seconds (e.g., 300 = 5 minutes).

        injection_rate : list of float
            List of injection rates per scenario (used in Scenario 1: Injection Insertion).

        MSX_uncertainty : list of float
            List of uncertainty values (±%) for each scenario, typically for MSX reactions.

        species_names : list of str
            Species names (e.g., "THMs", "CL2") for each scenario.

        species_types : list of str
            Type of species being injected or set (e.g., "MASS", "CONCEN", "SETPOINT").

        initial_concentration : list of float
            Initial concentration values for each species (used in Scenario 2).

        chemical_value : list of float
            Numerical value for the chemical parameter in the MSX model (used in Scenario 3).

        chemical_param : list of str
            Names of chemical parameters being modified (used in Scenario 3).

        Demand_Uncertainty : list of float
            Demand uncertainty values (±%) for each scenario (used in Scenario 4).

        Input_Type : list of int
            Input type indicator per scenario:
                1 → Injection Insertion
                2 → Initial Concentration
                3 → Chemical Parameter
                4 → Demand Uncertainty

        scenario_id : list of int
            Scenario index, used to iterate over and reference scenario-specific parameters.

        patID : list of str or None
            Pattern ID(s) for demand patterns (optional, if pattern-based control is used).

        parameter_excel : list or None
            Additional parameter values extracted from the Excel file (optional).
        """
        self.inpname = inpname
        self.msxname = msxname
        self.excel_file = excel_file
        self.sensor_id = sensor_id
        self.t_d = t_d
        self.msx_timestep = msx_timestep

        # From parsed inputs
        self.injection_rate = injection_rate
        self.MSX_uncertainty = MSX_uncertainty
        self.species_names = species_names
        self.species_types = species_types
        self.initial_concentration = initial_concentration
        self.chemical_value = chemical_value
        self.chemical_param = chemical_param
        self.Demand_Uncertainty = Demand_Uncertainty
        self.Input_Type = Input_Type
        self.scenario_id = scenario_id
        self.multiplication_pattern = multiplication_pattern
        self.multiplication_sensors = multiplication_sensor_location
        self.custom_pattern_file = custom_pattern_file_path
        self.monte_carlo_simulations = monte_carlo_simulations
        self.feedback = feedback
        #LOAD
        #self.DMAcenterorganics = self.load_data(excel_file, 'DMA_DP1')
        # self.dataf = self.load_data_from_df(excel_file, "FCL")
        #
        # self.locarrays = []
        # self.locations = []
        # for loc, values_array in self.dataf.items():
        #     self.locarrays.append(values_array)
        #     self.locations.append(loc)
        # self.initials = self.get_initial_values() #initial values for simulation
        # self.means = self.compute_means() #mean values for simulation
        self.global_times = None
        self.t_sim = self.t_d * 24 * 60 * 60

        self.patID = {}  # for run simulation
        # initialize plugin directory
        self.plugin_dir = os.path.dirname(__file__)

    def load_data_from_df(self, df: pd.DataFrame, parameter: str = 'FCL') -> dict:
        """
        Given a pandas DataFrame, return a dictionary of sensor-location
        Parameters
        ----------
        df : pd.DataFrame
            The DataFrame containing at least the columns ['SensorLocation', 'ParameterName', 'time', 'value'].
        parameter : str, optional
            The name of the parameter to filter by (default is 'FCL').
        """
        # Ensure 'time' is datetime
        df['time'] = pd.to_datetime(df['time'], errors='coerce')
        unique_locations = df['SensorLocation'].unique()
        all_data = {}
        self.title = ""
        self.plot_times = {}
        self.global_times = None

        for loc in unique_locations:
            # Filter and sort
            param_data = df[(df['ParameterName'] == parameter) & (df['SensorLocation'] == loc)]
            param_data = param_data.dropna(subset=['time'])  # Drop rows with invalid dates
            param_data = param_data.sort_values(by='time')

            if param_data.empty:
                print(f"[WARNING] No data for ParameterName='{parameter}' at SensorLocation='{loc}'")
                continue

            values = param_data['value'].to_numpy()
            times = param_data['time'].to_numpy()

            all_data[loc] = values
            self.plot_times[loc] = times

            try:
                start_str = pd.Timestamp(times[0]).strftime("%d/%m/%y %H:%M")
                end_str = pd.Timestamp(times[-1]).strftime("%d/%m/%y %H:%M")
                self.title = f" from {start_str} to {end_str}"
            except IndexError:
                print(f"[ERROR] Time array is empty after filtering for location {loc}.")
                print(param_data)
                continue

            if self.global_times is None:
                self.global_times = times

        return all_data

    def fetch_custom_values(self, custom_excelfile):
        df = pd.read_csv(custom_excelfile, header=None)
        values = df.iloc[:, 0].dropna().tolist()  # Get all values from column B change it to 0 for column A
        return values

    def setup_simulation(self):
        # with open('C:/temp/dbp_log.txt', 'a') as f:
        #     f.write("1Simulation complete\n")
        try:
            # Try to unload any existing MSX project first
            if hasattr(self, 'G') and self.G:
                try:
                    self.G.unloadMSX()
                    self.G.unload()
                except:
                    pass  # Ignore errors if nothing was loaded
        except:
            pass
        if self.feedback is not None:
            self.feedback.setProgress(10)
            self.feedback.pushInfo(
                "Loading models...")
        self.G = epanet(self.inpname)
        self.G.loadMSXFile(self.msxname)
        self.species_names_function = self.G.getMSXSpeciesNameID()
        self.node_id = self.G.getNodeNameID()
        self.species_indices = [
            self.G.getMSXSpeciesIndex([name])[0] if name is not None else None
            for name in self.species_names
        ]
        self.sensor_index = [
            self.G.getNodeIndex(node)
            if node not in ("all_nodes_initial", "all_nodes_injection", None)
            else None
            for node in self.sensor_id
        ]
        # print("Setting up simulation...")
        self.G.setTimeSimulationDuration(self.t_sim)
        # self.G.setMSXTimeStep(self.msx_timestep)
        for i in range(len(self.Input_Type)):
            uncertainty = self.MSX_uncertainty[i]
            input_type = self.Input_Type[i]
            # uncertainty = self.MSX_uncertainty[i]
            # Case 1 Injection-> inputs -> NodeID, Species, SpeciesType,
            # InjectionRate and Uncertainty
            if input_type == 1:
                node = self.sensor_id[i]
                specie = self.species_names[i]
                specie_type = self.species_types[i]
                rate = self.injection_rate[i]

                if self.multiplication_pattern and self.multiplication_pattern[i]:
                    multiplication_pattern = self.multiplication_pattern[i]
                    multiplication_sensor_location = self.multiplication_sensors[i]
                    dataf = self.load_data_from_df(self.excel_file, multiplication_pattern)
                    sensor_data = dataf[multiplication_sensor_location]
                    pat_values = sensor_data
                else:
                    # NEED UPDATE just call for fix global times
                    self.load_data_from_df(self.excel_file, 'FCL')
                    pat_values = np.ones(1)

                if self.custom_pattern_file and self.custom_pattern_file[i]:
                    custom_pattern_file_path = self.custom_pattern_file[i]
                    custom_pattern_file_path = os.path.join(self.plugin_dir, "..", 'data', 'pattern_data',
                                                            custom_pattern_file_path)
                    pat_values = self.fetch_custom_values(custom_pattern_file_path)

                patID = f"pat{i}"
                self.patID[i] = patID

                # Fallback: use a flat pattern of 1s
                self.G.addMSXPattern(patID)
                self.G.setMSXPattern(patID, pat_values)

                # Check that the first case is correct
                # print(f"[Action {i + 1}] Injecting {specie} at {node} with rate {rate} using pattern {patID} (type:"
                #       f" {specie_type}) with uncertainty {self.MSX_uncertainty}")
                if specie_type == "Set Point Booster":
                    specie_type = 'SETPOINT'
                if specie_type == "Inflow Concentration":
                    specie_type = 'CONCEN'
                if specie_type == "Flow Paced Booster":
                    specie_type = 'FLOWPACED'
                if specie_type == "Mass Inflow Booster":
                    specie_type = 'MASS'
                if specie_type == "No Source":  # need name for csv
                    specie_type = 'NOSOURCE'

                if node == "all_nodes_injection":
                    nodes_count = self.G.getNodeCount()
                    all_nodes = self.G.getNodeNameID()
                    for j in range(nodes_count):
                        node = all_nodes[j]
                        self.G.setMSXSources(node, specie, specie_type, rate, patID)
                else:
                    self.G.setMSXSources(node, specie, specie_type, rate, patID)
            # Case 2 InInputs NodeID, Species Initial Concetration Uncertainty
            if input_type == 2:
                node = self.sensor_id[i]
                node_index = self.sensor_index[i]
                species_index = self.species_indices[i] - 1
                init_value = self.initial_concentration[i]

                # Warning in case the parse is wrong or csv.
                if node_index is None or species_index is None or init_value is None:
                    print(f"[Warning] Missing data for Scenario {i}. Skipping initial concentration.")
                    continue
                values = self.G.getMSXNodeInitqualValue()
                if node == "all_nodes_initial":
                    nodes_count = self.G.getNodeCount()
                    for j in range(nodes_count):
                        values[j][species_index] = init_value
                else:
                    values[node_index - 1][species_index] = init_value
                # print(f"[Action {i + 1}] Setting initial concentration: {init_value} "
                #      f"→ Node: {node}, Species Name: {self.species_names[i]}")
                self.G.setMSXNodeInitqualValue(values)
            # Case 3 Inputs Chemical parameter Chemical Value
            if input_type == 3:
                param_name = self.chemical_param[i]
                param_value = self.chemical_value[i]

                # Warning in case the parse is wrong or csv.
                if param_name is None or param_value is None:
                    print(f"[Warning] Missing chemical param or value in Scenario {i}. Skipping.")
                    continue

                try:
                    param_index = self.G.getMSXParametersIndex([param_name])[0]
                    # test123 = self.G.getMSXParametersNameID([param_index])
                    # print(param_name,param_index,test123)
                except Exception as e:
                    print(f"[Action {i + 1}] Error finding parameter '{param_name}': {e}")
                    continue

                # Apply to tanks
                for tank_index in self.G.getNodeTankIndex():
                    self.G.setMSXParametersTanksValue(tank_index, param_index, param_value)

                # Apply to pipes
                for pipe_index in self.G.getLinkPipeIndex():
                    self.G.setMSXParametersPipesValue(pipe_index, param_index, param_value)

                # print(f"[Action {i + 1}] Set chemical parameter '{param_name}' = {param_value} "
                #      f"→ Applied to all tanks and pipes.")
            # Demand Uncertainty
            if input_type == 4:
                type = 'PDA'
                pmin = 0
                preq = 0.1
                pexp = 0.5
                self.G.setDemandModel(type, pmin, preq, pexp)  # Sets the demand model
                base_demands = self.G.getNodeBaseDemands()[1]
                eta_bar = self.Demand_Uncertainty[i]
                nsim = self.monte_carlo_simulations
                self.store_demands(nsim, base_demands, eta_bar)

    def store_demands(self, nsim, base_demands, eta_bar):
        # Seed number to always get the same random results
        np.random.seed(1)
        # Initialize matrix to save MCS pressures
        self.bd_mcs = []
        for i in range(nsim):
            # Compute new base demands
            delta_bd = (2 * np.random.rand(1, len(base_demands))[0] - 1) * eta_bar * base_demands
            new_base_demands = base_demands + delta_bd
            self.bd_mcs.append(new_base_demands.copy())
            # print(f"Epoch {i}")
        return self.bd_mcs

    def compute_for_demands(self, node_index):

        nsim = len(self.bd_mcs)
        pmcs = [None for _ in range(nsim)]

        for i in range(nsim):
            new_base_demands = self.bd_mcs[i]
            self.G.setNodeBaseDemands(new_base_demands)
            pmcs[i] = self.G.getComputedHydraulicTimeSeries().Pressure
            # print(f"Epoch {i}")

        # Compute upper and lower bounds
        pmulti = [pmcs[i][:, node_index - 1] for i in range(nsim)]
        pmulti = np.vstack(pmulti)
        ub = np.max(pmulti, axis=0)
        lb = np.min(pmulti, axis=0)
        meanb = np.mean(pmulti, axis=0)

        return pmulti, ub, lb, meanb

    def run_simulation(self):
        if self.feedback is not None:
            self.feedback.setProgress(20)
            self.feedback.pushInfo(
                "Starting to Simulate...")
        G = self.G
        self.MSX_comps = []

        default_result = G.getMSXComputedQualityNode()  # default scenario no uncertainties
        self.MSX_comps.append(default_result)
        self.MSX_comp = default_result

        demand_scenarios = self.bd_mcs if any(t == 4 for t in self.Input_Type) else [None]

        idx_type1 = [i for i, t in enumerate(self.Input_Type) if t == 1 and self.MSX_uncertainty[i] > 0]
        idx_type2 = [i for i, t in enumerate(self.Input_Type) if t == 2 and self.MSX_uncertainty[i] > 0]
        idx_type3 = [i for i, t in enumerate(self.Input_Type) if t == 3 and self.MSX_uncertainty[i] > 0]
        scenario_index = 0

        v1 = 3 ** len(idx_type1) if idx_type1 else 1
        v2 = 3 ** len(idx_type2) if idx_type2 else 1
        v3 = 3 ** len(idx_type3) if idx_type3 else 1
        v4 = len(demand_scenarios) if demand_scenarios != [None] else 1
        # Total number of scenarios
        total_scenarios = v1 * v2 * v3 * v4
        #  type 1
        for i1 in idx_type1 or [None]:
            variations_1 = [(None, None)]
            if i1 is not None:
                rate = self.injection_rate[i1]
                uncertainty = self.MSX_uncertainty[i1] / 100
                node = self.sensor_id[i1]
                specie = self.species_names[i1]
                specie_type = self.species_types[i1].upper().replace(" ", "")
                specie_type = {
                    "SETPOINTBOOSTER": "SETPOINT",
                    "INFLOWCONCENTRATION": "CONCEN",
                    "FLOWPACEDBOOSTER": "FLOWPACED",
                    "MASSINFLOWBOOSTER": "MASS",
                    "NOSOURCE": "NOSOURCE"
                }.get(specie_type, specie_type)
                variations_1 = [
                    (rate * (1 - uncertainty), self.patID[i1]),
                    (rate * (1 + uncertainty), self.patID[i1]),
                    (rate, self.patID[i1])
                ]

            for val1, pat1 in variations_1:
                if i1 is not None:
                    if node == "all_nodes_injection":
                        nodes_count = self.G.getNodeCount()
                        all_nodes = self.G.getNodeNameID()
                        for j in range(nodes_count):
                            node = all_nodes[j]
                            self.G.setMSXSources(node, specie, specie_type, rate, pat1)
                    else:
                        self.G.setMSXSources(node, specie, specie_type, val1, pat1)


                # type 2
                for i2 in idx_type2 or [None]:
                    variations_2 = [(None,)]
                    if i2 is not None:
                        node = self.sensor_id[i2]
                        init_val = self.initial_concentration[i2]
                        uncertainty = self.MSX_uncertainty[i2] / 100
                        node_idx = self.sensor_index[i2] - 1
                        species_idx = self.species_indices[i2] - 1
                        variations_2 = [
                            (init_val * (1 - uncertainty),),
                            (init_val * (1 + uncertainty),),
                            (init_val,)
                        ]

                    for val2 in variations_2:
                        if i2 is not None:
                            values = self.G.getMSXNodeInitqualValue()
                            if node == "all_nodes_initial":
                                nodes_count = self.G.getNodeCount()
                                for j in range(nodes_count):
                                    values[j][species_index] = init_value
                            else:
                                values[node_idx][species_idx] = val2[0]

                            self.G.setMSXNodeInitqualValue(values)

                        #  type 3
                        for i3 in idx_type3 or [None]:
                            variations_3 = [(None,)]
                            if i3 is not None:
                                val = self.chemical_value[i3]
                                param = self.chemical_param[i3]
                                uncertainty = self.MSX_uncertainty[i3] / 100
                                param_index = self.G.getMSXParametersIndex([param])[0]
                                variations_3 = [
                                    (val * (1 - uncertainty),),
                                    (val * (1 + uncertainty),),
                                    (val,)
                                ]

                            for val3 in variations_3:
                                if i3 is not None:
                                    for tank in self.G.getNodeTankIndex():
                                        self.G.setMSXParametersTanksValue(tank, param_index, val3[0])
                                    for pipe in self.G.getLinkPipeIndex():
                                        self.G.setMSXParametersPipesValue(pipe, param_index, val3[0])

                                #  type 4 (demand scenarios)
                                for bd in demand_scenarios:
                                    if bd is not None:
                                        self.G.setNodeBaseDemands(bd)

                                    quality_result = G.getMSXComputedQualityNode()
                                    self.MSX_comps.append(quality_result)
                                    scenario_index += 1
                                    if self.feedback is not None and total_scenarios > 0:
                                        progress = 20 + int((scenario_index / total_scenarios) * 80)
                                        self.feedback.setProgress(progress)
                                        self.feedback.pushInfo(
                                            f"Progress: {progress}% - Scenario {scenario_index}/{total_scenarios}")

        self.MSXunits = G.getMSXSpeciesUnits()
        self.speiciesindecesmsx = G.getMSXSpeciesIndex()
        self.soeciesnamesmsx = G.getMSXSpeciesNameID()
        self.dataframe = self.export_to_dataframe()
        self.dataframe_uncertainty = self.export_to_dataframe_uncertainty()
        self.G.unloadMSX()
        self.G.unload()
        if self.feedback is not None:
            self.feedback.setProgress(100)
            self.feedback.pushInfo(f"Processing step 100%...")

        mes = self.Measured_Chlorine()

        return (self.MSX_comps[0], self.node_id, self.species_names_function, self.MSX_comps, self.global_times,
                self.soeciesnamesmsx, self.MSXunits, mes)

    def Measured_Chlorine(self):
        """Return measured chlorine data as a dictionary with padded arrays."""
        self.dataf = self.load_data_from_df(self.excel_file, "FCL")
        keys_list = list(self.dataf.keys())
        if not keys_list:  # this case should never trigger
            return {}

        arrays_dict = {}

        # Convert to NumPy arrays and check type
        for key in keys_list:
            array = self.dataf[key]

            if not isinstance(array, np.ndarray):
                array = np.array(array)
            if not np.issubdtype(array.dtype, np.number):
                raise TypeError(f"Array for key '{key}' contains non-numeric data.")

            arrays_dict[key] = array

        # Determine the maximum length among all arrays
        max_length = max(arr.shape[0] for arr in arrays_dict.values())

        # Pad arrays and build the result dictionary
        padded_dict = {}
        for key, arr in arrays_dict.items():
            current_length = arr.shape[0]
            if current_length < max_length:
                padding = np.zeros(max_length - current_length, dtype=arr.dtype)
                padded_arr = np.concatenate([arr, padding])
            else:
                padded_arr = arr[:max_length]  # Trim if too long
            padded_dict[key] = padded_arr

        return padded_dict

    def export_to_excel(self, results, output_file='computedtoexcel.xlsx', selected_nodes=None,
                        selected_species=None,
                        header=True):
        if not output_file.endswith('.xlsx'):
            output_file += '.xlsx'

        if not hasattr(results, 'Time') or not hasattr(results, 'Quality'):
            raise ValueError("Simulation results are not properly initialized or run.")

        time_data = results.Time
        species_list = self.species_names_function

        # Get node IDs and indices
        node_ids = self.node_id
        node_indices = list(range(len(node_ids)))

        # Filter nodes if selected_nodes is provided
        if selected_nodes:
            selected_node_indices = []
            for node in selected_nodes:
                if isinstance(node, str):  # Node ID
                    if node in node_ids:
                        selected_node_indices.append(node_ids.index(node))
                    else:
                        raise ValueError(f"Node ID '{node}' not found.")
                elif isinstance(node, int):  # Node index
                    if 0 <= node < len(node_ids):
                        selected_node_indices.append(node)
                    else:
                        raise ValueError(f"Node index '{node}' is out of range.")
                else:
                    raise ValueError(f"Invalid node identifier: {node}")
        else:
            selected_node_indices = node_indices

        # Filter species if selected_species is provided
        if selected_species:
            selected_species_indices = []
            for species in selected_species:
                if isinstance(species, str):  # Species name
                    if species in species_list:
                        selected_species_indices.append(species_list.index(species))
                    else:
                        raise ValueError(f"Species name '{species}' not found.")
                elif isinstance(species, int):  # Species index
                    if 0 <= species < len(species_list):
                        selected_species_indices.append(species)
                    else:
                        raise ValueError(f"Species index '{species}' is out of range.")
                else:
                    raise ValueError(f"Invalid species identifier: {species}")
        else:
            selected_species_indices = list(range(len(species_list)))

        with pd.ExcelWriter(output_file, engine='xlsxwriter') as writer:
            node_keys = list(results.Quality.keys())

            for species_index in selected_species_indices:
                species_name = species_list[species_index]
                species_data = []

                for node_index in selected_node_indices:
                    node_key = node_keys[node_index]
                    quality_data = np.array(results.Quality[node_key])

                    # If quality_data has an extra leading dimension
                    if quality_data.ndim == 3 and quality_data.shape[0] == 1:
                        quality_data = quality_data[0]

                    num_timesteps = len(time_data)
                    num_species = len(species_list)
                    expected_shape = (num_timesteps, num_species)

                    if quality_data.shape != expected_shape:
                        raise ValueError(
                            f"Node {node_key}: quality_data does not match expected shape {expected_shape}. "
                            f"Actual shape: {quality_data.shape}"
                        )
                    species_data.append(quality_data[:, species_index])

                species_data_array = np.array(species_data)

                df = pd.DataFrame(species_data_array, columns=time_data,
                                  index=[node_ids[i] for i in selected_node_indices])
                df.insert(0, 'NODE INDEX', [node_indices[i] for i in selected_node_indices])
                df.insert(1, 'NODE ID', [node_ids[i] for i in selected_node_indices])

                # If header is False, remove the first data row from df
                if not header and len(df) > 0:
                    df = df.iloc[1:].copy()

                sheet_name = f"{species_name}"
                # If header=False, no column headers will be written to the Excel sheet.
                df.to_excel(writer, index=False, sheet_name=sheet_name, header=header)

                worksheet = writer.sheets[sheet_name]
                worksheet.set_column('A:A', 13.0)

        # print(f"Data successfully written to {output_file}")

    def plot_data(self, measured_data, simulated_data, sensor_index, species_index, species_names,
                  sensor_description, subtitle=None, show_measured=True):

        from matplotlib import cm
        import matplotlib.colors as mcolors
        from itertools import cycle
        import matplotlib.dates as mdates
        from matplotlib.ticker import MaxNLocator
        from datetime import datetime, timedelta
        base_pool = list(plt.rcParams['axes.prop_cycle'].by_key()['color'])
        tab20_pool = [mcolors.to_hex(cm.get_cmap('tab20')(i)) for i in range(cm.get_cmap('tab20').N)]
        colour_pool = [c for c in base_pool if c != "blue"] + [c for c in tab20_pool if c != "blue"]
        colour_cycle = cycle(colour_pool)

        species_colour_map = {}

        def colour_for_species(name):
            """Return the consistent colour assigned to *name*, assigning a new one if necessary."""
            if name not in species_colour_map:
                _ = next(colour_cycle)  # discard
                assigned = next(colour_cycle)
                species_colour_map[name] = assigned
            return species_colour_map[name]

        figure_height = max(3 * len(sensor_index), 8)
        figure = Figure(figsize=(10, figure_height))
        canvas = FigureCanvas(figure)
        duration = self.t_d * 288

        for k, (i, sensor_name) in enumerate(zip(sensor_index, sensor_description), start=1):
            ax_left = figure.add_subplot(4, 1, k)

            ax_right = None
            #  Measured chlorine
            measured_array = measured_data.get(sensor_name)
            times = self.global_times
            ax_left.xaxis.set_major_formatter(mdates.DateFormatter('%d %H:%M'))

            ax_left.xaxis.set_major_locator(MaxNLocator(nbins=13))

            for label in ax_left.get_xticklabels():
                label.set_rotation(30)
                label.set_horizontalalignment('right')
                label.set_fontsize(8)

            if measured_array is not None and show_measured:
                ax_left.plot(times[:duration], measured_array[:duration], label="CL2 measured", color="blue")
                ax_left.set_ylabel("(mg/L)", color="blue")
                ax_left.tick_params(axis="y", labelcolor="blue")
                ax_left.set_ylim(bottom=0, top=max(0.01, measured_array[:duration].max() * 1.2))  # Add headroom

            #  Simulated species
            for idx, sp_ind in enumerate(species_index):
                if sp_ind is None or i == 0:
                    continue

                quality_data = simulated_data[0].Quality[i][:, sp_ind - 1]
                species_name = species_names[idx]
                unit_idx = self.soeciesnamesmsx.index(species_name)
                unit = self.MSXunits[unit_idx].lower()

                colour = colour_for_species(species_name)
                if "mg" in unit:
                    ax_left.plot(times[:duration], quality_data[:duration], label=species_name, color=colour)
                    ax_left.set_ylabel("(mg/L)", color="blue")
                    ax_left.tick_params(axis="y", labelcolor="blue")
                    if show_measured and quality_data[:duration].max() * 1.2 > measured_array[:duration].max() * 1.2:
                        ax_left.set_ylim(bottom=0, top=max(0.01, quality_data[:duration].max() * 1.2))  # Add headroom
                elif any(u in unit for u in ("ug", "µg", "μg")):
                    if ax_right is None:
                        ax_right = ax_left.twinx()
                    ax_right.plot(times[:duration], quality_data[:duration], label=species_name, color=colour)
                    ax_right.set_ylabel("(ug/L)", color="red")
                    ax_right.tick_params(axis="y", labelcolor="red")
                    ax_right.set_ylim(bottom=0, top=max(0.01, quality_data[:duration].max() * 1.2))  # Add headroom
                else:
                    ax_left.plot(times[:duration], quality_data[:duration], label=species_name, color=colour)

            start_date = self.global_times[0].astype('M8[D]').astype(str)  # 'YYYY-MM-DD'

            start_dt = datetime.strptime(start_date, "%Y-%m-%d")
            end_dt = start_dt + timedelta(days=self.t_d)
            start_str = start_dt.strftime("%d/%m/%Y")
            end_str = end_dt.strftime("%d/%m/%Y")
            ax_left.set_xlim(times[0], times[duration - 1])

            title = f"{sensor_name} ({start_str}-{end_str})"
            ax_left.set_title(title, fontsize=9, fontweight="bold")
            ax_left.grid(True)
            # if k == len(sensor_index):
            # ax_left.set_xlabel("Time of Day", fontsize = 8)
            ax_left.legend(loc="upper left", fontsize="x-small", frameon=True, framealpha=0.5, facecolor='white')
            if ax_right is not None:
                ax_right.tick_params(axis="y")
                ax_right.legend(loc="upper right", fontsize="x-small", frameon=True, framealpha=0.5, facecolor='white')

        if subtitle:
            figure.suptitle(subtitle, fontsize=12)

        figure.tight_layout(rect=[0, 0, 0.94, 0.96])
        figure.subplots_adjust(hspace=0.5)
        canvas.draw()
        return canvas

    def export_to_dataframe(self, selected_nodes=None, selected_species=None, header=True):

        results = self.MSX_comp
        if not hasattr(results, 'Time') or not hasattr(results, 'Quality'):
            raise ValueError("Simulation results are not properly initialized or run.")

        time_data = results.Time
        species_list = self.species_names_function

        node_ids = self.node_id
        node_indices = list(range(len(node_ids)))

        if selected_nodes:
            selected_node_indices = []
            for node in selected_nodes:
                if isinstance(node, str):
                    if node in node_ids:
                        selected_node_indices.append(node_ids.index(node))
                    else:
                        raise ValueError(f"Node ID '{node}' not found.")
                elif isinstance(node, int):
                    if 0 <= node < len(node_ids):
                        selected_node_indices.append(node)
                    else:
                        raise ValueError(f"Node index '{node}' is out of range.")
                else:
                    raise ValueError(f"Invalid node identifier: {node}")
        else:
            selected_node_indices = node_indices

        if selected_species:
            selected_species_indices = []
            for species in selected_species:
                if isinstance(species, str):
                    if species in species_list:
                        selected_species_indices.append(species_list.index(species))
                    else:
                        raise ValueError(f"Species name '{species}' not found.")
                elif isinstance(species, int):
                    if 0 <= species < len(species_list):
                        selected_species_indices.append(species)
                    else:
                        raise ValueError(f"Species index '{species}' is out of range.")
                else:
                    raise ValueError(f"Invalid species identifier: {species}")
        else:
            selected_species_indices = list(range(len(species_list)))

        node_keys = list(results.Quality.keys())
        all_dataframes = []

        for species_index in selected_species_indices:
            species_name = species_list[species_index]
            species_data = []

            for node_index in selected_node_indices:
                node_key = node_keys[node_index]
                quality_data = np.array(results.Quality[node_key])

                if quality_data.ndim == 3 and quality_data.shape[0] == 1:
                    quality_data = quality_data[0]

                num_timesteps = len(time_data)
                num_species = len(species_list)
                expected_shape = (num_timesteps, num_species)

                if quality_data.shape != expected_shape:
                    raise ValueError(
                        f"Node {node_key}: quality_data does not match expected shape {expected_shape}. "
                        f"Actual shape: {quality_data.shape}"
                    )
                species_data.append(quality_data[:, species_index])

            species_data_array = np.array(species_data)

            df = pd.DataFrame(species_data_array, columns=time_data,
                              index=[node_ids[i] for i in selected_node_indices])
            df.insert(0, 'NODE INDEX', [node_indices[i] for i in selected_node_indices])
            df.insert(1, 'NODE ID', [node_ids[i] for i in selected_node_indices])
            df.insert(2, 'SPECIES', species_name)

            if not header and len(df) > 0:
                df = df.iloc[1:].copy()

            all_dataframes.append(df.reset_index(drop=True))

        combined_df = pd.concat(all_dataframes, ignore_index=True)
        return combined_df

    def export_to_dataframe_uncertainty(self, selected_nodes=None, selected_species=None, header=True):
        if not hasattr(self, 'MSX_comps') or not isinstance(self.MSX_comps, list):
            raise ValueError("self.MSX_comps must be a list of simulation results.")

        all_dataframes = []
        species_list = self.species_names_function
        node_ids = self.node_id
        node_indices = list(range(len(node_ids)))

        if selected_nodes:
            selected_node_indices = []
            for node in selected_nodes:
                if isinstance(node, str):
                    if node in node_ids:
                        selected_node_indices.append(node_ids.index(node))
                    else:
                        raise ValueError(f"Node ID '{node}' not found.")
                elif isinstance(node, int):
                    if 0 <= node < len(node_ids):
                        selected_node_indices.append(node)
                    else:
                        raise ValueError(f"Node index '{node}' is out of range.")
                else:
                    raise ValueError(f"Invalid node identifier: {node}")
        else:
            selected_node_indices = node_indices

        if selected_species:
            selected_species_indices = []
            for species in selected_species:
                if isinstance(species, str):
                    if species in species_list:
                        selected_species_indices.append(species_list.index(species))
                    else:
                        raise ValueError(f"Species name '{species}' not found.")
                elif isinstance(species, int):
                    if 0 <= species < len(species_list):
                        selected_species_indices.append(species)
                    else:
                        raise ValueError(f"Species index '{species}' is out of range.")
                else:
                    raise ValueError(f"Invalid species identifier: {species}")
        else:
            selected_species_indices = list(range(len(species_list)))

        for run_index, results in enumerate(self.MSX_comps):
            if not hasattr(results, 'Time') or not hasattr(results, 'Quality'):
                raise ValueError(f"Simulation result at index {run_index} is not properly initialized or run.")

            time_data = results.Time
            node_keys = list(results.Quality.keys())

            for species_index in selected_species_indices:
                species_name = species_list[species_index]
                species_data = []

                for node_index in selected_node_indices:
                    node_key = node_keys[node_index]
                    quality_data = np.array(results.Quality[node_key])

                    if quality_data.ndim == 3 and quality_data.shape[0] == 1:
                        quality_data = quality_data[0]

                    expected_shape = (len(time_data), len(species_list))
                    if quality_data.shape != expected_shape:
                        raise ValueError(
                            f"Simulation {run_index}, Node {node_key}: quality_data shape mismatch. "
                            f"Expected {expected_shape}, got {quality_data.shape}."
                        )

                    species_data.append(quality_data[:, species_index])

                species_data_array = np.array(species_data)

                df = pd.DataFrame(species_data_array, columns=time_data,
                                  index=[node_ids[i] for i in selected_node_indices])
                df.insert(0, 'RUN INDEX', run_index)
                df.insert(1, 'NODE INDEX', [node_indices[i] for i in selected_node_indices])
                df.insert(2, 'NODE ID', [node_ids[i] for i in selected_node_indices])
                df.insert(3, 'SPECIES', species_name)

                if not header and len(df) > 0:
                    df = df.iloc[1:].copy()

                all_dataframes.append(df.reset_index(drop=True))

        combined_df = pd.concat(all_dataframes, ignore_index=True)
        return combined_df

    def plot_data_with_uncertainty(self, measured_data,
                                   sensor_index, species_index, species_names,
                                   sensor_description, subtitle=None, show_measured=False):
        import numpy as np
        from matplotlib import cm
        import matplotlib.colors as mcolors
        import matplotlib.pyplot as plt
        from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
        from matplotlib.figure import Figure
        import matplotlib.dates as mdates
        from matplotlib.ticker import MaxNLocator
        from datetime import datetime, timedelta

        figure_height = max(3 * len(sensor_index), 8)
        figure = Figure(figsize=(10, figure_height))
        canvas = FigureCanvas(figure)

        duration = self.t_d * 288
        if any("CL2" in name for name in species_names):
            show_measured = True

        for k, (i, sensor_name) in enumerate(zip(sensor_index, sensor_description), start=1):
            ax_left = figure.add_subplot(4, 1, k)
            ax_right = None

            times = self.global_times
            measured_array = measured_data.get(sensor_name)
            ax_left.xaxis.set_major_formatter(mdates.DateFormatter('%d %H:%M'))
            ax_left.xaxis.set_major_locator(MaxNLocator(nbins=13))
            for label in ax_left.get_xticklabels():
                label.set_rotation(30)
                label.set_horizontalalignment('right')
                label.set_fontsize(8)

            if measured_array is not None and show_measured:
                ax_left.plot(times[:duration], measured_array[:duration], label="CL2 measured", color="blue")
                ax_left.set_ylabel("(mg/L)", color="blue")
                ax_left.tick_params(axis="y", labelcolor="blue")

            # Only the first species will be plotted
            sp_ind = species_index[0]
            species_name = species_names[0]

            unit_idx = self.soeciesnamesmsx.index(species_name)
            unit = self.MSXunits[unit_idx].lower()

            mean_vals = self.MSX_comps[0].Quality[i][:, sp_ind - 1]

            # Determine axis (left/right)
            is_micro = any(u in unit for u in ("ug", "µg", "μg"))
            plot_ax = ax_left if not is_micro else ax_left.twinx()
            if is_micro:
                ax_right = plot_ax

            # Aggregate uncertainty range across all MSX_comps[1:]
            if len(self.MSX_comps) > 1:
                uncertainty_arrays = [
                    comp.Quality[i][:, sp_ind - 1] for comp in self.MSX_comps[1:]
                ]
                uncertainty_stack = np.stack(uncertainty_arrays)  # shape: (n_comps, n_times)
                min_vals = np.min(uncertainty_stack, axis=0)
                max_vals = np.max(uncertainty_stack, axis=0)

                plot_ax.fill_between(times[:duration], min_vals[:duration], max_vals[:duration],
                                     color="orange", alpha=0.3, label=f"{species_name} range")

            # Plot mean
            plot_ax.plot(times[:duration], mean_vals[:duration], color="red", linewidth=0.8,
                         label=f"{species_name} with no uncertainty")

            # Axis labeling
            ylabel = "(mg/L)" if not is_micro else "(ug/L)"
            plot_ax.set_ylabel(ylabel, color="blue" if not is_micro else "red")
            plot_ax.tick_params(axis="y", labelcolor="blue" if not is_micro else "red")

            # X-axis and title formatting
            start_date = self.global_times[0].astype('M8[D]').astype(str)
            start_dt = datetime.strptime(start_date, "%Y-%m-%d")
            end_dt = start_dt + timedelta(days=self.t_d)
            start_str = start_dt.strftime("%d/%m/%Y")
            end_str = end_dt.strftime("%d/%m/%Y")
            ax_left.set_xlim(times[0], times[duration - 1])
            ax_left.set_title(f"{sensor_name} ({start_str}-{end_str})", fontsize=9, fontweight="bold")
            ax_left.grid(True)

            # Legends
            plot_ax.legend(loc="upper left" if not is_micro else "upper right",
                           fontsize="x-small", frameon=True, framealpha=0.5, facecolor='white')

        if subtitle:
            figure.suptitle(subtitle, fontsize=12)

        figure.tight_layout(rect=[0, 0, 0.94, 0.96])
        figure.subplots_adjust(hspace=0.5)
        canvas.draw()
        return canvas

    def exportMSXstatistics(self, output_path="summary_output.xlsx", nodeids=True, nodeindex=True):
        """
        Summarizes min, max, and average values for each node from a combined DataFrame.

        Parameters:
            combined_df (pd.DataFrame): The combined dataframe with all species.
            output_path (str): Path to save the output summary Excel file.
            nodeids (bool): Include node IDs in the summary.
            nodeindex (bool): Include node indices in the summary.
        """
        combined_df = self.dataframe
        output_data = {}
        time_columns = [col for col in combined_df.columns if isinstance(col, (int, float))]

        grouped = combined_df.groupby("SPECIES")

        for species_name, species_df in grouped:
            summary_rows = []

            for _, row in species_df.iterrows():
                values = pd.to_numeric(row[time_columns], errors='coerce').dropna()

                if values.empty:
                    continue

                summary = {
                    'Min': values.min(),
                    'Max': values.max(),
                    'Mean': values.mean()
                }

                if nodeids:
                    summary['NodeID'] = row['NODE ID']
                if nodeindex:
                    summary['NodeIndex'] = row['NODE INDEX']

                ordered_summary = {}
                if nodeids:
                    ordered_summary['NodeID'] = summary['NodeID']
                if nodeindex:
                    ordered_summary['NodeIndex'] = summary['NodeIndex']
                ordered_summary['Min'] = summary['Min']
                ordered_summary['Max'] = summary['Max']
                ordered_summary['Mean'] = summary['Mean']

                summary_rows.append(ordered_summary)

            output_data[species_name] = pd.DataFrame(summary_rows)

        with pd.ExcelWriter(output_path) as writer:
            for sheet_name, df in output_data.items():
                df.to_excel(writer, sheet_name=sheet_name + '_stats', index=False)

        # Optional: return the data if needed
        return output_data

    def exportMSXstatistics_uncertainty(self, output_path="summary_output.xlsx", nodeids=True, nodeindex=True):
        """
        Summarizes min, max, and average values for each node from the full combined DataFrame,
        ignoring run index and grouping only by species.

        Parameters:
            output_path (str): Path to save the output summary Excel file.
            nodeids (bool): Include node IDs in the summary.
            nodeindex (bool): Include node indices in the summary.
        """
        combined_df = self.dataframe_uncertainty
        output_data = {}

        # Time columns are numeric (time steps)
        time_columns = [col for col in combined_df.columns if isinstance(col, (int, float))]

        # Group only by species (ignore RUN INDEX)
        grouped = combined_df.groupby('SPECIES')

        for species_name, species_df in grouped:
            summary_rows = []

            # Group by NODE ID + NODE INDEX to summarize across multiple runs
            node_grouped = species_df.groupby(['NODE ID', 'NODE INDEX'])

            for (node_id, node_index), group_df in node_grouped:
                # Combine all time values for this node across all runs
                values = pd.to_numeric(group_df[time_columns].values.flatten(), errors='coerce')
                values = pd.Series(values).dropna()

                if values.empty:
                    continue

                summary = {
                    'NodeID': node_id if nodeids else None,
                    'NodeIndex': node_index if nodeindex else None,
                    'Min': values.min(),
                    'Max': values.max(),
                    'Mean': values.mean()
                }

                # Clean up based on user options
                ordered_summary = {}
                if nodeids:
                    ordered_summary['NodeID'] = summary['NodeID']
                if nodeindex:
                    ordered_summary['NodeIndex'] = summary['NodeIndex']
                ordered_summary['Min'] = summary['Min']
                ordered_summary['Max'] = summary['Max']
                ordered_summary['Mean'] = summary['Mean']

                summary_rows.append(ordered_summary)

            # Save each species summary
            output_data[species_name] = pd.DataFrame(summary_rows)

        # Write to Excel, one sheet per species
        with pd.ExcelWriter(output_path) as writer:
            for species_name, df in output_data.items():
                safe_sheet_name = f"{species_name}_stats"[:31]
                df.to_excel(writer, sheet_name=safe_sheet_name, index=False)

        return output_data