from qgis.PyQt.QtCore import (QT_TRANSLATE_NOOP, QCoreApplication)
from qgis.core import (
  QgsCoordinateReferenceSystem,
  QgsProcessingAlgorithm,
  QgsProcessingParameterExtent,
  QgsProcessingParameterDistance,
  QgsProcessingParameterCrs, 
  QgsProcessingParameterFeatureSink,
  QgsProcessingParameterRasterDestination,
  QgsProcessingParameterNumber,
  QgsProcessingParameterBoolean,
  QgsRectangle,
  QgsReferencedRectangle,
  QgsProcessingParameterEnum
  )
from qgis import processing

from osgeo import gdal, osr
import numpy as np
import os

from ..algutil.hriskutil import HrUtil
from ..algutil.hrisktile import WebMercatorTile
from ..algutil.hriskpostprocessor import HrPostProcessor
from ..algutil.hriskvar import PostProcessors

class fetchdemrasterja(QgsProcessingAlgorithm):
  
  PARAMETERS = {  
    "FETCH_EXTENT": {
      "ui_func": QgsProcessingParameterExtent,
      "ui_args":{
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Extent for fetching data")
      }
    },
    "GRID_SIZE": {
      "ui_func": QgsProcessingParameterEnum,
      "ui_args":{
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Grid size to use"),
        "options": ["10m (DEM-10B)", "5m (DEM-5A)", "5m (DEM-5B)", "5m (DEM-5C)"],
        "defaultValue": 0
      }
    },      
    "TARGET_CRS": {
      "ui_func": QgsProcessingParameterCrs,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Target CRS (Cartesian coordinates)")
      }
    },
    "BUFFER": {
      "ui_func": QgsProcessingParameterDistance,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Buffer of the fetch area (using Target CRS)"),
        "defaultValue": 0.0,
        "parentParameterName": "TARGET_CRS"
      }
    },
    "VECTORIZE": {
      "ui_func": QgsProcessingParameterBoolean,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Output elevation points as vector layer"),
        "defaultValue": False
      }
    },
    "MAX_DOWNLOAD": {
      "ui_func": QgsProcessingParameterNumber,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Maximum number of download"),
        "type": QgsProcessingParameterNumber.Integer,
        "defaultValue": 100
      }
    },
    "OUTPUT": {
      "ui_func": QgsProcessingParameterFeatureSink,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Elevation points")
      }
    },
    "OUTPUT_RASTER": {
      "ui_func": QgsProcessingParameterRasterDestination,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("fetchdemrasterja","Elevation raster" )
      }
    }
  }  
  
  FETCH_BASE_URLS = [
    "https://cyberjapandata.gsi.go.jp/xyz/dem_png/{z}/{x}/{y}.png",
    "https://cyberjapandata.gsi.go.jp/xyz/dem5a_png/{z}/{x}/{y}.png",
    "https://cyberjapandata.gsi.go.jp/xyz/dem5b_png/{z}/{x}/{y}.png",
    "https://cyberjapandata.gsi.go.jp/xyz/dem5c_png/{z}/{x}/{y}.png"
  ]
  FETCH_BASE_URL = None
  FETCH_CRS = QgsCoordinateReferenceSystem("EPSG:3857")
  FETCH_AREA_CRS = QgsCoordinateReferenceSystem("EPSG:3857")
  FETCH_ZOOMS = [14,15,15,15]
  FETCH_ZOOM = None
  FETCH_TILE_PIXELS = 256
  FETCH_GEOM_TYPE = "raster"
  FETCH_TILE_FAMILY = "web_mercator"
  
  
  def __init__(self):
    super().__init__()
    self.UTIL = HrUtil(self)
    
  
  def initAlgorithm(self, config):    
    (extent, target_crs) = self.UTIL.getExtentAndCrsUsingCanvas()
    self.UTIL.setDefaultValue("FETCH_EXTENT", extent)
    self.UTIL.setDefaultValue("TARGET_CRS", target_crs.authid())
    self.UTIL.initParameters()

  def processAlgorithm(self, parameters, context, feedback):
    
    self.UTIL.registerProcessingParameters(parameters, context, feedback)
    self.CURRENT_PROCESS = self.UTIL.parseCurrentProcess()
    
    src_idx = self.parameterAsEnum(parameters, "GRID_SIZE", context)
    self.FETCH_BASE_URL = self.FETCH_BASE_URLS[src_idx]
    self.FETCH_ZOOM = self.FETCH_ZOOMS[src_idx]
    
    # get target x-y CRS, to apply the buffer and determine the fetch area
    target_crs = self.parameterAsCrs(parameters, "TARGET_CRS", context)
    
    # check whether the target CRS is x-y coordinates
    self.UTIL.checkCrsAsCartesian(target_crs)
    
    # get the extent, using the target CRS
    fetch_extent = self.parameterAsExtent(
      parameters, "FETCH_EXTENT", context, 
      self.parameterAsCrs(parameters, "TARGET_CRS", context)
      )
    
    # get the buffer
    buffer = self.parameterAsDouble(parameters, "BUFFER",context)
    
    # get the fetch area, using the extent and buffer
    fetch_area = QgsReferencedRectangle(
      QgsRectangle(
        fetch_extent.xMinimum() - buffer,
        fetch_extent.yMinimum() - buffer,
        fetch_extent.xMaximum() + buffer,
        fetch_extent.yMaximum() + buffer
      ),
      target_crs
    )
    
    tile = WebMercatorTile(zoom = self.FETCH_ZOOM)
    
    tiles_list = tile.cellXyIdx(fetch_area)
    
    
    fetch_args = {}
    txmin, txmax, tymin, tymax = (0,0,0,0)
    for i, (tx, ty) in enumerate(tiles_list):
      url_parsed = self.FETCH_BASE_URL.format(
        z = self.FETCH_ZOOM, x = tx, y = ty
      )
      
      fetch_args[f"{i+1}/{len(tiles_list)}"] = {
        "url": url_parsed,
        "zoom": self.FETCH_ZOOM,
        "tx": tx,
        "ty": ty,
        "geom_type": self.FETCH_GEOM_TYPE
      }
      txmin = tx if tx < txmin or i == 0 else txmin
      txmax = tx if tx > txmax or i == 0 else txmax
      tymin = ty if ty < tymin or i == 0 else tymin
      tymax = ty if ty > tymax or i == 0 else tymax
    
    if len(fetch_args) > self.parameterAsInt(parameters, "MAX_DOWNLOAD", context):
      feedback.reportError(self.tr("Too many downloads are required: ") + str(len(fetch_args)))
      raise Exception(self.tr("Too many downloads are required: ") + str(len(fetch_args)))
    
    fetch_results = self.UTIL.downloadFilesConcurrently(args = fetch_args)
    
    
    feedback.pushInfo(self.tr("Raster is being merged..."))
    ras_path = self.parameterAsOutputLayer(parameters, "OUTPUT_RASTER", context)
    
    ras_raw = gdal.GetDriverByName("GTiff").Create(
      ras_path, 
      256 * (txmax - txmin + 1), 
      256 * (tymax - tymin + 1),
      1, gdal.GDT_Float32
    )
    
    tile_length  = tile.unitLength()[0]
    pixel_length = tile_length / 256
    x_orig, y_orig = tile.origin()
      
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(3857)
    ras_raw.SetProjection(srs.ExportToWkt())
    ras_raw.SetGeoTransform([
      x_orig + txmin * tile_length, pixel_length, 0, 
      y_orig - tymin * tile_length, 0, -pixel_length
    ])
    
    for key, file in fetch_results.items():
      try:
        ras = gdal.Open(file)
        ras_data = np.array(ras.ReadAsArray(), dtype = "float32")
        alt_array = np.select(
          [ras_data[0] < 128, ras_data[0] > 128], 
          [
            0.01 * (65536 * ras_data[0] + 256 * ras_data[1] + ras_data[2]), 
            0.01 * (65536 * ras_data[0] + 256 * ras_data[1] + ras_data[2] - 16777216)
          ],
          np.nan
        )
        
        alt_array = alt_array.astype(np.float32)
        
        ras_raw.GetRasterBand(1).WriteArray(
          alt_array, 
          xoff=(fetch_args[key]["tx"] - txmin) * 256, 
          yoff=(fetch_args[key]["ty"] - tymin) * 256
        )
        ras_raw.FlushCache()
        
      except:
        if isinstance(file, Exception):
          msg = f"({key}) {str(file)}"
        else:
          msg = f"({key}) " + self.tr("Importing from the downloaded file was failed.")
        feedback.pushInfo(msg)
      
      ras = None # closing procedure
    ras_raw = None # closing procedure

    
    # clip the raster because it is too large as a point vector    
    if self.parameterAsBoolean(parameters, "VECTORIZE", context):
      
      feedback.pushInfo(self.tr("Raster is being vectorized..."))
      
      try: # UNKNOWN error may be occured in the following process
        dem_raster_clipped = processing.run(
          "gdal:cliprasterbyextent", 
          {
            "INPUT": ras_path,
            "PROJWIN": fetch_area,
            "OUTPUT": "TEMPORARY_OUTPUT"
          },
          context = context,
          is_child_algorithm = True
        )["OUTPUT"]
      
      except:
        dem_raster_clipped = ras_path
    
      dem_raw = processing.run(
        "native:pixelstopoints",
        {
          "INPUT_RASTER": dem_raster_clipped,
          "RASTER_BAND": 1,
          "FIELD_NAME": "alti",
          "OUTPUT": "TEMPORARY_OUTPUT"
        },
        context = context,
        is_child_algorithm = True
      )["OUTPUT"]
      
      # CRS transform    
      dem_transformed = processing.run(
        "native:reprojectlayer", 
        {
          "INPUT": dem_raw,
          "TARGET_CRS": target_crs,
          "OUTPUT": "TEMPORARY_OUTPUT"
        },
        context = context,
        is_child_algorithm = True
      )["OUTPUT"]      
      
      # substitute self constant with the fetched vector layer
      dem_final = processing.run(
        "hrisk:initelevationpoint",{
          "INPUT": dem_transformed,
          "FIELD_USE_AS_HEIGHT": "alti",
          "TARGET_CRS": target_crs,
          "OVERWRITE": True,
          "OUTPUT": "TEMPORARY_OUTPUT"
        },
        context = context,
        is_child_algorithm = True
      )["OUTPUT"]
      
      dem_final_fts = context.getMapLayer(dem_final)
      new_fields = self.UTIL.newFieldsWithHistory(dem_final_fts.fields())
      (sink, dest_id) = self.parameterAsSink(
        parameters, "OUTPUT", context,
        new_fields, dem_final_fts.wkbType(), target_crs
      )
      self.UTIL.addFeaturesWithHistoryToSink(
        sink, dem_final_fts, new_fields,
        current_process=self.CURRENT_PROCESS
      )
      
      
      PostProcessors[dest_id] = HrPostProcessor(
        history=[self.CURRENT_PROCESS]
      )
    else:
      dest_id = None
      
    PostProcessors[ras_path] = HrPostProcessor(
      history=[self.CURRENT_PROCESS], 
      color_args = {"coloring": "single_band_pseudo_color", "theme": "Greens", "opacity": 0.8},
      set_min_to_zero = True
    )
    
    self.UTIL.registerPostProcessAlgorithm(context, PostProcessors)
    
    return {"OUTPUT": dest_id, "OUTPUT_RASTER": os.path.normpath(ras_path).replace(os.path.sep, "/")}   
    
  def name(self):
    return self.__class__.__name__
  
  def displayName(self):
    return self.tr("Elevation points (Ja raster)")

  def group(self):
    return self.tr('Fetch geometries (Ja)')

  def groupId(self):
    return 'fetchgeomja'

  def createInstance(self):
    return fetchdemrasterja()

  # placing here is necessary, when employing pylupdate
  def tr(self, string):
    return QCoreApplication.translate(self.__class__.__name__, string)
