# -*- coding: utf-8 -*-

"""
/***************************************************************************
 QLearn
                                 A QGIS plugin
 QLearn preforms automatic training of a neural network model on raster data
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2025-02-01
        copyright            : (C) 2025 by Adam B
        email                : adam {at} ardanika {dot} com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Adam B'
__date__ = '2025-02-01'
__copyright__ = '(C) 2025 by Adam B'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessingAlgorithm,
                       QgsProcessingParameterFileDestination,
                       QgsProcessingParameterBoolean,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterFile,
                       QgsProcessingParameterNumber, 
                       QgsProcessingParameterString)

from .QLearnDataset import QDataset
from .QLearnTrain import QUNetTrainer
from .QLearnRasterSelectWidget import *
import torch
import cProfile, os, datetime, argparse, shlex


class QLearnTrainingAlgorithm(QgsProcessingAlgorithm):
    
    def flags(self):
        return super().flags()
    
    """
    This is an example algorithm that takes a vector layer and
    creates a new identical one.

    It is meant to be used as an example of how to create your own
    algorithms and explain methods and variables used to do it. An
    algorithm like this will be available in all elements, and there
    is not need for additional work.

    All Processing algorithms should extend the QgsProcessingAlgorithm
    class.
    """

    # Constants used to refer to parameters and outputs. They will be
    # used when calling the algorithm from another algorithm, or when
    # calling from the QGIS console.

    RASTER_PAIRS = 'RASTER_PAIRS'
    OUTPUT_MODEL = 'OUTPUT_MODEL'
    ARGS_NODATA = 'ARGS_NODATA'
    ARGS_NORMALIZE = 'ARGS_NORMALIZE'
    ARGS_EPOCHS = 'ARGS_EPOCH'
    ARGS_LR = 'ARGS_LEARNINGRATE'
    INPUT_MODEL = 'INPUT_MODEL'
    ARGS_TRAINTYPE = 'TRAINING_TYPE'
    ARGS_EXTRA = 'ARGS_EXTRA'

    training_types = ["classification","regression"]

    def initAlgorithm(self, config):
        """
        Here we define the inputs and output of the algorithm, along
        with some other properties.
        """

        self.addParameter(
            RasterPairParameter(
                self.RASTER_PAIRS, 
                'Select raster pairs')
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.ARGS_EPOCHS,
                self.tr("Number of Epochs"),
                defaultValue=10
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.ARGS_NODATA,
                self.tr("NODATA Value"),
                defaultValue=-100
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.ARGS_LR,
                self.tr("Learning Rate"),
                type=QgsProcessingParameterNumber.Double,
                defaultValue=0.001
            )
        )

        self.addParameter(
            QgsProcessingParameterBoolean(
                self.ARGS_NORMALIZE,
                self.tr("Normalize input values"),
                defaultValue=True
            )
        )

        self.addParameter(
            QgsProcessingParameterEnum(
                self.ARGS_TRAINTYPE,
                self.tr("Model training type"),
                options=self.training_types,
                allowMultiple=False
            )
        )

        self.addParameter(
            QgsProcessingParameterFile(
                self.INPUT_MODEL,
                self.tr('Model to continue training on (will overwrite settings provided)'),
                fileFilter='PyTorch Model(*.pth)',
                optional=True
            )
        )

        self.addParameter(
            QgsProcessingParameterFileDestination(
                self.OUTPUT_MODEL,
                self.tr("Output model location"),
                fileFilter='PyTorch Model(*.pth)'
            )
        )

        self.addParameter(
            QgsProcessingParameterString(
                self.ARGS_EXTRA,
                self.tr("Extra arguments"),
                defaultValue="",
                optional=True
            )
        )


    def processAlgorithm(self, parameters, context, feedback):
        # Note: pip install snakeviz to view profiling data
        # Note: snakviz profile.prof

        """
        Here is where the processing itself takes place.
        """

        # Fetch Input Parameters from Dict
        current_model = self.parameterAsFile(parameters,self.INPUT_MODEL, context)
        model_save_loc = self.parameterAsFileOutput(parameters, self.OUTPUT_MODEL, context)
        n_epochs = self.parameterAsInt(parameters,self.ARGS_EPOCHS, context)
        nodata = self.parameterAsInt(parameters,self.ARGS_NODATA, context)
        training_type = self.parameterAsEnum(parameters,self.ARGS_TRAINTYPE, context)
        normalize_inputs = self.parameterAsBoolean(parameters,self.ARGS_NORMALIZE, context)
        learning_rate = self.parameterAsDouble(parameters, self.ARGS_LR, context)
        raster_pairs = self.parameterAsString(parameters, self.RASTER_PAIRS, context)
        extra_args = self.parameterAsString(parameters, self.ARGS_EXTRA, context)

        feedback.pushInfo(f"raster pairs: {raster_pairs}")

        # Parse extra args
        parser = argparse.ArgumentParser(description="Extra arguments for training",exit_on_error=False)
        parser.add_argument("-b", "--batch_size", type=int, default=16, help="Batch size for training", required=False)
        parser.add_argument("-d","--depth", type=int, default=4, help="Depth of the UNet Model", required=False)
        parser.add_argument("-c","--channels", type=int, default=64, help="Number of input channels for UNet Model", required=False)
        parser.add_argument("-v","--validation_split", type=float, default=0.2, help="Validation split for training", required=False)
        parser.add_argument("-n","--normalize_targets", type=bool, default=True, help="Normalize target data (regression only)", required=False)
        parser.add_argument("-p","--profile", type=bool, default=False, help="Enable profiling", required=False)
        parser.add_argument("-ch","--chunk_size", type=int, default=256, help="Chunk size for training", required=False)
        parser.add_argument("-w","--weights", type=float, nargs="+", default=None, help="Class Weightings for Classification problems (number of values must be equal to number of classes)", required=False)
        parser.add_argument("-sm","--save_mode", type=int, default=0, help="0=best, 1=last", required=False)
        parser.add_argument("-ep","--end_patience", type=int, default=5, help="Early stopping patience in epochs (if no improvement in validation loss)", required=False)
        parser.add_argument("-r","--rescale", type=float, default=1.0, help="Downscale images by n (e.g. 0.6x)", required=False)

        try:
            parser_args = parser.parse_args(shlex.split(extra_args))
            feedback.pushInfo(f"Extra args: {parser_args}")
        except argparse.ArgumentError as e:
            feedback.reportError("Error parsing extra arguments: " + str(e))
            return {self.OUTPUT_MODEL: None}  # Return None if parsing fails
            

        doProfiling = parser_args.profile

        args = {
            "CHUNK_SIZE": parser_args.chunk_size,
            "NODATA": nodata,
            "BATCH_SIZE": parser_args.batch_size,
            "LEARNING_RATE": learning_rate,
            "EPOCHS": n_epochs,
            "DEVICE": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            "TRAIN_TYPE": self.training_types[training_type], # classification or regression
            "NORMALIZE_INPUTS": normalize_inputs,
            "NORMALIZE_TARGETS": parser_args.normalize_targets, # only for regression
            "CLASS_REMAPPING": True,
            "VALIDATION_SPLIT": parser_args.validation_split,
            "M_DEPTH": parser_args.depth,
            "M_CHANNELS": parser_args.channels,
            "CLASS_WEIGHTS": parser_args.weights,           # CrossEntropyLoss class weightings (for classification only)
            "SAVE_MODE": parser_args.save_mode,             # 0 = save best, 1 = save last
            "END_PATIENCE": parser_args.end_patience,       # Early stopping patience in epochs (if no improvement in validation loss)
            "RESCALE": parser_args.rescale,                 # Scale the input and target images to a smaller size before training
        }

        feedback.pushInfo(f"Args: {args}")

        # load checkpoint if retraining
        checkpoint = None
        if current_model:
            checkpoint = torch.load(current_model, weights_only=False)
            feedback.pushInfo(f"Loaded checkpoint from {current_model}")


        # SETUP PROFILING
        if doProfiling:
            profiler = cProfile.Profile()
            profiler.enable()


        # Setup Dataset
        dataset = QDataset(raster_pairs, context, feedback, args, checkpoint)
        trainer = QUNetTrainer(dataset,model_save_loc, feedback, args, checkpoint)
        try: # Start Training
            trainer.train()
        except Exception as e:
            feedback.reportError(str(e))
            raise e
        


        # FINISH PROFILING
        if doProfiling:
            self.finish_profiling(profiler, feedback)


        return {self.OUTPUT_MODEL: model_save_loc}

    def finish_profiling(self, profiler, feedback):
        # make profiling directory
        profile_dir = os.path.join(os.path.dirname(__file__), "profile_results")
        os.makedirs(profile_dir, exist_ok=True)
        
        # Create filename
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        profile_file = os.path.join(profile_dir, f"profile_{timestamp}.prof")
        report_file = os.path.join(profile_dir, f"profile_{timestamp}.txt")
        
                
        # Save raw profile data
        profiler.disable()
        profiler.dump_stats(profile_file)

    def name(self):
        """
        Returns the algorithm name, used for identifying the algorithm. This
        string should be fixed for the algorithm, and must not be localised.
        The name should be unique within each provider. Names should contain
        lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return 'QLearnTrain'

    def displayName(self):
        """
        Returns the translated algorithm name, which should be used for any
        user-visible display of the algorithm name.
        """
        return self.tr(self.name())

    def group(self):
        """
        Returns the name of the group this algorithm belongs to. This string
        should be localised.
        """
        return self.tr(self.groupId())

    def groupId(self):
        """
        Returns the unique ID of the group this algorithm belongs to. This
        string should be fixed for the algorithm, and must not be localised.
        The group id should be unique within each provider. Group id should
        contain lowercase alphanumeric characters only and no spaces or other
        formatting characters.
        """
        return 'Training'

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)

    def createInstance(self):
        return QLearnTrainingAlgorithm()
