from qgis.PyQt.QtCore import (QT_TRANSLATE_NOOP)
from qgis.core import (
  QgsCoordinateReferenceSystem,
  QgsProcessingParameterExtent,
  QgsProcessingParameterDistance,
  QgsProcessingParameterCrs, 
  QgsProcessingParameterFeatureSink,
  QgsProcessingParameterRasterDestination,
  QgsProperty,
  QgsProcessingParameterString,
  QgsProcessingParameterNumber,
  QgsProcessingParameterBoolean,
  QgsCoordinateTransform,
  QgsProject
  )
from qgis import processing

from .fetchabstract import fetchabstract
from osgeo import gdal
import numpy as np

class fetchdemrasterja(fetchabstract):
  
  PARAMETERS = {  
    "FETCH_EXTENT": {
      "ui_func": QgsProcessingParameterExtent,
      "ui_args":{
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Extent for fetching data")
      }
    },
    "TARGET_CRS": {
      "ui_func": QgsProcessingParameterCrs,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Target CRS (Cartesian coordinates)")
      }
    },
    "BUFFER": {
      "ui_func": QgsProcessingParameterDistance,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Buffer of the fetch area (using Target CRS)"),
        "defaultValue": 0.0,
        "parentParameterName": "TARGET_CRS"
      }
    },
    "VECTORISE": {
      "ui_func": QgsProcessingParameterBoolean,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Output elevation points as vector layer"),
        "defaultValue": True
      }
    },
    "TILEMAP_URL": {
      "ui_func": QgsProcessingParameterString,
      "ui_args": {
        "optional": True,
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Base-URL of the vector-tile map"),
        "defaultValue": "https://cyberjapandata.gsi.go.jp/xyz/dem_png/{z}/{x}/{y}.png"
      }
    },
    "TILEMAP_CRS": {
      "ui_func": QgsProcessingParameterCrs,
      "ui_args": {
        "optional": True,
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","CRS of the vector-tile map"),        
        "defaultValue": "EPSG:3857" # must be specified as string, because optional parameter cannot be set as QgsCoordinateReferenceSystem
      }
    },
    "TILEMAP_ZOOM": {
      "ui_func": QgsProcessingParameterNumber,
      "ui_args": {
        "optional": True,
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Zoom level of the vector-tile map"),
        "type": QgsProcessingParameterNumber.Integer,
        "defaultValue": 14
      }
    },
    "TILEMAP_PIXELS": {
      "ui_func": QgsProcessingParameterNumber,
      "ui_args": {
        "optional": True,
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Number of pixels in a row/column of the tile map"),
        "type": QgsProcessingParameterNumber.Integer,
        "defaultValue": 256
      }
    },
    "OUTPUT": {
      "ui_func": QgsProcessingParameterFeatureSink,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Elevation point (DEM)")
      }
    },
    "OUTPUT_RASTER": {
      "ui_func": QgsProcessingParameterRasterDestination,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Elevation raster (DEM)" )
      }
    }
  }  
  
  def initAlgorithm(self, config):    
    self.initUsingCanvas()
    self.initParameters()

  def processAlgorithm(self, parameters, context, feedback):
    self.FETCH_TYPE = "raster"
    self.setFetchArea(parameters,context,feedback,QgsCoordinateReferenceSystem("EPSG:6668"))
    self.setTileMapArgs(parameters, context, feedback, additional_args={"N_RASTER_BANDS": 1, "RASTER_DTYPE": gdal.GDT_Float32})
    
    self.fetchFeaturesFromTile(parameters, context, feedback)
    
    
    # clip the raster because it is too large as a point vector
    transform = QgsCoordinateTransform(self.FETCH_AREA.crs(), self.TILEMAP_ARGS["CRS"], QgsProject.instance())
    clip_area = transform.transformBoundingBox(self.FETCH_AREA)
    
    try: # UNKNOWN error may be occured in the following process
      dem_raster_clipped = processing.run(
        "gdal:cliprasterbyextent", 
        {
          "INPUT": self.TILEMAP_ARGS["OUTPUT_RASTER"],
          "PROJWIN": clip_area,
          "OUTPUT": "TEMPORARY_OUTPUT"
        }
      )["OUTPUT"]
    
    except:
      dem_raster_clipped = self.TILEMAP_ARGS["OUTPUT_RASTER"]
    
    if self.parameterAsBoolean(parameters, "VECTORISE", context):
      dem_raw = processing.run(
        "native:pixelstopoints",
        {
          "INPUT_RASTER": dem_raster_clipped,
          "RASTER_BAND": 1,
          "FIELD_NAME": "alti",
          "OUTPUT": "TEMPORARY_OUTPUT"
        }
      )["OUTPUT"]
      
      # CRS transform    
      dem_transformed = self.transformToTargetCrs(parameters,context,feedback,dem_raw)
      
      # set z value
      dem_z = processing.run(
        "native:setzvalue",
        {
          "INPUT": dem_transformed,
          "Z_VALUE": QgsProperty.fromExpression('"alti"'),
          "OUTPUT": "memory:dem"
        }
      )["OUTPUT"]
      
      # substitute self constant with the fetched vector layer
      dem_final = dem_z    
      
      (sink, dest_id) = self.parameterAsSink(
        parameters, "OUTPUT", context,
        dem_final.fields(), dem_final.wkbType(), dem_final.sourceCrs()
      )
      sink.addFeatures(dem_final.getFeatures())
    else:
      dest_id = None
      
    return {"OUTPUT": dest_id, "OUTPUT_RASTER": self.TILEMAP_ARGS["OUTPUT_RASTER"]}   
    
  
  
  def writeOutputRaster(self, tx, ty, ras_tile) -> None:
    if ras_tile is not None:
      ras_data = np.array(ras_tile.ReadAsArray(), dtype = "float32")
      alt = np.select(
        [ras_data[0] < 128, ras_data[0] > 128], 
        [
          0.01 * (65536 * ras_data[0] + 256 * ras_data[1] + ras_data[2]), 
          0.01 * (65536 * ras_data[0] + 256 * ras_data[1] + ras_data[2] - 16777216)
        ],
        np.nan
      )
    else:
      alt = np.full((self.TILEMAP_ARGS["TILEMAP_PIXELS"], self.TILEMAP_ARGS["TILEMAP_PIXELS"]), np.nan)
      
      
    x_origin = (tx - self.TILEMAP_ARGS["XMIN"]) * self.TILEMAP_ARGS["TILEMAP_PIXELS"]
    y_origin = (ty - self.TILEMAP_ARGS["YMIN"]) * self.TILEMAP_ARGS["TILEMAP_PIXELS"]
    alt = alt.astype(np.float32)
    
    self.FETCH_FEATURE.GetRasterBand(1).WriteArray(alt, xoff=x_origin, yoff=y_origin)
  
  def postProcessAlgorithm(self, context, feedback):
    return {}

  def displayName(self):
    return self.tr("Elevation points (Ja raster)")

  def group(self):
    return self.tr('Fetch geometries (Ja)')

  def groupId(self):
    return 'fetchgeomja'

  def createInstance(self):
    return fetchdemrasterja()
