# -*- coding: utf-8 -*-
"""
/***************************************************************************
 ZonalExactDialog
                                 A QGIS plugin
 Zonal Statistics of rasters using Exact Extract library
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                             -------------------
        begin                : 2024-02-11
        git sha              : $Format:%H$
        copyright            : (C) 2024 by Jakub Charyton
        email                : jakub.charyton@gmail.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

import os
from typing import Dict, List
from pathlib import Path

from qgis.PyQt import uic
from qgis.PyQt import QtWidgets, QtCore
from qgis.core import (
    QgsMapLayerProxyModel,
    QgsFieldProxyModel,
    QgsTask,
    QgsTaskManager,
    QgsMessageLog,
    QgsVectorLayer,
    QgsRasterLayer,
    QgsMapLayer,
    QgsProject,
    QgsFeatureRequest,
    QgsVectorFileWriter,
)

from .dialog_input_dto import DialogInputDTO
from .user_communication import UserCommunication, WidgetPlainTextWriter
from .task_classes import CalculateStatsTask, MergeStatsTask
from .widgets.codeEditor import CodeEditorUI
from .utils import extract_function_name

# This loads your .ui file so that PyQt can populate your plugin with the elements from Qt Designer
FORM_CLASS, _ = uic.loadUiType(
    os.path.join(os.path.dirname(__file__), "zonal_exact_dialog_base.ui")
)

DEFAULT_CODE = """import numpy as np

def np_mean(values, cov):
    average_value=np.average(values, weights=cov)
    return average_value
