# -*- coding: utf-8 -*-

'''
/***************************************************************************
 DrainageBasinGeomorphology
                                 A QGIS plugin
 This plugin provides tools for geomorphological analysis in drainage basins.
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2025-03-22
        copyright            : (C) 2025 by João Vitor Pimenta
        email                : jvpjoaopimenta@gmail.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''

__author__ = 'João Vitor Pimenta'
__date__ = '2025-03-22'
__copyright__ = '(C) 2025 by João Vitor Pimenta'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from qgis.core import QgsProcessingException
from collections import Counter
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from osgeo import gdal, ogr
import numpy as np
import csv
import itertools
import bisect
import os

def verifyLibs():
        try:
            import numpy
        except ImportError:
            raise QgsProcessingException('Numpy library not found, please install it and try again.')
        try:
            import plotly
        except ImportError:
            raise QgsProcessingException('Plotly library not found, please install it and try again.')

def loadDEM(demLayer):
    DEMpath = demLayer.dataProvider().dataSourceUri().split('|')[0]
    ds = gdal.Open(DEMpath)
    band = ds.GetRasterBand(1)
    demArray = band.ReadAsArray()
    noData = band.GetNoDataValue() if band.GetNoDataValue() is not None else -9999
    gt = ds.GetGeoTransform()
    proj = ds.GetProjection()
    rows, cols = demArray.shape
    ds = None
    return demArray, noData, gt, proj, rows, cols

def EAVBelowProcessing(demArray,noData,gt,proj,cols,rows,basin,distanceContour,baseLevel,useOnlyDEMElev,useMaxDEMElev,feedback):
    basinGeom = basin.geometry()
    wkb = basinGeom.asWkb()
    ogrGeom = ogr.CreateGeometryFromWkb(wkb)

    rasterDrive = gdal.GetDriverByName('MEM')
    vectorDrive = ogr.GetDriverByName('Memory')
    vectorDriveSource    = vectorDrive.CreateDataSource('wrk')
    vectorLayer = vectorDriveSource.CreateLayer('lyr', None, ogr.wkbUnknown)
    featureDef   = vectorLayer.GetLayerDefn()
    ogrFeat  = ogr.Feature(featureDef)
    ogrFeat.SetGeometry(ogrGeom)
    vectorLayer.CreateFeature(ogrFeat)
    ogrFeat = None

    maskDS = rasterDrive.Create('', cols, rows, 1, gdal.GDT_Byte)
    maskDS.SetGeoTransform(gt)
    maskDS.SetProjection(proj)
    gdal.RasterizeLayer(maskDS, [1], vectorLayer, burn_values=[1])

    mask = maskDS.GetRasterBand(1).ReadAsArray()
    validMask = (mask == 1) & (demArray != noData)
    validDataInsideBasin = demArray[validMask].tolist()

    if not validDataInsideBasin:
        feedback.pushWarning('There is no valid raster data in the basin of id '+str(basin.id())+' and therefore it is not possible to calculate the elevation - area - volume.')
        return None, None, None

    counterValues = Counter(validDataInsideBasin)
    counterValuesOrdered = sorted(counterValues.items())

    originalElevations = [item[0] for item in counterValuesOrdered]
    elevations = [item[0] for item in counterValuesOrdered]
    countElevations = [item[1] for item in counterValuesOrdered]

    pixelWidth  = abs(gt[1])
    pixelHeight = abs(gt[5])
    areas = np.array(countElevations) * (pixelWidth * pixelHeight)
    originalCumulativeAreas = np.cumsum(areas)
    cumulativeAreas = np.cumsum(areas)

    deltaElev = np.diff(elevations)
    volumes = ((cumulativeAreas[1:] + cumulativeAreas[:-1])/2) * deltaElev
    originalCumulativeVolumes = np.concatenate(([0], np.cumsum(volumes)))
    cumulativeVolumes = np.concatenate(([0], np.cumsum(volumes)))

    if useOnlyDEMElev is True:
        distanceContour = None
    if useMaxDEMElev is True:
        baseLevel = None

    if baseLevel is not None:
        elevationsWithBaseLevel = sorted(elevations + [baseLevel])
        if baseLevel not in elevations:
            elevationsWithBaseLevel = sorted(elevations + [baseLevel])

            if distanceContour is None:
                cumulativeAreas = np.interp(elevationsWithBaseLevel, elevations, cumulativeAreas)
                cumulativeVolumes = np.interp(elevationsWithBaseLevel, elevations, cumulativeVolumes)
                elevations = elevationsWithBaseLevel

        index = bisect.bisect_right(elevationsWithBaseLevel, baseLevel)
        elevations = elevationsWithBaseLevel[:index]
        cumulativeAreas = cumulativeAreas[:index]
        cumulativeVolumes = cumulativeVolumes[:index]

    if distanceContour is not None:
        minElevation = min(elevations)
        maxElevation = max(elevations)

        elevationCurves = np.arange(minElevation, maxElevation, distanceContour)

        if maxElevation not in elevationCurves:
            elevationCurves = np.append(elevationCurves,maxElevation)

        interpAreas = np.interp(elevationCurves, originalElevations, originalCumulativeAreas)
        interpVolumes = np.interp(elevationCurves, originalElevations, originalCumulativeVolumes)

        elevations = elevationCurves.tolist()
        cumulativeAreas = interpAreas
        cumulativeVolumes = interpVolumes

    constantAreaFill = originalCumulativeAreas[-1]

    if baseLevel is not None:
        if baseLevel > max(originalElevations):
            elevationsFillArray = np.array(elevations)
            maxElevationFillArray = np.max(originalElevations)

            deltaH = elevationsFillArray[elevationsFillArray > maxElevationFillArray] - maxElevationFillArray
            volumeFill = originalCumulativeVolumes[-1] + constantAreaFill * deltaH
            lenVolumeFill = len(volumeFill)
            cumulativeVolumes[-lenVolumeFill:] = volumeFill

    cumulativeAreasList = cumulativeAreas.tolist()
    cumulativeVolumesList = cumulativeVolumes.tolist()
    return elevations, cumulativeAreasList, cumulativeVolumesList

def runEAVBelow(drainageBasinLayer,demLayer,pathCsv,pathHtml,distanceContour,baseLevel,useOnlyDEMElev,useMaxDEMElev,feedback):
    feedback.setProgress(0)
    total = drainageBasinLayer.featureCount()
    step = 100.0 / total if total else 0

    os.makedirs(pathHtml, exist_ok=True)

    demArray,noData,gt,proj,rows,cols = loadDEM(demLayer)

    listsWithData = []

    for idx, basin in enumerate(drainageBasinLayer.getFeatures()):
        if feedback.isCanceled():
            return
        feedback.setProgressText('Basin id '+str(basin.id())+' processing starting...')
        elevations, cumulativeAreas, cumulativeVolumes = EAVBelowProcessing(demArray,noData,gt,proj,cols,rows,basin,distanceContour,baseLevel,useOnlyDEMElev,useMaxDEMElev,feedback)

        if (elevations is None and cumulativeAreas is None and cumulativeVolumes is None):
            return

        elevations.insert(0,'Elevation basin id'+str(basin.id()))
        cumulativeAreas.insert(0,'Area basin id '+str(basin.id()))
        cumulativeVolumes.insert(0,'Volume basin id '+str(basin.id()))

        listsWithData.append(elevations)
        listsWithData.append(cumulativeAreas)
        listsWithData.append(cumulativeVolumes)

        feedback.setProgressText('Basin id '+str(basin.id())+' processing completed')

        if feedback.isCanceled():
            return

        feedback.setProgressText('Basin id '+str(basin.id())+' graph starting...')
        fig = go.Figure()
        fig = make_subplots(specs=[[{"secondary_y": True}]])

        fig.add_trace(go.Scatter(
                                x=cumulativeVolumes,
                                y=elevations,
                                mode='lines',
                                name='Volume - Elevation basin id '+str(basin.id())
                                ),
                                secondary_y=False
                                )
        fig.add_trace(go.Scatter(x=cumulativeAreas,
                                y=elevations,
                                mode='lines',
                                name='Area - Elevation basin id '+str(basin.id())
                                ),
                                secondary_y=True
                                )

        fig.data[1].update(xaxis='x2')

        fig.update_layout(
            title='Elevation - Area - Volume graph',
            xaxis=dict(title='Volume (m3)'),
            yaxis=dict(title='Elevation (m)'),
            xaxis2=dict(title='Area (m2)',
                        overlaying='x',
                        side='top',
                        autorange='reversed'),
            yaxis2=dict(
                        title='Elevation (m)',
                        overlaying='y',
                        side='right',
                        position=1
                        )
                            )

        outputHTML = os.path.join(pathHtml, 'GRAPH_BASIN_ID_'+str(basin.id())+'.html')
        fig.write_html(outputHTML)

        barProgress = int((idx + 1) * step)
        feedback.setProgress(barProgress)
        feedback.setProgressText('Basin id '+str(basin.id())+' graph completed')

        fig.show()

    if feedback.isCanceled():
            return
    feedback.setProgressText('Basin id '+str(basin.id())+' graph completed')
    with open(pathCsv, 'w', newline='') as archive:
        writer = csv.writer(archive)
        writer.writerows(itertools.zip_longest(*listsWithData))

