from qgis.core import (
  QgsCoordinateReferenceSystem,
  QgsReferencedRectangle,
  QgsReferencedPointXY,
  QgsCoordinateTransform,
  QgsProject,
  QgsRectangle,
  QgsCoordinateTransformContext
  )

import itertools
from typing import Callable, Union, Tuple
import numpy as np


class HrTile:
  
  X_ORIG = None
  Y_ORIG = None
  X_UNIT = None
  Y_UNIT = None
  CODING_FUNC = None #x, y to code
  REV_CODING_FUNC = None # code to xc, yc
  
  def __init__(
      self, 
      xunit: float, 
      yunit: float, 
      xorig: float, 
      yorig: float,
      coding_func: Callable[[float, float], str] = None, 
      rev_coding_func: Callable[[str], Union[float, float]] = None
  ):
    

    assert xunit is not None, "The xunit must be given."
    assert yunit is not None, "The yunit must be given."
    
    self.X_ORIG = xorig
    self.Y_ORIG = yorig
    self.X_UNIT = xunit
    self.Y_UNIT = yunit

    def coding(long:float, lat:float) -> str:
      xidx, yidx = self.cellXyIdx(long, lat)[0]
      return f"{xidx}_{yidx}"

    if coding_func is None:
      self.CODING_FUNC = coding
    else:
      self.CODING_FUNC = coding_func
    
    def rev_coding(code:str) -> Union[float, float]:
      xy_idx = code.split("_")
      x, y = self.cellXyCenter(int(xy_idx[0]), int(xy_idx[1]))
      return x, y

    if rev_coding_func is None:
      self.REV_CODING_FUNC = rev_coding
    else:
      self.REV_CODING_FUNC = rev_coding_func

  #-------------------------------
  
  def unitLength(self) -> Union[float, float]:
    return (self.X_UNIT, self.Y_UNIT)
  
  def origin(self) -> Union[float, float]:
    return (self.X_ORIG, self.Y_ORIG)
  
  #-------------------------------
  
  def cellXyIdx(
    self, 
    code_or_x:Union[str, float],
    y:float = None
    ) -> list[Tuple[int, int]]:
    """
    Returns the cell indices (x, y) for a given value.
    
    Parameters:
    code_or_x (str, float): The value to convert to cell indices.
    y (float): Optional. The y value.
    
    Returns:
    list[Tuple[int, int]]: The mesh index for the given cell.
    
    Raises:
    TypeError: If the value is not of a supported type.
    """
    if isinstance(code_or_x, str):
      return self._cellXyIdx_str(code_or_x)
    elif isinstance(code_or_x, float):
      if isinstance(y, float):
        return self._cellXyIdx_float(code_or_x, y)
      else:
        raise TypeError("The float y value must be given.")
    else:
      raise TypeError(f"Unsupported type: {type(code_or_x).__name__}. It must be either str or float.")
  
  def _cellXyIdx_float(self, x:float, y:float) -> list[Tuple[int, int]]:
    xidx = int((x - self.X_ORIG) / self.X_UNIT)
    yidx = int((y - self.Y_ORIG) / self.Y_UNIT)
    return [(xidx, yidx)]
  
  def _cellXyIdx_str(self, code:str) ->list[Tuple[int, int]]:
    x, y = self.REV_CODING_FUNC(code)
    return self.cellXyIdx(x, y)
  
  #-------------------------------
  
  def cellXyCenter(
    self, 
    code_or_idx_or_x: Union[str, int, float], 
    idx_or_y: Union[int, float] = None,
    ) -> list[tuple[float, float]]:
    """
    Returns the center coordinates (x, y) for a given cell index or extent.
    
    Parameters:
    code_or_idx_or_x (str, int or float): The cell code, cell index (x and y) or extent to convert to center coordinates.
    idx_or_y (int or float): Optional. The y index or value.
    
    Returns:
    Tuple[float, float] or Tuple[Tuple(float), Tuple(float)]: The center coordinates of the cell.
    
    Raises:
    TypeError: If the value is not of a supported type.
    """
    if isinstance(code_or_idx_or_x, str):
      return self._cellXyCenter_str(code_or_idx_or_x)
    elif isinstance(code_or_idx_or_x, int):
      if isinstance(idx_or_y, int):
        return self._cellXyCenter_int(code_or_idx_or_x, idx_or_y)
      else:
        raise TypeError("Integer y index must be given.")
    elif isinstance(code_or_idx_or_x, float):
      if isinstance(idx_or_y, float):
        return self._cellXyCenter_float(code_or_idx_or_x, idx_or_y)
      else:
        raise TypeError("Float y value must be given.")
    else: 
      raise TypeError(
        f"Unsupported type: {type(code_or_idx_or_x).__name__}." +
        "The first argument must be either str, int or float."
        )
  

  def _cellXyCenter_float(self, x:float, y:float) -> list[tuple[float, float]]:
    xyidx = self._cellXyIdx_float(x, y)[0]
    return [(
      self.X_ORIG + self.X_UNIT * xyidx[0] + self.X_UNIT / 2, 
      self.Y_ORIG + self.Y_UNIT * xyidx[1] + self.Y_UNIT / 2
      )]
  
  def _cellXyCenter_int(self, xidx:int, yidx:int) -> list[tuple[float, float]]:
    return [(
      self.X_ORIG + self.X_UNIT * xidx + self.X_UNIT / 2, 
      self.Y_ORIG + self.Y_UNIT * yidx + self.Y_UNIT / 2
      )]
  
  def _cellXyCenter_str(self, code:str) -> list[tuple[float, float]]:
    return [self.REV_CODING_FUNC(code)]
  
  #-------------------------------
  
  def cellMeshCode(
    self, 
    idx_or_x:Union[int, float],
    idx_or_y:Union[int, float]
    ) -> list[str]:
    """
    Returns the mesh code for the given coordinates, cell indices, or extent.
    
    Parameters:
    idx_or_x (int or float): The cell code, cell index (x and y) or extent to convert to center coordinates.
    idx_or_y (int or float): The y index or value.
        
    Returns:
    list[str]: The mesh code for the given coordinates, cell indices, or extent.
    
    Raises:
    TypeError: If the value is not of a supported type.
    """    
    if isinstance(idx_or_x, int):
      if isinstance(idx_or_y, int):
        return self._cellMeshCode_int(idx_or_x, idx_or_y)
      else:
        raise TypeError("Integer y index must be given.")
    elif isinstance(idx_or_x, float):
      if isinstance(idx_or_y, float):
        return self._cellMeshCode_float(idx_or_x, idx_or_y)
      else:
        raise TypeError("Float y value must be given.")
    else: 
      raise TypeError("Unsupported type: value must be either int or float.")
  
  def _cellMeshCode_int(self, xidx:int, yidx:int) -> list[str]:
    xyc = self.cellXyCenter(xidx, yidx)[0]
    return [self.CODING_FUNC(xyc[0], xyc[1])]

  def _cellMeshCode_float(self, x:float, y:float) -> list[str]:
    return [self.CODING_FUNC(x, y)]
    
    
  #-------------------------------
    
  def cellRect(
    self, 
    code_or_idx_or_x:Union[str, int, float],
    idx_or_y:Union[int, float] = None
    ) -> list[dict]:    
    """
    Returns the mesh rectangle or extent.
    
    Parameters:
    code_or_idx_or_x (str, int or float): The cell code, cell index (x and y) or extent to convert to center coordinates.
    idx_or_y (int or float): Optional. The y index or value.
    kwargs (dict): The keyword arguments.
    
    Returns:
    list[dict] 
        
    Raises:
    TypeError: If the value is not of a supported type.
    """    
    
    xyc = self.cellXyCenter(code_or_idx_or_x, idx_or_y)[0]
    return [{
      "xmin": xyc[0] - 0.5 * abs(self.X_UNIT), 
      "ymin": xyc[1] - 0.5 * abs(self.Y_UNIT), 
      "xmax": xyc[0] + 0.5 * abs(self.X_UNIT), 
      "ymax": xyc[1] + 0.5 * abs(self.Y_UNIT)
    }]
                

