# -*- coding: utf-8 -*-

'''
/***************************************************************************
 DrainageBasinGeomorphology
                                 A QGIS plugin
 This plugin provides tools for geomorphological analysis in drainage basins.
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2025-03-22
        copyright            : (C) 2025 by João Vitor Pimenta
        email                : jvpjoaopimenta@gmail.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
'''

__author__ = 'João Vitor Pimenta'
__date__ = '2025-03-22'
__copyright__ = '(C) 2025 by João Vitor Pimenta'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

from qgis.core import QgsProcessingException
from collections import Counter
import plotly.graph_objects as go
import numpy as np
from osgeo import gdal, ogr
import csv
import itertools

def verifyLibs():
        try:
            import numpy
        except ImportError:
            raise QgsProcessingException('Numpy library not found, please install it and try again.')
        try:
            import plotly
        except ImportError:
            raise QgsProcessingException('Plotly library not found, please install it and try again.')

def loadDEM(demLayer):
    DEMpath = demLayer.dataProvider().dataSourceUri().split('|')[0]
    ds = gdal.Open(DEMpath)
    band = ds.GetRasterBand(1)
    demArray = band.ReadAsArray()
    noData = band.GetNoDataValue() if band.GetNoDataValue() is not None else -9999
    gt = ds.GetGeoTransform()
    proj = ds.GetProjection()
    rows, cols = demArray.shape
    ds = None
    return demArray, noData, gt, proj, rows, cols

def calculateHypsometricCurve(demArray,noData,gt,proj,cols,rows,basin,absoluteValues,distanceContour,areaBelow,useOnlyDEMElev,feedback):
    basinGeom = basin.geometry()
    wkb = basinGeom.asWkb()
    ogrGeom = ogr.CreateGeometryFromWkb(wkb)

    rasterDrive = gdal.GetDriverByName('MEM')
    vectorDrive = ogr.GetDriverByName('Memory')
    vectorDriveSource    = vectorDrive.CreateDataSource('wrk')
    vectorLayer = vectorDriveSource.CreateLayer('lyr', None, ogr.wkbUnknown)
    featureDef   = vectorLayer.GetLayerDefn()
    ogrFeat  = ogr.Feature(featureDef)
    ogrFeat.SetGeometry(ogrGeom)
    vectorLayer.CreateFeature(ogrFeat)
    ogrFeat = None

    maskDS = rasterDrive.Create('', cols, rows, 1, gdal.GDT_Byte)
    maskDS.SetGeoTransform(gt)
    maskDS.SetProjection(proj)
    gdal.RasterizeLayer(maskDS, [1], vectorLayer, burn_values=[1])

    mask = maskDS.GetRasterBand(1).ReadAsArray()
    validMask = (mask == 1) & (demArray != noData)
    validDataInsideBasin = demArray[validMask].tolist()

    if not validDataInsideBasin:
        feedback.pushWarning('There is no valid raster data in the basin of id '+str(basin.id())+' and therefore it is not possible to calculate the elevation - area - volume.')
        return None, None

    filteredData = [nonMaskValue for nonMaskValue in validDataInsideBasin if not np.ma.is_masked(nonMaskValue)]
    counterValues = Counter(filteredData)
    counterValuesOrdered = sorted(counterValues.items(),reverse=True)

    if areaBelow is True:
        counterValuesOrdered = sorted(counterValues.items())

    elevations = [item[0] for item in counterValuesOrdered]
    countElevations = [item[1] for item in counterValuesOrdered]

    pixelWidth  = abs(gt[1])
    pixelHeight = abs(gt[5])
    areas = np.array(countElevations) * (pixelWidth * pixelHeight)
    cumulativeAreas = np.cumsum(areas)

    minElevation = min(elevations)
    maxElevation = max(elevations)

    if useOnlyDEMElev is True:
         distanceContour = None

    if distanceContour is not None:
        elevationCurves = np.arange(minElevation, maxElevation, distanceContour)

        if maxElevation not in elevationCurves:
            elevationCurves = np.append(elevationCurves,maxElevation)

        interpAreas = np.interp(elevationCurves, elevations[::-1], cumulativeAreas[::-1])

        elevations = elevationCurves[::-1].tolist()
        cumulativeAreas = interpAreas[::-1]

    if absoluteValues is True:
        elevationsList = elevations

        areasList = cumulativeAreas.tolist()

        elevationsList.insert(0,'Absolute elevation (m) basin id '+str(basin.id()))
        areasList.insert(0,'Absolute area (m2) basin id '+str(basin.id()))

        return elevationsList, areasList

    relativeHeights = (np.array(elevations) - min(elevations))/(max(elevations) - min(elevations))
    relativeHeightsList = relativeHeights.tolist()

    minArea = min(cumulativeAreas)
    maxArea = max(cumulativeAreas)

    relativeAreas = (np.array(cumulativeAreas) - minArea)/(maxArea - minArea)
    relativeAreasList = relativeAreas.tolist()

    relativeHeightsList.insert(0,'Relative height (h/H) basin id '+str(basin.id()))
    relativeAreasList.insert(0,'Relative area (a/A) basin id '+str(basin.id()))

    return relativeHeightsList, relativeAreasList

