# -*- coding: utf-8 -*-

"""
/***************************************************************************
 Karika
                                 A QGIS plugin
 Terrain generalization
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2018-03-05
        copyright            : (C) 2018 by Roman Geisthövel
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 3 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__    = 'Roman Geisthövel'
__date__      = '2018-03-05'
__copyright__ = '(C) 2018 by Roman Geisthövel'

#=============================================================================
import os
from functools import partial
from tempfile import gettempdir
from uuid import uuid4

from PyQt5.QtCore import QCoreApplication

from qgis.PyQt.QtGui import QIcon
from qgis.core import (QgsProcessingAlgorithm,
                       QgsProcessingParameterRasterLayer,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterEnum,
                       QgsProcessingException,
                       QgsRasterFileWriter,
                       QgsProcessingParameterRasterDestination)

import numpy as np

import karika.appinter as appi

#=============================================================================
def progress(fraction, feedback):
    if feedback.isCanceled():
        return 0
    else:
        feedback.setProgress(int(fraction*100))
        return 1


#=============================================================================
def smooth_lic(grid, cellsize, num_steps, progress):
    
    # Generalization by Line Integral Convolution of the grid values
    # along the gradient of the grid, as described in the PhD thesis
    # Geisthoevel (2017) "Automatic Swiss style rock depiction"
    # https://doi.org/10.3929/ethz-b-000201368

    src = grid.copy()
    
    sink = (slice(0,num_steps//2),slice(0,num_steps//2))
    src0 = src[sink].copy()

    H,W = src.shape

    DS = 0.5 # Integration speed

    # Calc normalized gradient, i.e. unit normals times DS
    vi,vj = np.gradient(src, cellsize)
    res = np.hypot(vi,vj)
    m = np.equal(res, 0)
    np.copyto(res,1,where=m)
    np.multiply(res, 1./DS, out=res)
    for _ in (vi,vj):
        np.divide(_,res,out=_)

    m2 = m.copy()
    count = np.ones_like(src, "u4")
    np.copyto(res, src)

    src[0,0] = vi[0,0] = vj[0,0] = 0

    for i in range(2):
        if i: # Reverse gradient direction for backward integration
            for _ in (vi,vj):
                np.negative(_,out=_)
        pi,pj = np.mgrid[slice(H),slice(W)]
        fi,fj = (_.astype(float) for _ in (pi,pj))
        for j in range(num_steps):
            fi += vi[pi,pj]
            fj += vj[pi,pj]
            np.trunc(fi,out=pi,casting="unsafe")
            np.trunc(fj,out=pj,casting="unsafe")
            np.less(pi,0,m)
            for a,v in [(pi,H), (pj,0), (pj,W)]:
                if v != 0:
                    np.greater_equal(a,v,m2)
                else:
                    np.less(a,v,m2)
                np.logical_or(m,m2,m)
            # Map out of bounds indices to [0,0], where src is zero
            np.copyto(pi, 0, where=m)
            np.copyto(pj, 0, where=m)
            res += src[pi,pj]
            np.logical_not(m,m2)
            np.add(count, 1, where=m2, out=count)
            if not progress((i*num_steps + (j+1)) / (2.*num_steps)):
                # Canceled
                return None
    count[sink] = 1
    res[sink] = src0
    np.divide(res, count, out=res)
    return res


#=============================================================================
class KarikaAlgorithm(QgsProcessingAlgorithm):

    IN_RASTER  = "IN_RASTER"
    LIC_LENGTH = "LIC_LENGTH"
    Z_FACTOR   = "Z_FACTOR"
    BAND_NO    = "BAND_NO"
    OUT_RASTER = "OUT_RASTER"
    BORDER     = "BORDER"

    #-------------------------------------------------------------------------
    def initAlgorithm(self, config):
        """
        Here we define the inputs and output of the algorithm.
        """

        self.addParameter(QgsProcessingParameterRasterLayer(
                            self.IN_RASTER,
                            self.tr('Input raster')))

        self.addParameter(QgsProcessingParameterNumber(
                            self.LIC_LENGTH,
                            self.tr("Integration length"),
                            QgsProcessingParameterNumber.Integer,
                            5, False, 1))

        self.addParameter(QgsProcessingParameterNumber(
                            self.BAND_NO,
                            self.tr("Band number"),
                            QgsProcessingParameterNumber.Integer,
                            1, False, 1))

        self.addParameter(QgsProcessingParameterNumber(
                            self.Z_FACTOR,
                            self.tr("Z factor"),
                            QgsProcessingParameterNumber.Double,
                            1, False, 1e-10))

        self.addParameter(QgsProcessingParameterRasterDestination(
                            self.OUT_RASTER,
                            self.tr('Output raster'), 
                            os.path.join(gettempdir(), 
                                         "karika_{}.tif".format(uuid4()))))

        self._borderMode = [
            (self.tr("Keep border"),                "keep"),
            (self.tr("Crop border"),                "crop"),
            (self.tr("Fill border with NODATA"),    "fill")]

        self.addParameter(QgsProcessingParameterEnum(self.BORDER,
                            self.tr('Border mode'),
                            options=[_[0] for _ in self._borderMode],
                            allowMultiple=False,
                            defaultValue=0))
            
            

    #-------------------------------------------------------------------------
    def processAlgorithm(self, args, ctx, feedback):
        """
        Here is where the processing itself takes place.
        """
        Ras = appi.Raster
        App = appi.App


        log = feedback.setProgressText
        prog = partial(progress, feedback=feedback)

        out_raster = self.parameterAsOutputLayer(args, self.OUT_RASTER, ctx)
        res = { self.OUT_RASTER : out_raster }

        # Check output format
        out_fmt = QgsRasterFileWriter.driverForExtension(os.path.splitext(out_raster)[1])
        if not out_fmt or out_fmt.lower() != "gtiff":
            log("CRITICAL: Currently only GeoTIFF output format allowed, exiting!")
            return res

        if not prog(0):
            log("Canceled!")
            return res

        # Process args
        in_raster  = self.parameterAsRasterLayer(args,  self.IN_RASTER,  ctx)
        int_length = self.parameterAsInt(args,          self.LIC_LENGTH, ctx)
        band_no    = self.parameterAsInt(args,          self.BAND_NO,    ctx)
        z_factor   = self.parameterAsDouble(args,       self.Z_FACTOR,   ctx)
        border     = self.parameterAsEnum(args,         self.BORDER,     ctx)

        if not (1 <= band_no <= Ras.num_bands(in_raster)):
            log(("WARNING: invalid band number selected ({}), "
                  "using band 1 instead.").format(band_no))
            band_no = 1

        if z_factor <= 0:
            log(("WARNING: invalid z-factor selected ({}), "
                  "using 1 instead.").format(z_factor))
            z_factor = 1
        
        log(self.tr("Reading input raster ..."))
        grid = Ras.to_numpy(in_raster, band=band_no)

        if not prog(0): 
            log("Canceled!")
            return res

        if z_factor != 1 and z_factor > 0:
            grid *= z_factor

        dy, dx = Ras.cellsize(in_raster)

        if dx != dy:
            cs = (dx+dy) * 0.5
            log(("WARNING: cell dimensions for X and Y differ ({}, {}), "
                  "using mean {} instead.").format(dx,dy,cs))
        else:
            cs = dx

        log(self.tr("Processing ..."))

        smoothed = smooth_lic(grid, cs, int_length, prog)

        if smoothed is None:
            log("Canceled or error")
            return res

        save_opts = {}

        # Border handling
        border = self._borderMode[border][1]
        if border != "keep":
            H,W = smoothed.shape
            off = int_length//2 + 1
            if W > 2*off and H > 2*off:
                if border == "crop":
                    smoothed = smoothed[off:H-off-1,off:W-off-1]
                    # Modify geo transform accordingly
                    gt = Ras.geo_transform(in_raster)
                    save_opts["geo_transform"] = (gt[0]+gt[1]*off,gt[1],gt[2],
                                                  gt[3]+gt[5]*off,gt[4],gt[5])
                    del gt
                elif border == "fill":
                    nodata = Ras.nodata_value(in_raster)
                    nodate = -32768 if nodata is None else nodata
                    smoothed[0:off+1,:]               = nodata
                    smoothed[slice(H-off-1,None),:]   = nodata
                    smoothed[:,0:off+1]               = nodata
                    smoothed[:,(slice(W-off-1,None))] = nodata
                    save_opts["nodata"] = nodata
                    del nodata
            del H,W,off
        del border

        log(self.tr("Saving output raster ..."))
        Ras.numpy_to_file(smoothed, out_raster, src=str(in_raster.source()), **save_opts)

        size = appi.Common.file_size(out_raster)
        if (size >> 20):
            log("Size: %i MB" % (size>>20))
        else:
            log("Size: %i bytes" % size)

        if not prog(1):
            log("Canceled!")
            return res

        log(self.tr("Done!\n"))

        return res

    #-------------------------------------------------------------------------
    def icon(self):
        return QIcon(self.svgIconPath())

    #-------------------------------------------------------------------------
    def svgIconPath(self):
        C = appi.Common
        return C.mkpath(C.folder(), "img_gear.svg")

    #-------------------------------------------------------------------------
    def name(self):
        return 'karika'

    #-------------------------------------------------------------------------
    def displayName(self):
        return "Karika"

    #-------------------------------------------------------------------------
    def group(self):
        return self.tr("Generalization")

    #-------------------------------------------------------------------------
    def groupId(self):
        return "generalization"

    #-------------------------------------------------------------------------
    def tr(self, string):
        return QCoreApplication.translate('Processing', string)

    #-------------------------------------------------------------------------
    def tags(self):
        return [self.tr(_) for _ in ("Raster", "Generalization", "Processing")]

    #-------------------------------------------------------------------------
    def createInstance(self):
        return KarikaAlgorithm()

    #-------------------------------------------------------------------------
    def helpUrl(self):
        return "file:///%s/help/index.html" % appi.Common.folder()
