# -*- coding: utf-8 -*-
"""
/***************************************************************************
 GeosimulationLandChanges
                                 A QGIS plugin
 This plugin is a tool used in spatial modeling to predict changes in land cover or land use
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-12-13
        git sha              : $Format:%H$
        copyright            : (C) 2024 by Albertus Deliar, Ananda Diva Victorya Gunawan, Elnurmas Zaetun Faesyari, Tedy Imanuel Selan
        email                : albertus.deliar@gmail.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""
from qgis.PyQt.QtCore import QSettings, QTranslator, QCoreApplication
from qgis.PyQt.QtGui import QIcon, QColor
from qgis.PyQt.QtWidgets import QAction, QFileDialog, QMessageBox
from qgis.core import QgsProject, QgsRasterLayer, Qgis, QgsColorRampShader, QgsRasterShader, QgsSingleBandPseudoColorRenderer
from qgis.gui import QgsFileWidget
import numpy as np
from osgeo import gdal
import datetime
import time

# Initialize Qt resources from file resources.py
from .resources import *
# Import the code for the dialog
from .geosimulation_land_changes_dialog import GeosimulationLandChangesDialog
import os.path


class GeosimulationLandChanges:
    """QGIS Plugin Implementation."""

    def __init__(self, iface):
        """Constructor.

        :param iface: An interface instance that will be passed to this class
            which provides the hook by which you can manipulate the QGIS
            application at run time.
        :type iface: QgsInterface
        """

        self.dlg = GeosimulationLandChangesDialog()

        # Save reference to the QGIS interface
        self.iface = iface
        # initialize plugin directory
        self.plugin_dir = os.path.dirname(__file__)
        # Initialize the reference_required attribute
        self.reference_required = False
        # Initialize show processed output attribute
        self.show_processed_output = False
        # initialize locale
        locale = QSettings().value('locale/userLocale')[0:2]
        locale_path = os.path.join(
            self.plugin_dir,
            'i18n',
            'GeosimulationLandChanges_{}.qm'.format(locale))

        if os.path.exists(locale_path):
            self.translator = QTranslator()
            self.translator.load(locale_path)
            QCoreApplication.installTranslator(self.translator)

        # Declare instance attributes
        self.actions = []
        self.menu = self.tr(u'&Geosimulation Land Changes')

        # Check if plugin was started the first time in current QGIS session
        # Must be set in initGui() to survive plugin reloads
        self.first_start = None

    # noinspection PyMethodMayBeStatic
    def tr(self, message):
        """Get the translation for a string using Qt translation API.

        We implement this ourselves since we do not inherit QObject.

        :param message: String for translation.
        :type message: str, QString

        :returns: Translated version of message.
        :rtype: QString
        """
        # noinspection PyTypeChecker,PyArgumentList,PyCallByClass
        return QCoreApplication.translate('GeosimulationLandChanges', message)


    def add_action(
        self,
        icon_path,
        text,
        callback,
        enabled_flag=True,
        add_to_menu=False,
        add_to_toolbar=True,
        status_tip=None,
        whats_this=None,
        parent=None):
        """Add a toolbar icon to the toolbar.

        :param icon_path: Path to the icon for this action. Can be a resource
            path (e.g. ':/plugins/foo/bar.png') or a normal file system path.
        :type icon_path: str

        :param text: Text that should be shown in menu items for this action.
        :type text: str

        :param callback: Function to be called when the action is triggered.
        :type callback: function

        :param enabled_flag: A flag indicating if the action should be enabled
            by default. Defaults to True.
        :type enabled_flag: bool

        :param add_to_menu: Flag indicating whether the action should also
            be added to the menu. Defaults to True.
        :type add_to_menu: bool

        :param add_to_toolbar: Flag indicating whether the action should also
            be added to the toolbar. Defaults to True.
        :type add_to_toolbar: bool

        :param status_tip: Optional text to show in a popup when mouse pointer
            hovers over the action.
        :type status_tip: str

        :param parent: Parent widget for the new action. Defaults None.
        :type parent: QWidget

        :param whats_this: Optional text to show in the status bar when the
            mouse pointer hovers over the action.

        :returns: The action that was created. Note that the action is also
            added to self.actions list.
        :rtype: QAction
        """

        icon = QIcon(icon_path)
        action = QAction(icon, text, parent)
        action.triggered.connect(callback)
        action.setEnabled(enabled_flag)

        if status_tip is not None:
            action.setStatusTip(status_tip)

        if whats_this is not None:
            action.setWhatsThis(whats_this)

        if add_to_toolbar:
            # Adds plugin icon to Plugins toolbar
            self.iface.addToolBarIcon(action)

        if add_to_menu:
            self.iface.addPluginToMenu(
                self.menu,
                action)

        self.actions.append(action)

        return action

    
    def populate_combobox_with_rasters(self, combobox):
        """Populates the combobox with available rasters in the project."""
        combobox.clear()
        layers = QgsProject.instance().mapLayers().values()
        for layer in layers:
            if isinstance(layer, QgsRasterLayer):
                combobox.addItem(layer.name(), layer)
        
        return None
    
    
    def load_raster_from_combobox(self, combobox):
        """Load selected raster from combobox"""
        selected_layer = combobox.currentData()
        if selected_layer and isinstance(selected_layer, QgsRasterLayer):
            return selected_layer
        
        return None
    
    
    def browse_raster_file(self):
        """Open a file dialog to select a raster file"""
        file_path, _ = QFileDialog.getOpenFileName(self.iface.mainWindow(), "Select Raster File", "", "Raster Files (*.tif *.tiff)")
        if file_path:
            raster_layer = QgsRasterLayer(file_path, file_path.split('/')[-1])
            if raster_layer.isValid():
                QgsProject.instance().addMapLayer(raster_layer)
                return raster_layer
            else:
                self.iface.messageBar().pushMessage("Error", "Invalid raster file selected", level=Qgis.Critical)
        
        return None

    
    def initGui(self):
        """Create the menu entries and toolbar icons inside the QGIS GUI."""

        icon_path = ':/plugins/geosimulation_land_changes/plugin_icon.png'
        self.add_action(
            icon_path,
            text=self.tr(u'<b>Geosimulation Land Changes</b>'),
            callback=self.run,
            parent=self.iface.mainWindow())

        # will be set False in run()
        self.first_start = True

        # populate comboBox1 and comboBox2 with active raster layers
        self.populate_combobox_with_rasters(self.dlg.comboBox_1)
        self.populate_combobox_with_rasters(self.dlg.comboBox_2)

        # Connect toolButton to browse and load raster
        self.dlg.toolButton_1.clicked.connect(lambda: self.handle_raster_input(self.dlg.comboBox_1))
        self.dlg.toolButton_2.clicked.connect(lambda: self.handle_raster_input(self.dlg.comboBox_2))

        # defines a widget to input the window size
        self.dlg.spinBox.valueChanged.connect(self.check_odd_value)

        # define QCheckbox and connect the QCheckBox state to handle the condition
        self.dlg.checkBox_1.stateChanged.connect(self.handle_checkbox_1_state)

        # Initially disable comboBox_3
        self.dlg.toolButton_3.setEnabled(False)
        self.dlg.comboBox_3.setEnabled(False)

        # define QCheckbox and connect the QCheckBox state to handle the condition
        self.dlg.checkBox_2.stateChanged.connect(self.handle_checkbox_2_state)

        # Create a FileWidget for saving the output results
        self.dlg.mQgsFileWidget.setStorageMode(QgsFileWidget.SaveFile)
        self.dlg.mQgsFileWidget.setFilter("Raster Files (*.tif *.tiff)")

        # defines a button to start the prediction data calculation
        self.dlg.pushButton_1.clicked.connect(self.load_data_and_calculate_prediction)

        # QProgressBar initialization
        self.dlg.progressBar.setRange(0, 100)
        self.dlg.progressBar.setValue(0)

        # defines a button to cancel or close the dialog
        self.dlg.pushButton_2.clicked.connect(self.close_plugin)


    def unload(self):
        """Removes the plugin menu item and icon from QGIS GUI."""
        for action in self.actions:
            self.iface.removePluginMenu(
                self.tr(u'&GeosimulationLandChanges'),
                action)
            self.iface.removeToolBarIcon(action)

    
    def handle_raster_input(self, combobox):
        # Open a file dialog to select a raster
        raster_layer = self.browse_raster_file()
        if raster_layer:
            combobox.clear()
            combobox.addItem(raster_layer.name(), raster_layer)
            combobox.setCurrentIndex(0)
            self.apply_color_ramp(raster_layer)
            self.iface.messageBar().pushMessage("Success", f"Loaded raster from file: {raster_layer.name()}", level=Qgis.Success)
        else:
            # If there are layers in the combobox, select the current one
            if combobox.count() > 0:
                raster_layer = self.load_raster_from_combobox(combobox)
                if raster_layer:
                    self.iface.messageBar().pushMessage("Success", f"Loaded raster: {raster_layer.name()}", level=Qgis.Success)
                    self.apply_color_ramp(raster_layer)  # Apply color ramp to the loaded raster layer


    def apply_color_ramp(self, raster_layer):
        """Apply a color ramp to the raster layer"""
        if raster_layer.isValid():            
            raster_array = self.load_raster_data_as_array(raster_layer)
            unique_values = np.unique(raster_array)
            
            # Create a color ramp shader
            color_ramp_shader = QgsColorRampShader()
            color_ramp_shader.setColorRampType(QgsColorRampShader.Discrete)
            
            # Create a fixed color map for unique values
            color_map = {
                1: QColor(0, 0, 0),         # Black
                2: QColor(138, 69, 19),     # Saddle Brown
                3: QColor(210, 105, 30),    # Chocolate
                4: QColor(255, 0, 0),       # Red
                5: QColor(255, 128, 0),     # Orange
                6: QColor(255, 255, 0),     # Yellow
                7: QColor(128, 255, 0),     # Light Green
                8: QColor(128, 128, 0),     # Olive
                9: QColor(0, 0, 255),       # Blue
                10: QColor(0, 90, 156),     # Blue Dodger
                11: QColor(0, 255, 255),    # Cyan
                12: QColor(128, 0, 128),    # Purple
                13: QColor(179, 0, 255),    # Dark Violet
                14: QColor(255, 0, 170),    # Magenta
                15: QColor(255, 0, 204)     # Light Pink
            }
            
            # Create color ramp items based on the unique values and their assigned colors
            color_ramp_items = []            
            for value in unique_values:
                color = color_map.get(value, QColor(255, 255, 255))
                color_ramp_items.append(QgsColorRampShader.ColorRampItem(value, color, str(value.astype(int))))

            color_ramp_shader.setColorRampItemList(color_ramp_items)

            # Create a raster shader and set the color ramp
            raster_shader = QgsRasterShader()
            raster_shader.setRasterShaderFunction(color_ramp_shader)

            # Apply renderer to the raster layer
            renderer = QgsSingleBandPseudoColorRenderer(raster_layer.dataProvider(), 1, raster_shader)
            raster_layer.setRenderer(renderer)
            raster_layer.triggerRepaint()

            # Add the layer to the QGIS project if not already added
            if not QgsProject.instance().mapLayersByName(raster_layer.name()):
                QgsProject.instance().addMapLayer(raster_layer)
            
            self.iface.mapCanvas().setExtent(raster_layer.extent())
            self.iface.mapCanvas().refresh()
            # self.iface.messageBar().pushMessage("Success", f"Color ramp applied to: {raster_layer.name()}", level=Qgis.Success)
        else:
            self.iface.messageBar().pushMessage("Error", "Invalid raster layer.", level=Qgis.Critical)

    
    def handle_checkbox_1_state(self, state):
        """Enable or disable comboBox_3 based on the ChecBox_1 state"""
        if self.dlg.checkBox_1.isChecked():
            self.dlg.toolButton_3.setEnabled(True)
            self.dlg.comboBox_3.setEnabled(True)
            self.reference_required = True

            # Populate comboBox_3 with available rasters
            self.dlg.toolButton_3.clicked.connect(lambda: self.handle_raster_input(self.dlg.comboBox_3))
            self.populate_combobox_with_rasters(self.dlg.comboBox_3)

        else:
            self.dlg.toolButton_3.setEnabled(False)
            self.dlg.comboBox_3.setEnabled(False)
            self.reference_required = False

            # Disconnect the signal to avoid multiple connections
            self.dlg.toolButton_3.clicked.disconnect(lambda: self.handle_raster_input(self.dlg.comboBox_3))

    
    def handle_checkbox_2_state(self, state):
        if self.dlg.checkBox_2.isChecked():
            self.show_processed_output = True
        else:
            self.show_processed_output = False

    
    def check_odd_value(self, value):
        if value % 2 == 0:
            self.iface.messageBar().pushMessage("Warning", "Only odd values are allowed. Please change the value to an odd number.", level=Qgis.Warning)


    def load_raster_data_as_array(self, raster_layer):
        dataset = gdal.Open(raster_layer.source())
        if dataset is None:
            self.iface.messageBar().pushMessage("Error", "Failed to open raster dataset", level=Qgis.Critical)
            return None

        band = dataset.GetRasterBand(1)
        array = band.ReadAsArray()
              
        return array
    

    def calculate_cross_tabulation(self, raster_layer_1, raster_layer_2):
        raster1_array = self.load_raster_data_as_array(raster_layer_1)
        raster2_array = self.load_raster_data_as_array(raster_layer_2)

        unique_values_1 = np.unique(raster1_array)
        unique_values_2 = np.unique(raster2_array)

        cross_tabulation = np.zeros((len(unique_values_1), len(unique_values_2)))

        for i in range(len(unique_values_1)):
            for j in range(len(unique_values_2)):
                cross_tabulation[i, j] = np.count_nonzero((raster1_array == unique_values_1[i]) & (raster2_array == unique_values_2[j]))
        
        return cross_tabulation, unique_values_1, unique_values_2
    

    def calculate_probability_matrix(self, cross_tabulation):
        # Initialize the probability matrix with the same shape as cross_tabulation
        probability_matrix = np.zeros_like(cross_tabulation, dtype=float)

        # Iterate through each row to calculate probabilities
        for i in range(cross_tabulation.shape[0]):
            total_pixels = np.sum(cross_tabulation[i, :])  # Sum of the counts in the current row
            if total_pixels > 0:
                probability_matrix[i, :] = cross_tabulation[i, :] / total_pixels  # Calculate probabilities

        return probability_matrix
    

    def apply_cellular_automata_markov_chain(self, current_state, probability_matrix, x, y, matrix_size):
        offset = matrix_size // 2
        center_value = current_state[x, y]
        window_matrix = current_state[x-offset:x+offset+1, y-offset:y+offset+1]
        submatrix = window_matrix[~np.isnan(window_matrix)]  # remove NaN values

        if np.isnan(center_value):
            return center_value # Preserve NoData values
        
        unique, counts = np.unique(submatrix, return_counts=True)
        count_dict = dict(zip(unique, counts))

        # Ensure center_value is an integer and within the bounds
        center_value = int(center_value)
        if center_value < 0 or center_value >= probability_matrix.shape[0]:
            return center_value
        
        transition_probabilities = probability_matrix [center_value, :]
        weighted_sums = np.zeros_like(transition_probabilities)

        for cls in unique:
            cls = int(cls)  # Convert to integer
            if 0 <= cls < probability_matrix.shape[1]:  # Check bounds
                weighted_sums[cls] += count_dict[cls] * transition_probabilities[cls]

        new_class = np.argmax(weighted_sums) if weighted_sums.sum() > 0 else center_value
        
        return new_class
    

    def prediction(self, raster_current_state, probability_matrix, window_size):
        current_state_array = self.load_raster_data_as_array(raster_current_state)
        predict_matrix = np.copy(current_state_array)
        offset = window_size // 2

        for x in range(offset, current_state_array.shape[0] - offset):
            for y in range(offset, current_state_array.shape[1] - offset):
                predict_matrix[x, y] = self.apply_cellular_automata_markov_chain(current_state_array, probability_matrix, x, y, window_size)

        return predict_matrix
    

    def get_raster_attributes(self, raster_layer):
        dataset = gdal.Open(raster_layer.source())
        if dataset is None:
            self.iface.messageBar().pushMessage("Error", "Failed to open raster dataset", level=Qgis.Critical)

        width = dataset.RasterXSize
        height = dataset.RasterYSize
        crs = dataset.GetProjection()
        transform =  dataset.GetGeoTransform()

        return width, height, crs, transform
    

    def compare_rasters(self, raster_layer_1, raster_layer_2):
        width1, height1, crs1, transform1 = self.get_raster_attributes(raster_layer_1)
        width2, height2, crs2, transform2 = self.get_raster_attributes(raster_layer_2)

        if ((width1 != width2) or (height1 != height2) or (crs1 != crs2) or (transform1 != transform2)):
            self.iface.messageBar().pushMessage("Error", "Two images are not same size, CRS, or transform", level=Qgis.Critical)


    def save_predict_matrix_to_raster(self, predict_matrix, raster_current_state, file_path):
        reference_dataset = gdal.Open(raster_current_state.source())
        if reference_dataset is None:
            self.iface.messageBar().pushMessage("Error", "Failed to open reference raster dataset", level=Qgis.Critical)
            return

        # Access raster properties
        raster_x_size = reference_dataset.RasterXSize
        raster_y_size = reference_dataset.RasterYSize
        crs = reference_dataset.GetProjection()
        transform = reference_dataset.GetGeoTransform()

        driver = gdal.GetDriverByName('GTiff')
        # out_dataset = driver.Create(file_path, reference_dataset.RasterXSize, reference_dataset.RasterYSize, 1, gdal.GDT_Float32)
        out_dataset = driver.Create(file_path, raster_x_size, raster_y_size, 1, gdal.GDT_Float32)
        
        if out_dataset is None:
            self.iface.messageBar().pushMessage("Error", f"Failed to create raster file: {file_path}", level=Qgis.Critical)
            return

        # Write the prediction matrix to a raster band
        out_band = out_dataset.GetRasterBand(1)
        out_band.WriteArray(np.array(predict_matrix))
        out_band.FlushCache()

        # Set up geotransform and projection based on reference dataset
        out_dataset.SetGeoTransform(reference_dataset.GetGeoTransform())
        out_dataset.SetProjection(reference_dataset.GetProjection())

        # Close dataset
        out_dataset = None  # Close the dataset
        reference_dataset = None  # Close the reference dataset

        # Store the raster layer in the instance variable
        self.predict_raster = QgsRasterLayer(file_path, os.path.basename(file_path))
        if self.predict_raster.isValid():
            if self.show_processed_output:
                # Apply color ramp to the raster layer
                self.apply_color_ramp(self.predict_raster)

                # Add layers to QGIS project
                QgsProject.instance().addMapLayer(self.predict_raster)
                self.iface.messageBar().pushMessage("Success", f"Predicted raster displayed in QGIS: {file_path}", level=Qgis.Success)
            else:
                self.iface.messageBar().pushMessage("Success", f"Predicted raster saved successfully to: {file_path}", level=Qgis.Success)
        else:
            self.iface.messageBar().pushMessage("Error", "Failed to load the processed raster file into QGIS.", level=Qgis.Critical)
            
        return self.predict_raster


    def calculate_confusion_matrix(self, predict_matrix, reference_layer):
        prediction_array = np.array(predict_matrix)
        reference_array = self.load_raster_data_as_array(reference_layer)

        unique_pred_values = np.unique(prediction_array)
        unique_ref_values = np.unique(reference_array)
        confusion_matrix = np.zeros((len(unique_pred_values), len(unique_ref_values)))

        for i in range(len(unique_pred_values)):
            for j in range(len(unique_ref_values)):
                confusion_matrix[i, j] = np.count_nonzero((prediction_array == unique_pred_values[i]) & (reference_array == unique_ref_values[j]))
        
        return confusion_matrix, unique_pred_values, unique_ref_values
    

    def calculate_accuracy_and_kappa(self, confusion_matrix):
        # Check if confusion matrix is empty
        if confusion_matrix.size == 0:
            return 0, 0

        # Calculate Overall Accuracy
        total = np.sum(confusion_matrix)
        correct = np.trace(confusion_matrix) #sum of diagonal elements
        overall_accuracy = correct / total if total > 0 else 0

        # Calculate sum of rows dan columns
        row_sums = np.sum(confusion_matrix, axis=1)
        col_sums = np.sum(confusion_matrix, axis=0)
        
        # Calculate Pairwise Product
        pairwise_product = row_sums * col_sums
        expected = np.sum(pairwise_product)

        square_total = total ** 2

        # Calculate Kappa
        numerator = (total * correct) - expected
        denominator = square_total - expected

        if denominator == 0:
            return 0 # Return 0 if no variation in the data

        kappa =  numerator / denominator

        return overall_accuracy, kappa
    

    def categorize_kappa(self, kappa):
        if kappa < 0:
            return "Last than change agreement"
        elif kappa < 0.20:
            return "Slight agreement"
        elif kappa < 0.40:
            return "Fair agreement"
        elif kappa < 0.60:
            return "Moderate agreement"
        elif kappa < 0.80:
            return "Substantial agreement"
        else:
            return "Almost perfect agreement"
        
    
    def save_metadata_to_txt(self, raster_layer_1, raster_layer_2, window_size, cross_tabulation, probability_matrix, predict_raster, predict_matrix=None, reference_layer=None, confusion_matrix=None, overall_accuracy=None, kappa=None, kappa_category=None, file_path=None):
        with open(file_path, 'w') as f:
            title = "METADATA REPORT"
            title_length = len(title)

            # Assuming a fixed width for the report
            report_width = 80
            padding = (report_width - title_length) // 2
            f.write(" " * padding + title + "\n")
            f.write("=" * report_width + "\n")  # Underline the title
            f.write(f"Processing Date and Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write("\n")

            f.write("INFORMATION ABOUT THE INPUT RASTER DATA\n")
            f.write("\n")
            # Get attributes for epoch 1
            width1, height1, crs1, transform1 = self.get_raster_attributes(raster_layer_1)
            f.write("Raster Data Epoch 1\n")
            f.write(f"Epoch 1 File Name: {raster_layer_1.name()}\n")
            f.write(f"Width: {width1}, Height: {height1}\n")
            f.write(f"Coordinate Reference System: {crs1}\n")
            f.write(f"Transform: {transform1}\n")
            f.write("\n")

            # Get attributes for epoch 2
            width2, height2, crs2, transform2 = self.get_raster_attributes(raster_layer_2)
            f.write("Raster Data Epoch 2\n")
            f.write(f"Epoch 2 File Name: {raster_layer_2.name()}\n")
            f.write(f"Width: {width2}, Height: {height2}\n")
            f.write(f"Coordinate Reference System: {crs2}\n")
            f.write(f"Transform: {transform2}\n")
            f.write("\n")
            f.write("\n")

            f.write("INFORMATION ABOUT PREDICTION PARAMETERS\n")
            f.write("\n")
            f.write(f"Window Size: {window_size}\n")
            f.write("\n")

            raster1_array = self.load_raster_data_as_array(raster_layer_1)
            raster2_array = self.load_raster_data_as_array(raster_layer_2)
            unique_values_1 = np.unique(raster1_array)
            unique_values_2 = np.unique(raster2_array)
            
            f.write("Cross Tabulation\n")
            f.write("Epoch 1 \\ Epoch 2\t" + "\t".join(map(str, unique_values_2.astype(int))) + "\n")

            for i, value in enumerate(unique_values_1.astype(int)):
                row_data = "\t".join(map(str, cross_tabulation[i].astype(int)))
                f.write(f"{value}\t{row_data}\n")
            f.write("\n")

            f.write("Probability Matrix\n")
            f.write("Epoch 1 \\ Epoch 2\t" + "\t".join(map(str, unique_values_2.astype(int))) + "\n")

            # Write each row with its corresponding label (Epoch 1)
            for i, value in enumerate(unique_values_1.astype(int)):
                row_data = "\t".join(f"{prob:.3f}" for prob in probability_matrix[i])
                f.write(f"{value}\t{row_data}\n")
            f.write("\n")
            f.write("\n")

            f.write("INFORMATION ABOUT PREDICTION RASTER\n")
            f.write("\n")
            prediction = gdal.Open(predict_raster.source())
            p_raster_full_path = prediction.GetDescription()
            p_raster_name, p_width, p_height, p_crs, p_transform = os.path.basename(p_raster_full_path),  prediction.RasterXSize, prediction.RasterYSize, prediction.GetProjection(), prediction.GetGeoTransform()
            f.write(f"Prediciton File Name: {p_raster_name}\n")
            f.write(f"Width: {p_width}, Height: {p_height}\n")
            f.write(f"Coordinate Reference System: {p_crs}\n")
            f.write(f"Transform: {p_transform}\n")
            f.write("\n")
            f.write("\n")
            
            if reference_layer is not None:
                f.write("INFORMATION ABOUT VALIDATE PREDICTION RESULT\n")
                f.write("\n")
                # Get attributes for reference raster
                r_width, r_height, r_crs, r_transform = self.get_raster_attributes(reference_layer)
                f.write("Raster Reference Information\n")
                f.write(f"Reference File Name: {reference_layer.name()}\n")
                f.write(f"Width: {r_width}, Height: {r_height}\n")
                f.write(f"Coordinate Reference System: {r_crs}\n")
                f.write(f"Transform: {r_transform}\n")
                f.write("\n")

                # Write confusion matrix
                prediction_array = np.array(predict_matrix)
                reference_array = self.load_raster_data_as_array(reference_layer)
                p_unique_values = np.unique(prediction_array)
                r_unique_values = np.unique(reference_array)

                f.write("Confusion Matrix\n")
                f.write("Predicted \\ Reference\t" + "\t".join(map(str, r_unique_values.astype(int))) + "\n")

                for i, p_value in enumerate(p_unique_values.astype(int)):
                    f.write(f"{p_value}\t" + "\t".join(map(str, confusion_matrix[i].astype(int))) + "\n")
                f.write("\n")

                # Write OA, Kappa, and Kappa Category
                f.write(f"\nOverall Accuracy: {overall_accuracy:.3f}\n")
                f.write(f"\nKappa: {kappa:.3f}\n")
                f.write(f"Kappa Category: {kappa_category}\n")
                f.write("\n")
            else:
                f.write("No reference raster selected, skip validate the prediction result\n")
                f.write("\n")
            f.write("=" * report_width + "\n")  # Underline the title

        self.iface.messageBar().pushMessage("Success", f"Metadata saved to: {file_path}", level=Qgis.Success)


    def load_data_and_calculate_prediction(self):
        # start_time = time.time() #start timing

        # Disable the button
        self.dlg.pushButton_1.setEnabled(False)

        # Initialize the progress bar
        self.dlg.progressBar.setRange(0, 100)
        self.dlg.progressBar.setValue(0)
        
        # Load raster data
        raster_layer_1 = self.load_raster_from_combobox(self.dlg.comboBox_1)
        raster_layer_2 = self.load_raster_from_combobox(self.dlg.comboBox_2)
        
        # Check the validation of the raster files
        if not raster_layer_1.isValid() or not raster_layer_2.isValid():
            self.iface.messageBar().pushMessage("Error", "One or both raster files are invalid!", level=Qgis.Critical)
            return

        # Update progress bar
        self.dlg.progressBar.setValue(1)

        # compare the dimensions, CRS, or different transformations of the two rasters
        compare_raster = self.compare_rasters(raster_layer_1, raster_layer_2)
        
        # Update progress bar
        self.dlg.progressBar.setValue(2)

        # Calculate cross tabulation
        cross_tabulation, unique_values_1, unique_values_2 = self.calculate_cross_tabulation(raster_layer_1, raster_layer_2)

        # Update progress bar
        self.dlg.progressBar.setValue(3)

        # Calculate probability matrix
        probability_matrix = self.calculate_probability_matrix(cross_tabulation)
        
        # Update progress bar
        self.dlg.progressBar.setValue(5)

        # Save the predict_matrix to a raster file
        file_path = self.dlg.mQgsFileWidget.filePath()
        if not file_path:
            self.iface.messageBar().pushMessage("Error", "Please specify a file path to save the predicted raster.", level=Qgis.Critical)
            return

        directory = os.path.dirname(file_path)
        if not os.path.exists(directory):
            self.iface.messageBar().pushMessage("Error", "Directory does not exist.", level=Qgis.Critical)
            return
                     
        # Apply Celular Automata Markov Chain
        window_size = self.dlg.spinBox.value()
        raster_current_state = raster_layer_2
        predict_matrix = self.prediction(raster_current_state, probability_matrix, window_size)

        # Update progress bar
        total_steps = predict_matrix.shape[0] * predict_matrix.shape[1]
        step = 0

        for x in range(window_size // 2, predict_matrix.shape[0] - window_size // 2):
            for y in range(window_size // 2, predict_matrix.shape[1] - window_size // 2):
                predict_matrix[x, y] = self.apply_cellular_automata_markov_chain(predict_matrix, probability_matrix, x, y, window_size)
                step += 1
                progress = 5 + (step / total_steps) * 95
                self.dlg.progressBar.setValue(int(progress))

        # Save the prediction matrix to a raster file
        self.predict_raster = self.save_predict_matrix_to_raster(predict_matrix, raster_layer_2, file_path)

        # Update progress bar after saving the prediction raster
        self.dlg.progressBar.setValue(100)

        # Check the state of checkBox_1
        if not self.dlg.checkBox_1.isChecked():
            # If checkBox_1 is False, finish the process
            self.dlg.progressBar.setValue(100)
            self.dlg.close()
        
        # Define metadata filepath
        metadata_file_path = self.dlg.mQgsFileWidget.filePath().replace('.tif', '_metadata.txt')

        # Validate the predicition Data
        if self.reference_required:
            reference_layer = self.load_raster_from_combobox(self.dlg.comboBox_3)
            

            # Check the validation of the raster files and
            # compare the dimensions, CRS, or different transformations of the two rasters
            if not reference_layer.isValid():
                self.iface.messageBar().pushMessage("Error", "reference raster files are invalid!", level=Qgis.Critical)
                return
            
            compare_raster = self.compare_rasters(reference_layer, raster_layer_2)

            if reference_layer:
                confusion_matrix, unique_pred_values, unique_ref_values = self.calculate_confusion_matrix(predict_matrix, reference_layer)
                overall_accuracy, kappa = self.calculate_accuracy_and_kappa(confusion_matrix)
                kappa_category = self.categorize_kappa(kappa)
            else:
                self.iface.messageBar().pushMessage("Error", "Reference raster file is required!", level=Qgis.Critical)

            # Update progress bar after validation
            self.dlg.progressBar.setValue(95)

            # Save Metadata
            self.save_metadata_to_txt(
                raster_layer_1,
                raster_layer_2,
                window_size,
                cross_tabulation,
                probability_matrix,
                self.predict_raster,
                predict_matrix=predict_matrix,
                reference_layer=reference_layer,
                confusion_matrix=confusion_matrix,
                overall_accuracy=overall_accuracy,
                kappa=kappa,
                kappa_category=kappa_category,
                file_path=metadata_file_path
                )

        else:
            # Save Metadata
            self.save_metadata_to_txt(
                raster_layer_1,
                raster_layer_2,
                window_size,
                cross_tabulation,
                probability_matrix,
                self.predict_raster,
                file_path=metadata_file_path
                )
            self.iface.messageBar().pushMessage("Info", "No reference raster selected, skip validate the prediction result.", level=Qgis.Info)

        # Reset the progress bar after completion
        self.dlg.progressBar.setValue(100)
        self.dlg.pushButton_1.setEnabled(True)
        self.dlg.close()

        # Print the total time taken for the process
        # end_time = time.time()  # End timing
        # total_time = end_time - start_time
        # QMessageBox.information(self.dlg, "Processing Time", f"Total processing time: {total_time:.2f} seconds")


    def run(self):
        """Run method that performs all the real work"""

        # Create the dialog with elements (after translation) and keep reference
        # Only create GUI ONCE in callback, so that it will only load when the plugin is started
        if self.first_start == True:
            self.first_start = False
        
        # Set is_processing to True when the command is executed
        self.is_processing = True
        # When finished, set it back to False
        self.is_processing = False

        # show the dialog
        self.dlg.show()
        # Run the dialog event loop
        result = self.dlg.exec_()
        # See if OK was pressed
        if result:
            # Do something useful here - delete the line containing pass and
            # substitute with your code.
            pass

    def close_plugin(self):
        if not self.is_processing:
            self.dlg.close()
        else:
            reply = QMessageBox.question(self, 'Confirm', 'Are you sure you want to cancel the operation?',
                                         QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
            
            if reply == QMessageBox.Yes:
                self.dlg.close()