def calculateHI(elevations,areas,basin):
    elevationsWOTitle = elevations[1:]
    areasWOTitle = areas[1:]

    hypsometricIntegral = np.trapz(elevationsWOTitle,areasWOTitle)
    listHI = [hypsometricIntegral]
    listHI.insert(0, 'Hypsometric integral basin id '+str(basin.id()))

    return listHI

def exportHypsometricCurves(listsWithData,path):
    with open(path, 'w', newline='') as archive:
        writer = csv.writer(archive)
        writer.writerows(itertools.zip_longest(*listsWithData))

def runHypsometricCurves(drainageBasinLayer,demLayer,pathCsv,pathHtml,absoluteValues,distanceContour,areaBelow,useOnlyDEMElev,feedback):
    demArray,noData,gt,proj,rows,cols = loadDEM(demLayer)

    feedback.setProgress(0)
    total = drainageBasinLayer.featureCount()
    step = 100.0 / total if total else 0

    fig = go.Figure()

    listsWithData = []

    for idx, basin in enumerate(drainageBasinLayer.getFeatures()):
        if feedback.isCanceled():
            return
        feedback.setProgressText('Basin id '+str(basin.id())+' processing starting...')

        heights, cumulativeAreas = calculateHypsometricCurve(demArray,noData,gt,proj,cols,rows,basin,absoluteValues,distanceContour,areaBelow,useOnlyDEMElev,feedback)
        hypsometricIntegral =calculateHI(heights,cumulativeAreas,basin)

        if feedback.isCanceled():
            return
        listsWithData.append(heights)
        listsWithData.append(cumulativeAreas)
        listsWithData.append(hypsometricIntegral)

        feedback.setProgressText('Basin id '+str(basin.id())+' processing completed')

        if feedback.isCanceled():
            return

        feedback.setProgressText('Basin id '+str(basin.id())+' graph starting...')
        fig.add_trace(go.Scatter(
                                x=cumulativeAreas,
                                y=heights,
                                mode='lines',
                                name='basin id '+ str(basin.id())+' Integral = '+str(round(hypsometricIntegral[1],2))))

        barProgress = int((idx + 1) * step)
        feedback.setProgress(barProgress)
        feedback.setProgressText('Basin id '+str(basin.id())+' processing graph completed')

    if absoluteValues is True:
        fig.update_layout(
        title='Graph comparing the hypsometric curves of the drainage basins',
        xaxis_title='Absolute area (m2)',
        yaxis_title='Absolute elevation (m)'
    )
        fig.show()
        fig.write_html(pathHtml)
        return

    fig.update_layout(
        title='Graph comparing the hypsometric curves of the drainage basins',
        xaxis_title='Relative area (a/A)',
        yaxis_title='Relative height (h/H)'
    )

    if feedback.isCanceled():
            return
    fig.show()
    fig.write_html(pathHtml)

    if feedback.isCanceled():
            return
    exportHypsometricCurves(listsWithData,pathCsv)