"""


class ZonalExactDialog(QtWidgets.QDialog, FORM_CLASS):
    """
    A dialog window for performing zonal statistics calculations using exactextract in QGIS.
    """

    def __init__(
        self,
        parent=None,
        uc: UserCommunication = None,
        iface=None,
        project: QgsProject = None,
        task_manager: QgsTaskManager = None,
    ):
        """
        Initialize the ZonalExactDialog class.

        Args:
            parent: The parent widget (default: None).
            uc: An instance of the UserCommunication class (default: None).
            iface: The QGIS interface (default: None).
            project: The QGIS project (default: None).
            task_manager: An instance of the QgsTaskManager class (default: None).
        """
        super(ZonalExactDialog, self).__init__(parent)
        # Set up the user interface from Designer through FORM_CLASS.
        # After self.setupUi() you can access any designer object by doing
        # self.<objectname>, and you can use autoconnect slots - see
        # http://qt-project.org/doc/qt-4.8/designer-using-a-ui-file.html
        # #widgets-and-dialogs-with-auto-connect
        # Initiate  a new instance of the dialog input DTO class to hold all input data
        self.dialog_input: DialogInputDTO = None
        # Initiate an empty list for storing tasks in queue
        self.tasks = []
        # Initiate an empty list to store intermediate results of zonal statistics calculation
        self.intermediate_result_list = []
        # Initiate main task that will hold aggregated data from child calculating tasks
        self.merge_task: MergeStatsTask = None
        self.output_attribute_layer = None
        self.calculated_stats_list = []
        self.temp_index_field = None
        self.input_vector = None
        self.features_count = None
        self.geospatial_output = False
        self.input_attributes_dict = {}
        # it holds custom functions and should reflect mCustomFunctionsComboBox content
        self.custom_functions_dict: Dict[str, str] = {}
        # assign qgis internal variables to class variables
        self.uc = uc
        self.iface = iface
        self.project = project
        self.task_manager: QgsTaskManager = task_manager

        self.editor = CodeEditorUI(DEFAULT_CODE)
        self.editor.resize(600, 300)
        self.editor.setWindowTitle("Custom Function Code Editor")

        self.setupUi(self)
        self.populate_comboboxes()
        self.mRasterLayersList.setup(self.project)
        
        self.helpTextBrowser.setSearchPaths([os.path.dirname(__file__)])
        self.helpTextBrowser.setSource(QtCore.QUrl("help.md"))
        self.helpTextBrowser.setOpenExternalLinks(True)
        self.libraryTextBrowser.setSearchPaths([os.path.dirname(__file__)])
        self.libraryTextBrowser.setSource(QtCore.QUrl("library.md"))
        self.libraryTextBrowser.setOpenExternalLinks(True)

        self.set_id_field()

        self.widget_console = WidgetPlainTextWriter(self.mPlainText)

        # set filters on combo boxes to get correct layer types
        self.mWeightsLayerComboBox.setFilters(QgsMapLayerProxyModel.RasterLayer)
        self.mVectorLayerComboBox.setFilters(QgsMapLayerProxyModel.PolygonLayer)
        # set ID field combo box to current vector layer
        self.mFieldComboBox.setFilters(
            QgsFieldProxyModel.LongLong | QgsFieldProxyModel.Int
        )
        if self.mVectorLayerComboBox.currentLayer():
            self.mFieldComboBox.setLayer(self.mVectorLayerComboBox.currentLayer())
        self.mVectorLayerComboBox.layerChanged.connect(self.set_field_vector_layer)
        # set temp_index_field class variable when user selects another index field
        if self.mFieldComboBox.currentField():
            self.temp_index_field = self.mFieldComboBox.currentField()
        self.mFieldComboBox.fieldChanged.connect(self.set_id_field)
        # make weights layer empty as default
        self.mWeightsLayerComboBox.setCurrentIndex(0)

        self.mCalculateButton.clicked.connect(self.calculate)

        self.mAddModifyMetricButton.clicked.connect(self.edit_metric_function)
        self.editor.codeSubmitted.connect(self.modify_code)

    def populate_comboboxes(self):
        aggregates_stats_list = [
            "count",
            "majority",
            "max",
            "max_center_x",
            "max_center_y",
            "mean",
            "median",
            "min",
            "min_center_x",
            "min_center_y",
            "minority",
            "stdev",
            "sum",
            "variance",
            "variety",
            "weighted_mean",
            "weighted_sum",
            "weighted_sum",
            "weighted_variance",
        ]
        arrays_stats_list = [
            "cell_id",
            "frac",
            "center_x",
            "center_y",
            "coverage",
            "values",
            "weights",
            "weighted_frac",
        ]
        self.mAggregatesComboBox.addItems(aggregates_stats_list)
        self.mArraysComboBox.addItems(arrays_stats_list)

    def calculate(self):
        """
        The calculate method disables the calculate button, gets input values from the dialog
        and stores them in the dialog_input attribute, and initiates the calculation process
        using QgsTask and exactextract. If an exception occurs during the calculation,
        an error message is logged and displayed in the console.
        """
        self.mCalculateButton.setEnabled(False)
        try:
            self.get_input_values()  # loads input values from the dialog into self.dialog_input
            if self.dialog_input is None:
                self.mCalculateButton.setEnabled(True)
                return
            self.input_vector: QgsVectorLayer = self.dialog_input.vector_layer

            self.features_count = self.input_vector.featureCount()
            batch_size = round(self.features_count / self.dialog_input.parallel_jobs)

            # calculate using QgsTask and exactextract
            self.process_calculations(self.input_vector, batch_size)

            # wait for calculations to finish to continue
            if self.merge_task is not None:
                self.merge_task.taskCompleted.connect(self.postprocess)
                self.merge_task.taskTerminated.connect(self.postprocess)
        except ValueError as exc:
            QgsMessageLog.logMessage(f"ERROR: {str(exc)}")
            self.uc.bar_warn(str(exc))
            self.widget_console.write_error(str(exc))
        finally:
            if self.input_vector:
                self.input_vector.removeSelection()  # remove selection of features after processing
            self.mCalculateButton.setEnabled(True)

    def process_calculations(self, vector: QgsVectorLayer, batch_size: int):
        """
        Processes the calculations for zonal statistics using exactextract.
        This method initiates a series of tasks to calculate zonal statistics for a given vector layer
        using exactextract. It creates a `CalculateStatsTask` for each batch of features and adds it
        as a subtask to a `MergeStatsTask`.

        Args:
            vector (QgsVectorLayer): The input vector layer for which to calculate zonal statistics.
            batch_size (int): The number of features to process in each batch.
        """
        self.intermediate_result_list = []
        self.merge_task = MergeStatsTask(
            "Zonal ExactExtract task",
            QgsTask.CanCancel,
            result_list=self.intermediate_result_list,
            index_column=self.temp_index_field,
            prefix=self.dialog_input.prefix,
            geospatial_output=self.geospatial_output,
            output_file_path=self.dialog_input.output_file_path,
            source_columns=self.input_attributes_dict,
            source_crs=vector.crs(),
        )
        self.merge_task.taskChanged.connect(self.widget_console.write_info)
        self.merge_task.progressChanged.connect(self.update_progress_bar)

        self.tasks = []

        feature_ids = vector.allFeatureIds()
        for i in range(0, self.features_count, batch_size):
            selection_ids = feature_ids[i : i + batch_size]
            temp_vector = vector.materialize(
                QgsFeatureRequest().setFilterFids(selection_ids)
            )

            stats_list = (
                self.dialog_input.aggregates_stats_list
                + self.dialog_input.arrays_stats_list
                + self.dialog_input.custom_functions_list
            )
            calculation_subtask = CalculateStatsTask(
                f"calculation subtask {i}",
                flags=QgsTask.Silent,
                result_list=self.intermediate_result_list,
                polygon_layer=temp_vector,
                rasters=self.dialog_input.raster_layers_path,
                weights=self.dialog_input.weights_layer_path,
                stats=stats_list,
                include_cols=self.input_attributes_dict,
                geospatial_output=self.geospatial_output,
                strategy=self.dialog_input.strategy,
            )
            calculation_subtask.taskChanged.connect(self.widget_console.write_info)
            self.tasks.append(calculation_subtask)
            self.merge_task.addSubTask(
                calculation_subtask, [], QgsTask.ParentDependsOnSubTask
            )

        self.task_manager.addTask(self.merge_task)

    def postprocess(self):
        """
        This method is called after the zonal statistics calculation is complete. It saves the result
        to a file based on the user's selected file format, loads the output into QGIS, and joins the
        output to the input vector layer if necessary.
        """
        try:
            if not self.geospatial_output:
                calculated_stats = self.merge_task.calculated_stats
                message = f"Zonal ExactExtract task result shape: {str(calculated_stats.shape)}"
                QgsMessageLog.logMessage(message)
                self.widget_console.write_info(message)

                # save result based on user decided extension
                if self.dialog_input.output_file_path.suffix == ".csv":
                    calculated_stats.to_csv(
                        self.dialog_input.output_file_path, index=False
                    )

            # load output into QgsVectorLayer
            output_attribute_layer = QgsVectorLayer(
                str(self.dialog_input.output_file_path),
                Path(self.dialog_input.output_file_path).stem,
                "ogr",
            )

            total_fields = len(output_attribute_layer.fields())
            if (
                output_attribute_layer.fields().at(total_fields - 1).name() == "path"
                and output_attribute_layer.fields().at(total_fields - 2).name()
                == "layer"
            ):
                output_attribute_layer.startEditing()
                # Delete the last two fields
                output_attribute_layer.deleteAttribute(
                    total_fields - 1
                )  # delete path field
                output_attribute_layer.deleteAttribute(
                    total_fields - 2
                )  # delete layer field
                output_attribute_layer.commitChanges()

            # check if the layer was loaded successfully
            if not output_attribute_layer.isValid():
                message = (
                    f"Unable to load layer from {self.dialog_input.output_file_path}"
                )
                QgsMessageLog.logMessage(message)
                self.widget_console.write_error(message)
            else:
                self.widget_console.write_info("Finished calculating statistics")
                # Add the layer to the project
                self.project.addMapLayer(output_attribute_layer)
                self.output_attribute_layer = output_attribute_layer

        except Exception as exc:
            QgsMessageLog.logMessage(f"ERROR: {exc}")
            self.widget_console.write_error(exc)
        finally:
            self.clean()

    def update_progress_bar(self):
        """
        Updates the progress bar using progress values from parent (MergeStatsTask) task
        """
        self.mProgressBar.setValue(self.merge_task.progress())

    def clean(self):
        """
        Reinitialize object values to free memory after calculation is done
        """
        self.dialog_input: DialogInputDTO = None
        self.tasks = []
        self.intermediate_result_list = []
        self.merge_task: MergeStatsTask = None
        self.calculated_stats_list = []
        self.mCalculateButton.setEnabled(True)

        self.mProgressBar.setValue(0)

    def get_input_values(self):
        """
        Gets input values from dialog and puts it into `DialogInputDTO` class object.
        """
        raster_layers_path: List[QgsRasterLayer] = self.extract_layers_path(
            self.mRasterLayersList.checked_layers()
        )
        weights_layer_path: str = None
        if self.mWeightsLayerComboBox.currentLayer():
            weights_layer_path = (
                self.mWeightsLayerComboBox.currentLayer().dataProvider().dataSourceUri()
            )
        vector_layer: QgsVectorLayer = self.mVectorLayerComboBox.currentLayer()
        parallel_jobs: int = self.mSubtasksSpinBox.value()
        if self.mQgsOutputFileWidget.filePath() == "":
            output_file_path = None
        else:
            output_file_path: Path = Path(self.mQgsOutputFileWidget.filePath())
        aggregates_stats_list: List[str] = self.mAggregatesComboBox.checkedItems()
        arrays_stats_list: List[str] = self.mArraysComboBox.checkedItems()
        prefix: str = self.mPrefixEdit.text()

        try:
            self.control_input(
                raster_layers_path=raster_layers_path,
                vector_layer=vector_layer,
                output_file_path=output_file_path,
                aggregates_stats_list=aggregates_stats_list,
                arrays_stats_list=arrays_stats_list,
            )
        except ValueError as exc:
            # there's been error during control of the input values
            # and we can't push processing further
            raise exc

        # create list with custom functions codes that will be converted to callables
        custom_functions: List[str] = []
        selected_functions_names: List[str] = (
            self.mCustomFunctionsComboBox.checkedItems()
        )
        if selected_functions_names:
            for selected_function_name in selected_functions_names:
                custom_functions.append(
                    self.custom_functions_dict[selected_function_name]
                )

        self.dialog_input = DialogInputDTO(
            raster_layers_path=raster_layers_path,
            weights_layer_path=weights_layer_path,
            vector_layer=vector_layer,
            parallel_jobs=parallel_jobs,
            output_file_path=output_file_path,
            aggregates_stats_list=aggregates_stats_list,
            arrays_stats_list=arrays_stats_list,
            prefix=prefix,
            custom_functions_str_list=custom_functions,
            strategy=self.mStrategyComboBox.currentText(),
        )

    def extract_layers_path(self, layers: List[QgsMapLayer]):
        """
        This method extracts the data source URIs of the input map layers and returns a list of the extracted URIs.

        Args:
            layers: List[QgsMapLayer] - A list of QGIS map layers.

        Returns:
            List[str] - A list of the data source URIs of the input map layers.
        """
        layers_path: List[str] = []
        for layer in layers:
            layers_path.append(layer.dataProvider().dataSourceUri())
        return layers_path

    def control_input(
        self,
        raster_layers_path: Path,
        vector_layer: QgsVectorLayer,
        output_file_path: str,
        aggregates_stats_list: List[str],
        arrays_stats_list: List[str],
    ):
        """
        Processes the input data by checking the validity of the input parameters.

        This method checks if both raster and vector layers are set, if the ID field is set, if the ID field is unique, if an output
        file path is selected, if the output file extension is CSV, and if both stats lists are empty.

        Args:
            raster_layers_path: Path - The path to the raster layer.
            vector_layer: QgsVectorLayer - The vector layer.
            temp_index_field: str - The ID field.
            output_file_path: Path - The path to the output file.
            aggregates_stats_list: List[str] - The list of aggregates statistics.
            arrays_stats_list: List[str] - The list of arrays statistics.
        """
        # check if both raster and vector layers are set
        if not raster_layers_path or not vector_layer:
            err_msg = "You didn't select raster layer or vector layer"
            raise ValueError(err_msg)
        if not output_file_path:
            err_msg = "You didn't select output file path"
            raise ValueError(err_msg)
        # check if output file extension is CSV
        output_file_path_suffix = output_file_path.suffix.strip(".")
        if output_file_path_suffix != "csv":
            # check if extension is in OGR allowed extensions
            if (
                output_file_path_suffix
                not in QgsVectorFileWriter.supportedFormatExtensions()
            ):
                err_msg = (
                    f"Output file extension {output_file_path_suffix} is not supported"
                )
                raise ValueError(err_msg)
            else:
                self.geospatial_output = True
                fields = vector_layer.fields()
                self.input_attributes_dict = {
                    name: fields.indexFromName(name) for name in fields.names()
                }
        else:
            self.geospatial_output = False
            self.input_attributes_dict = {self.temp_index_field: 0}
        # check if ID field is set if output is not geospatial
        if (not self.temp_index_field or self.temp_index_field == "") and (
            not self.geospatial_output
        ):
            err_msg = "You didn't select ID field"
            raise ValueError(err_msg)
        if self.temp_index_field and not self.geospatial_output:
            # check if values in vector_layer temp_index_field are unique
            id_idx = vector_layer.fields().indexOf(self.temp_index_field)
            id_unique_values = vector_layer.uniqueValues(id_idx)
            if len(id_unique_values) < vector_layer.featureCount():
                err_msg = f"{self.temp_index_field} field values are not unique. Please select unique field as ID field."
                raise ValueError(err_msg)
        # check if both stats lists are empty
        if not aggregates_stats_list and not arrays_stats_list:
            err_msg = "You didn't select anything from either Aggregates and Arrays"
            raise ValueError(err_msg)

    def set_field_vector_layer(self):
        """
        Sets fields to the Field ComboBox if vector layer has changed
        """
        selectedLayer = self.mVectorLayerComboBox.currentLayer()
        if selectedLayer:
            self.mFieldComboBox.setLayer(selectedLayer)

    def set_id_field(self):
        """
        Sets index method variable
        """
        self.temp_index_field = self.mFieldComboBox.currentField()

    def edit_metric_function(self):
        """
        Edits the metric function by setting the editor to the selected custom function code or the default code.

        This method retrieves the topmost checked custom function from the combobox, gets the corresponding code from the
        custom_functions_dict, and sets the editor to display that code. If no item is selected or the list is empty, the
        default code is used instead.
        """
        try:
            function_name = self.mCustomFunctionsComboBox.checkedItems()[0]
            code = self.custom_functions_dict[function_name]
        except IndexError:  # no item selected or list is empty
            code = DEFAULT_CODE
        # set editor to that code
        self.editor.set_code(code)
        self.editor.show()

    def modify_code(self, code: str):
        """
        Modifies the code in the custom functions dictionary and updates the combobox

        Args:
            code: The code to be modified and added to the dictionary.
        """
        # get function name as string
        function_name = extract_function_name(code)
        # modify the code in the dict
        self.custom_functions_dict[function_name] = code
        # if function name does not exist in combobox add function to combobox
        if self.mCustomFunctionsComboBox.findText(function_name) == -1:
            self.mCustomFunctionsComboBox.addItemWithCheckState(
                function_name, QtCore.Qt.Checked
            )