class HrQgsTile(HrTile):
  CRS = None  # Coordinate Reference System for the tile
  
  def __init__(
      self, 
      crs: QgsCoordinateReferenceSystem,
      xunit: float, 
      yunit: float, 
      xorig: float = None, 
      yorig: float = None,
      coding_func: Callable[[float, float], str] = None, 
      rev_coding_func: Callable[[str], Union[float, float]] = None
  ):
    
    assert xunit is not None, "The xunit must be given."
    assert yunit is not None, "The yunit must be given."
    
    self.CRS = crs
    
    if crs.isGeographic() and xorig is None and yorig is None:
      xorig = -180.0
      yorig = 90.0
    
      
    super().__init__(
      xunit = xunit, 
      yunit = yunit, 
      xorig = xorig, 
      yorig = yorig, 
      coding_func = coding_func, 
      rev_coding_func = rev_coding_func
      )  
  
  #-------------------------------
  
  def cellXyIdx(
    self, 
    code_or_x_or_rect:Union[str, float, QgsReferencedRectangle],
    y:float = None,
    only_edge:bool = False
    ) -> list[Tuple[int, int]]:
    """
    Returns the cell indices (x, y) for a given value.
    
    Parameters:
    code_or_x_or_rect (str, float, QgsReferencedRectangle): The value to convert to cell indices.
    y (float): Optional. The y value.
    only_edge (bool): Optional. If True, only the edge cells are returned.
    
    Returns:
    list[Tuple[int, int]]: The mesh index for the given cell.
    
    Raises:
    TypeError: If the value is not of a supported type.
    """
    if isinstance(code_or_x_or_rect, QgsReferencedRectangle):
      extent = code_or_x_or_rect
      assert extent.crs() == self.CRS, "The CRS of the extent must be the same as the tile CRS."
      
      xidx_edge = tuple(int((x - self.X_ORIG) / self.X_UNIT) for x in (extent.xMinimum(), extent.xMaximum()))
      yidx_edge = tuple(int((y - self.Y_ORIG) / self.Y_UNIT) for y in (extent.yMinimum(), extent.yMaximum()))
      
      xy_idxs = []
      if only_edge:
        for xidx, yidx in itertools.product(xidx_edge, yidx_edge):
          xy_idxs.append((xidx, yidx))
      else:
        xidxs = range(min(xidx_edge), max(xidx_edge) + 1)
        yidxs = range(min(yidx_edge), max(yidx_edge) + 1)
        for xidx, yidx in itertools.product(xidxs, yidxs):
          xy_idxs.append((xidx, yidx))
      return xy_idxs
    
    else:
      return super().cellXyIdx(code_or_x_or_rect, y)
    
  #-------------------------------
  
  def cellXyCenter(
    self, 
    code_idx_x_or_rect:Union[str, int, float, QgsReferencedRectangle], 
    idx_or_y:Union[int, float] = None,
    only_edge:bool = False, 
    as_QgsPoint:bool = False,
    destination_crs:QgsCoordinateReferenceSystem = None,
    ) -> list[Tuple[float, float]]:
    """
    Returns the center coordinates (x, y) for a given cell index or extent.
    
    Parameters:
    code_idx_x_or_rect (str, int, float or QgsReferencedRectangle): The cell code, cell index (x and y) or extent to convert to center coordinates.
    idx_or_y (int or float): Optional. The y index or value.
    only_edge (bool): Optional. If True, only the edge cells are returned.
    as_QgsPoint (bool): Optional. If True, the result is returned as QgsPointXY.
    destination_crs (QgsCoordinateReferenceSystem): Optional. The destination CRS for the transformation.
    
    Returns:
    list[Tuple[float, float]]: The center coordinates of the cell.
    
    Raises:
    TypeError: If the value is not of a supported type.
    """
    
    if not isinstance(code_idx_x_or_rect, QgsReferencedRectangle):
      return super().cellXyCenter(code_idx_x_or_rect, idx_or_y)

    # code_idx_x_or_rect is QgsReferencedRectangle
    extent = code_idx_x_or_rect
    assert extent.crs() == self.CRS, "The CRS of the extent must be the same as the tile CRS."

    xc_lwr, yc_lwr = super().cellXyCenter(extent.xMinimum(), extent.yMinimum())[0]
    xc_upr, yc_upr = super().cellXyCenter(extent.xMaximum(), extent.yMaximum())[0]
    
    if only_edge:      
      xys = list(itertools.product((xc_lwr,xc_upr), (yc_lwr,yc_upr)))
    else: 
      xys = list(itertools.product(
        np.arange(xc_lwr, xc_upr + 0.5 * abs(self.X_UNIT), abs(self.X_UNIT)), 
        np.arange(yc_lwr, yc_upr + 0.5 * abs(self.Y_UNIT), abs(self.Y_UNIT))
      ))

    return xys
        

  
  #-------------------------------
  
  def cellMeshCode(
    self, 
    idx_or_x_or_rect:Union[str, int, float, QgsReferencedRectangle],
    idx_or_y:Union[int, float] = None
    ) -> list[str]:
    """
    Returns the mesh code for the given coordinates, cell indices, or extent.
    
    Parameters:
    idx_or_x_or_rect (int, float or QgsReferencedRectangle): The cell code, cell index (x and y) or extent to convert to center coordinates.
    idx_or_y (int or float): Optional. The y index or value.
    
    Returns:
    list[str]: The mesh code for the given coordinates, cell indices, or extent.
    
    Raises:
    TypeError: If the value is not of a supported type.
    """    
    if isinstance(idx_or_x_or_rect, QgsReferencedRectangle):
      xys = self.cellXyCenter(idx_or_x_or_rect)
      return [self.cellMeshCode(xy[0], xy[1])[0] for xy in xys]
    else:
      return super().cellMeshCode(idx_or_x_or_rect, idx_or_y)
         
    
  #-------------------------------
    
  def cellRect(
    self, 
    code_or_idx_or_x_or_rect:Union[str, int, float, QgsReferencedRectangle], 
    idx_or_y:Union[int, float] = None,
    dissolve:bool = False,
    as_extent:bool = False
    ) -> list[QgsReferencedRectangle]:
    """
    Returns the mesh code for the given coordinates, cell indices, or extent.
    
    Parameters:
    code_or_idx_or_x_or_rect (str, float, int or QgsReferencedRectangle): The cell code, cell index (x and y) or extent to convert to center coordinates.
    idx_or_y (int or float): Optional. The y index or value.
    dissolve (bool): Optional. If True, the cells are dissolved to the smallest number of cells.
    as_extent (bool): Optional. If True, the result is returned as Extent string.
    
    Returns:
    list[QgsReferencedRectangle]
        
    Raises:
    TypeError: If the value is not of a supported type.
    """    
    if isinstance(code_or_idx_or_x_or_rect, QgsReferencedRectangle):
    
      xycs = self.cellXyCenter(code_or_idx_or_x_or_rect)
      
      if dissolve:
        xcs = [xyc[0] for xyc in xycs]
        ycs = [xyc[1] for xyc in xycs]
        rect_dicts = [{
          "xmin": min(xcs) - 0.5 * abs(self.X_UNIT), 
          "ymin": min(ycs) - 0.5 * abs(self.Y_UNIT), 
          "xmax": max(xcs) + 0.5 * abs(self.X_UNIT), 
          "ymax": max(ycs) + 0.5 * abs(self.Y_UNIT)
        }]
      else:
        rect_dicts = [{
          "xmin": xyc[0] - 0.5 * abs(self.X_UNIT), 
          "ymin": xyc[1] - 0.5 * abs(self.Y_UNIT), 
          "xmax": xyc[0] + 0.5 * abs(self.X_UNIT), 
          "ymax": xyc[1] + 0.5 * abs(self.Y_UNIT)
        } for xyc in xycs]

    else:
      rect_dicts = super().cellRect(code_or_idx_or_x_or_rect, idx_or_y)
    
    if as_extent:
      rects = [
        f'{rect_dict["xmin"]},'+
        f'{rect_dict["xmax"]},'+
        f'{rect_dict["ymin"]},'+
        f'{rect_dict["ymax"]} '+
        f'[{self.CRS.authid()}]'
        for rect_dict in rect_dicts
      ]
      
    else:
      rects = [
        QgsReferencedRectangle(
          QgsRectangle(
            rect_dict["xmin"], 
            rect_dict["ymin"], 
            rect_dict["xmax"], 
            rect_dict["ymax"]
          ),
          self.CRS
        ) for rect_dict in rect_dicts
      ]
    return rects


  def transformPoint(
    self, 
    pnt:QgsReferencedPointXY, 
    source_crs:QgsCoordinateReferenceSystem = None, 
    destination_crs:QgsCoordinateReferenceSystem = None
    ) -> QgsReferencedPointXY:
    source_crs = pnt.crs() if source_crs is None else source_crs
    destination_crs = self.CRS if destination_crs is None else destination_crs
    if source_crs == destination_crs:
      return pnt
    # Note that eather CRS argument of QgsCoordinateTransform must be the CRS of the QgsProject
    tr1 = QgsCoordinateTransform(source_crs, self.QGS_PROJECT.crs(), self.QGS_PROJECT)
    pnt_tr1 = QgsReferencedPointXY(tr1.transform(pnt), self.QGS_PROJECT.crs())
    tr2 = QgsCoordinateTransform(self.QGS_PROJECT.crs(), destination_crs, self.QGS_PROJECT)
    pnt_tr2 = tr2.transform(pnt_tr1)
    return pnt_tr2
    

class WebMercatorTile(HrQgsTile):
  
  EQUATOR_LENGTH = 40075016.68557849
  def __init__(self, zoom:int, qgs_project_file:str = None):
    
    assert zoom is not None, "The zoom level must be given."
    assert zoom in range(0, 21), "The zoom level must be in the range of 0 to 20."
    
    self.CRS = QgsCoordinateReferenceSystem("EPSG:3857")
        
    super().__init__(
      crs = QgsCoordinateReferenceSystem("EPSG:3857"), 
      xunit = self.EQUATOR_LENGTH / 2 ** zoom, yunit = -self.EQUATOR_LENGTH / 2 ** zoom, 
      xorig = -self.EQUATOR_LENGTH / 2, yorig = self.EQUATOR_LENGTH / 2,
      qgs_project_file=qgs_project_file
    )

