# -*- coding: utf-8 -*-

'''
/***************************************************************************
 DrainageBasinGeomorphology
                                 A QGIS plugin
 This plugin provides tools for geomorphological analysis in drainage basins.
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2025-03-22
        copyright            : (C) 2025 by João Vitor Pimenta
        email                : jvpjoaopimenta@gmail.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''

__author__ = 'João Vitor Pimenta'
__date__ = '2025-03-22'
__copyright__ = '(C) 2025 by João Vitor Pimenta'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from qgis.core import QgsProcessingException
from collections import Counter
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from osgeo import gdal, ogr
import numpy as np
import csv
import itertools
import bisect
import os

def verifyLibs():
        try:
            import numpy
        except ImportError:
            raise QgsProcessingException('Numpy library not found, please install it and try again.')
        try:
            import plotly
        except ImportError:
            raise QgsProcessingException('Plotly library not found, please install it and try again.')

def loadDEM(demLayer):
    DEMpath = demLayer.dataProvider().dataSourceUri().split('|')[0]
    ds = gdal.Open(DEMpath)
    band = ds.GetRasterBand(1)
    demArray = band.ReadAsArray()
    noData = band.GetNoDataValue()
    gt = ds.GetGeoTransform()
    proj = ds.GetProjection()
    rows, cols = demArray.shape
    ds = None
    return demArray, noData, gt, proj, rows, cols

def EAVAboveProcessing(demArray,noData,gt,proj,cols,rows,basin,distanceContour,baseLevel,useOnlyDEMElev,useMinDEMElev,feedback):
    basinGeom = basin.geometry()
    wkb = basinGeom.asWkb()
    ogrGeom = ogr.CreateGeometryFromWkb(wkb)

    rasterDrive = gdal.GetDriverByName('MEM')
    vectorDrive = ogr.GetDriverByName('Memory')
    vectorDriveSource    = vectorDrive.CreateDataSource('wrk')
    vectorLayer = vectorDriveSource.CreateLayer('lyr', None, ogr.wkbUnknown)
    featureDef   = vectorLayer.GetLayerDefn()
    ogrFeat  = ogr.Feature(featureDef)
    ogrFeat.SetGeometry(ogrGeom)
    vectorLayer.CreateFeature(ogrFeat)
    ogrFeat = None

    maskDS = rasterDrive.Create('', cols, rows, 1, gdal.GDT_Byte)
    maskDS.SetGeoTransform(gt)
    maskDS.SetProjection(proj)
    gdal.RasterizeLayer(maskDS, [1], vectorLayer, burn_values=[1])

    mask = maskDS.GetRasterBand(1).ReadAsArray()
    validMask = (mask == 1) & (demArray != noData)
    validDataInsideBasin = demArray[validMask].tolist()

    if not validDataInsideBasin:
        feedback.pushWarning('There is no valid raster data in the basin of id '+str(basin.id())+' and therefore it is not possible to calculate the elevation - area - volume.')
        return None, None, None

    counterValues = Counter(validDataInsideBasin)
    counterValuesOrdered = sorted(counterValues.items(),reverse=True)

    originalElevations = [item[0] for item in counterValuesOrdered]
    elevations = [item[0] for item in counterValuesOrdered]
    countElevations = [item[1] for item in counterValuesOrdered]

    pixelWidth  = abs(gt[1])
    pixelHeight = abs(gt[5])
    areas = np.array(countElevations) * (pixelWidth * pixelHeight)
    originalCumulativeAreas = np.cumsum(areas)
    cumulativeAreas = np.cumsum(areas)

    deltaElev = abs(np.diff(elevations))
    volumes = ((cumulativeAreas[1:] + cumulativeAreas[:-1])/2) * deltaElev
    originalCumulativeVolumes = np.concatenate(([0], np.cumsum(volumes)))
    cumulativeVolumes = np.concatenate(([0], np.cumsum(volumes)))

    if useOnlyDEMElev is True:
        distanceContour = None
    if useMinDEMElev is True:
        baseLevel = None

    if distanceContour == 0:
        raise QgsProcessingException('The distance between contour lines cannot be 0.')

    if baseLevel is not None:
        elevationsWithBaseLevel = sorted(elevations + [baseLevel],reverse=True)
        if baseLevel not in elevations:
            elevationsWithBaseLevel = sorted(elevations + [baseLevel],reverse=True)

            if distanceContour is None:
                cumulativeAreas = np.interp(elevationsWithBaseLevel, elevations[::-1], cumulativeAreas[::-1])
                cumulativeVolumes = np.interp(elevationsWithBaseLevel, elevations[::-1], cumulativeVolumes[::-1])
                elevations = elevationsWithBaseLevel

        negElevations = [-e for e in elevationsWithBaseLevel]
        index = bisect.bisect_right(negElevations, -baseLevel)
        elevations = elevationsWithBaseLevel[:index]
        cumulativeAreas = cumulativeAreas[:index]
        cumulativeVolumes = cumulativeVolumes[:index]

    if distanceContour is not None:
        minElevation = min(elevations)
        maxElevation = max(elevations)

        elevationCurves = np.arange(minElevation, maxElevation, distanceContour)
        if maxElevation not in elevationCurves:
            elevationCurves = np.append(elevationCurves,maxElevation)

        interpAreas = np.interp(elevationCurves, originalElevations[::-1], originalCumulativeAreas[::-1])
        interpVolumes = np.interp(elevationCurves, originalElevations[::-1], originalCumulativeVolumes[::-1])

        elevations = elevationCurves[::-1].tolist()
        cumulativeAreas = interpAreas[::-1]
        cumulativeVolumes = interpVolumes[::-1]

    constantAreaCut = originalCumulativeAreas[-1]

    if baseLevel is not None:
        if baseLevel < min(originalElevations):
            elevationsCutArray = np.array(elevations)
            minElevationCutArray = np.min(originalElevations)

            deltaH = minElevationCutArray - elevationsCutArray[elevationsCutArray < minElevationCutArray]
            volumeCut = originalCumulativeVolumes[-1] + constantAreaCut * deltaH
            lenVolumeCut = len(volumeCut)
            cumulativeVolumes[-lenVolumeCut:] = volumeCut

    cumulativeAreasList = cumulativeAreas.tolist()
    cumulativeVolumesList = cumulativeVolumes.tolist()
    return elevations, cumulativeAreasList, cumulativeVolumesList

def runEAVAbove(drainageBasinLayer,demLayer,pathCsv,pathHtml,distanceContour,baseLevel,useOnlyDEMElev,useMinDEMElev,feedback,decimalPlaces,useAllDecimalPlaces):
    feedback.setProgress(0)
    total = drainageBasinLayer.featureCount()
    step = 100.0 / total if total else 0

    demArray,noData,gt,proj,rows,cols = loadDEM(demLayer)

    fig = go.Figure()
    listsWithData = []

    for idx, basin in enumerate(drainageBasinLayer.getFeatures()):
        if feedback.isCanceled():
            return

        feedback.setProgressText(f'Basin id {basin.id()} processing starting...')
        elevations, cumulativeAreas, cumulativeVolumes = EAVAboveProcessing(
            demArray, noData, gt, proj, cols, rows, basin,
            distanceContour, baseLevel, useOnlyDEMElev, useMinDEMElev, feedback
        )

        if elevations is None and cumulativeAreas is None and cumulativeVolumes is None:
            continue

        if useAllDecimalPlaces is False:
            elevations = [round(num, decimalPlaces) for num in elevations]
            cumulativeAreas = [round(num, decimalPlaces) for num in cumulativeAreas]
            cumulativeVolumes = [round(num, decimalPlaces) for num in cumulativeVolumes]

        elevations.insert(0, f'Elevation basin id {basin.id()}')
        cumulativeAreas.insert(0, f'Area basin id {basin.id()}')
        cumulativeVolumes.insert(0, f'Volume basin id {basin.id()}')

        listsWithData.append(elevations)
        listsWithData.append(cumulativeAreas)
        listsWithData.append(cumulativeVolumes)

        feedback.setProgressText(f'Basin id {basin.id()} processing completed')

        if feedback.isCanceled():
            return

        feedback.setProgressText(f'Basin id {basin.id()} graph starting...')

        fig.add_trace(go.Scatter(
            x=cumulativeVolumes,
            y=elevations,
            mode='lines',
            name=f'Volume - Elevation basin id {basin.id()}',
            yaxis='y',
            xaxis='x'
        ))

        fig.add_trace(go.Scatter(
            x=cumulativeAreas,
            y=elevations,
            mode='lines',
            name=f'Area - Elevation basin id {basin.id()}',
            yaxis='y2',
            xaxis='x2'
        ))

        barProgress = int((idx + 1) * step)
        feedback.setProgress(barProgress)
        feedback.setProgressText(f'Basin id {basin.id()} graph completed')

    fig.update_layout(
        title='Elevation - Area - Volume graph',
        xaxis=dict(title='Volume (m³)'),
        yaxis=dict(title='Elevation (m)'),
        xaxis2=dict(
            title='Area (m²)',
            overlaying='x',
            side='top',
            autorange='reversed'
        ),
        yaxis2=dict(
            title='Elevation (m)',
            overlaying='y',
            side='right',
            position=1
        ),
    )

    fig.write_html(pathHtml)
    fig.show()

    if feedback.isCanceled():
        return

    feedback.setProgressText(f'Graph completed for all basins')

    with open(pathCsv, 'w', newline='') as archive:
        writer = csv.writer(archive)
        writer.writerows(
            itertools.zip_longest(*[
                [col[0]] + [f"{v:.{decimalPlaces}f}" for v in col[1:]]
                for col in listsWithData
            ])
        )

