#!/usr/bin/python
#coding=utf-8
"""
/***************************************************************************
        begin                : 2021-11
        copyright            : (C) 2024 by Giacomo Titti,Bologna, November 2024
        email                : giacomotitti@gmail.com
 ***************************************************************************/

/***************************************************************************
    Copyright (C) 2024 by Giacomo Titti, Bologna, November 2024

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
 ***************************************************************************/
"""

__author__ = 'Giacomo Titti'
__date__ = '2024-11-01'
__copyright__ = '(C) 2024 by Giacomo Titti'

import sys
sys.setrecursionlimit(10000)
from qgis.core import (QgsProcessing,
                       QgsProcessingException,
                       QgsProcessingMultiStepFeedback,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterFileDestination,
                       QgsProcessingParameterVectorLayer,
                       QgsVectorLayer,
                       QgsProcessingParameterField,
                       QgsProcessingParameterFolderDestination,
                       QgsProcessingParameterField,
                       QgsProcessingContext,
                       QgsProcessingParameterEnum
                       )
from qgis.core import *
from qgis import *
import tempfile
from sz_module.scripts.utils import SZ_utils
from sz_module.scripts.algorithms import CV_utils,GAM_utils
from sz_module.utils import log
from sz_module.test.utils import load_test_input


class CoreAlgorithmGAM_cv():

    def init(self, config=None):
        if os.environ.get('DEBUG')=='True':
            data=load_test_input("GAM_cv")
            self.addParameter(QgsProcessingParameterVectorLayer(self.INPUT, self.tr('Input layer'), types=[QgsProcessing.TypeVectorPolygon], defaultValue=data[self.INPUT]))
            self.addParameter(QgsProcessingParameterField(self.STRING3, 'Linear independent variables', parentLayerParameterName=self.INPUT, defaultValue=data[self.STRING3], allowMultiple=True,type=QgsProcessingParameterField.Any,optional=True))
            self.addParameter(QgsProcessingParameterField(self.STRING, 'Ordinal independent variables', parentLayerParameterName=self.INPUT, defaultValue=data[self.STRING], allowMultiple=True,type=QgsProcessingParameterField.Any,optional=True))
            self.addParameter(QgsProcessingParameterNumber(self.NUMBER1, self.tr('Spline smoothing parameter'), type=QgsProcessingParameterNumber.Integer,defaultValue=data[self.NUMBER1]))
            self.addParameter(QgsProcessingParameterField(self.STRING8, 'Interacting variable A', parentLayerParameterName=self.INPUT, defaultValue=data[self.STRING8], allowMultiple=False,type=QgsProcessingParameterField.Any,optional=True))
            self.addParameter(QgsProcessingParameterField(self.STRING9, 'Interacting variable B', parentLayerParameterName=self.INPUT, defaultValue=data[self.STRING9], allowMultiple=False,type=QgsProcessingParameterField.Any,optional=True))
            self.addParameter(QgsProcessingParameterField(self.STRING1, 'Categorical independent variables', parentLayerParameterName=self.INPUT, defaultValue=data[self.STRING1], allowMultiple=True,type=QgsProcessingParameterField.Any,optional=True))
            self.addParameter(QgsProcessingParameterEnum(self.STRING4, 'Family', options=['binomial','gaussian'], allowMultiple=False, usesStaticStrings=False, defaultValue=data[self.STRING4]))
            self.addParameter(QgsProcessingParameterEnum(self.STRING7, 'Scale (for Gaussian Family only)', options=['linear scale','log scale'], allowMultiple=False, usesStaticStrings=False, defaultValue=data[self.STRING7],optional=True))
            self.addParameter(QgsProcessingParameterField(self.STRING2, 'Field of dependent variable (0 for absence, > 0 for presence)', parentLayerParameterName=self.INPUT, defaultValue=data[self.STRING2]))
            self.addParameter(QgsProcessingParameterEnum(self.STRING5, 'CV method', options=['random CV','spatial CV','temporal CV (Time Series Split)','temporal CV (Leave One Out)', 'space-time CV (Leave One Out)'], allowMultiple=False, usesStaticStrings=False, defaultValue=data[self.STRING5]))
            self.addParameter(QgsProcessingParameterField(self.STRING6, 'Time field (for temporal CV)', parentLayerParameterName=self.INPUT, defaultValue=data[self.STRING6], allowMultiple=False,type=QgsProcessingParameterField.Any, optional=True ))
            self.addParameter(QgsProcessingParameterNumber(self.NUMBER, self.tr('K-fold CV: K=1 to fit, k>1 to cross-validate for spatial CV only'), minValue=1,type=QgsProcessingParameterNumber.Integer,defaultValue=data[self.NUMBER],optional=True))
            self.addParameter(QgsProcessingParameterFileDestination(self.OUTPUT, 'Output test/fit',fileFilter='GeoPackage (*.gpkg *.GPKG)', defaultValue=data[self.OUTPUT]))
            self.addParameter(QgsProcessingParameterFolderDestination(self.OUTPUT3, 'Outputs folder destination', defaultValue=data[self.OUTPUT3], createByDefault = True))
        else:
            self.addParameter(QgsProcessingParameterVectorLayer(self.INPUT, self.tr('Input layer'), types=[QgsProcessing.TypeVectorPolygon], defaultValue=None))
            self.addParameter(QgsProcessingParameterField(self.STRING3, 'Linear independent variables', parentLayerParameterName=self.INPUT, defaultValue=None, allowMultiple=True,type=QgsProcessingParameterField.Any,optional=True))
            self.addParameter(QgsProcessingParameterField(self.STRING, 'Ordinal independent variables', parentLayerParameterName=self.INPUT, defaultValue=None, allowMultiple=True,type=QgsProcessingParameterField.Any,optional=True))
            self.addParameter(QgsProcessingParameterNumber(self.NUMBER1, self.tr('Spline smoothing parameter'), type=QgsProcessingParameterNumber.Integer,defaultValue=10))
            self.addParameter(QgsProcessingParameterField(self.STRING8, 'Interacting variable A', parentLayerParameterName=self.INPUT, defaultValue=None, allowMultiple=False,type=QgsProcessingParameterField.Any,optional=True))
            self.addParameter(QgsProcessingParameterField(self.STRING9, 'Interacting variable B', parentLayerParameterName=self.INPUT, defaultValue=None, allowMultiple=False,type=QgsProcessingParameterField.Any,optional=True))
            self.addParameter(QgsProcessingParameterField(self.STRING1, 'Categorical independent variables', parentLayerParameterName=self.INPUT, defaultValue=None, allowMultiple=True,type=QgsProcessingParameterField.Any,optional=True))
            self.addParameter(QgsProcessingParameterEnum(self.STRING4, 'Family', options=['binomial','gaussian'], allowMultiple=False, usesStaticStrings=False, defaultValue=''))
            self.addParameter(QgsProcessingParameterEnum(self.STRING7, 'Scale (for Gaussian Family only)', options=['linear scale','log scale'], allowMultiple=False, usesStaticStrings=False, defaultValue='linear scale',optional=True))
            self.addParameter(QgsProcessingParameterField(self.STRING2, 'Field of dependent variable (0 for absence, > 0 for presence)', parentLayerParameterName=self.INPUT, defaultValue=''))
            self.addParameter(QgsProcessingParameterEnum(self.STRING5, 'CV method', options=['random CV','spatial CV','temporal CV (Time Series Split)','temporal CV (Leave One Out)', 'space-time CV (Leave One Out)'], allowMultiple=False, usesStaticStrings=False, defaultValue=''))
            self.addParameter(QgsProcessingParameterField(self.STRING6, 'Time field (for temporal CV)', parentLayerParameterName=self.INPUT, defaultValue=None, allowMultiple=False,type=QgsProcessingParameterField.Any, optional=True ))
            self.addParameter(QgsProcessingParameterNumber(self.NUMBER, self.tr('K-fold CV: K=1 to fit, k>1 to cross-validate for spatial CV only'), minValue=1,type=QgsProcessingParameterNumber.Integer,defaultValue=2,optional=True))
            self.addParameter(QgsProcessingParameterFileDestination(self.OUTPUT, 'Output test/fit',fileFilter='GeoPackage (*.gpkg *.GPKG)', defaultValue=None))
            self.addParameter(QgsProcessingParameterFolderDestination(self.OUTPUT3, 'Outputs folder destination', defaultValue=None, createByDefault = True))

    def process(self, parameters, context, feedback, algorithm=None, classifier=None):

        self.f=tempfile.gettempdir()
        feedback = QgsProcessingMultiStepFeedback(1, feedback)
        results = {}
        outputs = {}

        family={'0':'binomial','1':'gaussian'}
        cv_method={'0':'random','1':'spatial','2':'temporal_TSS','3':'temporal_LOO','4':'spacetime_LOO'}
        scale={'0':'linear_scale','1':'log_scale'}

        source = self.parameterAsVectorLayer(parameters, self.INPUT, context)
        parameters['covariates']=source.source()
        if parameters['covariates'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.INPUT))

        if source is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.INPUT))

        parameters['field1'] = self.parameterAsFields(parameters, self.STRING, context)
        if parameters['field1'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING))
        
        parameters['field2'] = self.parameterAsFields(parameters, self.STRING1, context)
        if parameters['field2'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING1))
        
        parameters['field3'] = self.parameterAsFields(parameters, self.STRING3, context)
        if parameters['field3'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING3))

        parameters['fieldlsd'] = self.parameterAsString(parameters, self.STRING2, context)
        if parameters['fieldlsd'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING2))
        
        parameters['family'] = self.parameterAsString(parameters, self.STRING4, context)
        if parameters['family'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING4))
        
        parameters['scale'] = self.parameterAsString(parameters, self.STRING7, context)
        if parameters['scale'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING7))
        
        parameters['var_interaction_A'] = self.parameterAsFields(parameters, self.STRING8, context)
        if parameters['var_interaction_A'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING8))
        
        parameters['var_interaction_B'] = self.parameterAsFields(parameters, self.STRING9, context)
        if parameters['var_interaction_B'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING9))
        
        parameters['field1'] = self.parameterAsFields(parameters, self.STRING, context)
        if parameters['field1'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING))
        
        parameters['num1'] = self.parameterAsInt(parameters, self.NUMBER1, context)
        if parameters['num1'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.NUMBER1))
        
        parameters['cv_method'] = self.parameterAsString(parameters, self.STRING5, context)
        if parameters['cv_method'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING5))
        
        parameters['time'] = self.parameterAsString(parameters, self.STRING6, context)
        if parameters['time'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.STRING6))

        parameters['testN'] = self.parameterAsInt(parameters, self.NUMBER, context)
        if parameters['testN'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.NUMBER))
 
        parameters['out'] = self.parameterAsFileOutput(parameters, self.OUTPUT, context)
        if parameters['out'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.OUTPUT))

        parameters['folder'] = self.parameterAsString(parameters, self.OUTPUT3, context)
        if parameters['folder'] is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.OUTPUT3))
        
        SZ_utils.make_directory({'path':parameters['folder']})

        if cv_method[parameters['cv_method']]=='random' or cv_method[parameters['cv_method']]=='spatial':
            parameters['time']=None
        else:
            if parameters['time']=='':
                log(f"Time field is missing for temporal CV")
                raise RuntimeError("Time field is missing for temporal CV")
        
        if parameters['var_interaction_A'] != [] and parameters['var_interaction_B'] != []: 
            tensor=[parameters['var_interaction_A'][0],parameters['var_interaction_B'][0]]

            alg_params = {
                'linear': parameters['field3'],
                'continuous': parameters['field1'],
                'categorical': parameters['field2'],
                'tensor': tensor,
            }
            if SZ_utils.check_validity(alg_params) is False:
                return {}
        else:
            tensor=[]

        feedback.setCurrentStep(0)
        if feedback.isCanceled():
            return {}
        
        alg_params = {
            'INPUT_VECTOR_LAYER': parameters['covariates'],
            'nomi': parameters['field3']+parameters['field1']+parameters['field2']+tensor,
            'lsd' : parameters['fieldlsd'],
            'family':family[parameters['family']],
            'time':parameters['time'],
            'scale':scale[parameters['scale']],
        }

        outputs['df'],outputs['crs']=SZ_utils.load_cv(self.f,alg_params)

        feedback.setCurrentStep(1)
        if feedback.isCanceled():
            return {}
                
        alg_params = {
            'linear': parameters['field3'],
            'continuous': parameters['field1'],
            'categorical': parameters['field2'],
            'tensor': tensor,
            'nomi': parameters['field3']+parameters['field1']+parameters['field2']+tensor,
            'spline': parameters['num1'],
        }

        outputs['splines'],outputs['dtypes']=GAM_utils.GAM_formula(alg_params)

        feedback.setCurrentStep(2)
        if feedback.isCanceled():
            return {}
        
        alg_params = {
            'testN':parameters['testN'],
            'fold':parameters['folder'],
            'nomi':parameters['field3']+parameters['field1']+parameters['field2']+tensor,
            'df':outputs['df'],
            'splines':outputs['splines'],
            'dtypes':outputs['dtypes'],
            'categorical':parameters['field2'],
            'linear':parameters['field3'],
            'continuous':parameters['field1'],
            'tensor': tensor,
            'family':family[parameters['family']],
            'cv_method':cv_method[parameters['cv_method']],
            'time':parameters['time']
        }

        outputs['prob'],outputs['test_ind'],outputs['gam']=CV_utils.cross_validation(alg_params,algorithm,classifier)

        feedback.setCurrentStep(3)
        if feedback.isCanceled():
            return {}
        
        if parameters['testN']>0:
            alg_params = {
                'df': outputs['df'],
                'crs': outputs['crs'],
                'OUT': parameters['out']
            }
            SZ_utils.save(alg_params)

        feedback.setCurrentStep(4)
        if feedback.isCanceled():
            return {}
        
        if family[parameters['family']]=='binomial':
            alg_params = {
                'test_ind': outputs['test_ind'],
                'df': outputs['df'],
                'OUT':parameters['folder']
            }
            SZ_utils.stamp_cv(alg_params)
        
        if family[parameters['family']]=='gaussian':
            alg_params = {
                'test_ind': outputs['test_ind'],
                'df': outputs['df'],
                'OUT':parameters['folder']
            }
            outputs['error_train']=SZ_utils.stamp_qq(alg_params)

            alg_params = {
                'df': outputs['df'],                
                'OUT':parameters['folder']
            }
            outputs['error_train']=SZ_utils.stamp_qq_fit(alg_params)

        results['out'] = parameters['out']

        feedback.setCurrentStep(5)
        if feedback.isCanceled():
            return {}

        fileName = parameters['out']
        layer1 = QgsVectorLayer(fileName,"test","ogr")
        subLayers =layer1.dataProvider().subLayers()

        for subLayer in subLayers:
            name = subLayer.split('!!::!!')[1]
            print(name,'name')
            uri = "%s|layername=%s" % (fileName, name,)
            print(uri,'uri')
            # Create layer
            sub_vlayer = QgsVectorLayer(uri, name, 'ogr')
            if not sub_vlayer.isValid():
                print('layer failed to load')
            # Add layer to map
            context.temporaryLayerStore().addMapLayer(sub_vlayer)
            context.addLayerToLoadOnCompletion(sub_vlayer.id(), QgsProcessingContext.LayerDetails('test', context.project(),'LAYER1'))

        feedback.setCurrentStep(6)
        if feedback.isCanceled():
            return {}
    
        del outputs    
        del subLayers

        return results