from qgis.PyQt.QtCore import (QT_TRANSLATE_NOOP, QVariant, QCoreApplication)
from qgis.core import (
  QgsProcessing,
  QgsProcessingAlgorithm,
  QgsProcessingParameterFeatureSource,
  QgsProcessingParameterFeatureSink,
  QgsProcessingParameterBoolean,
  QgsProcessingParameterString,
  QgsFeatureRequest,
  QgsProcessingParameterField
  )
from qgis import processing
from ..algutil.hriskutil import HrUtil
from ..algutil.hriskpostprocessor import HrPostProcessor
from ..algutil.hriskvar import PostProcessors

class summarizerisk(QgsProcessingAlgorithm):
  PARAMETERS = {
    "BUILDING": {
      "ui_func": QgsProcessingParameterFeatureSource,
      "ui_args":{
        "description": QT_TRANSLATE_NOOP("summarizerisk","Building layer, with population and level"),
        "types": [QgsProcessing.TypeVectorPolygon]
      }
    },
    "GROUPING": {
      "ui_func": QgsProcessingParameterBoolean,
      "ui_args":{
        "description": QT_TRANSLATE_NOOP("summarizerisk","Do the results need to be grouped?"),
        "defaultValue": True
      }
    },
    "GROUPING_FIELD": {
      "ui_func": QgsProcessingParameterField,
      "ui_args":{
        "description": QT_TRANSLATE_NOOP("summarizerisk","Field to group the data"),
        "parentLayerParameterName": "BUILDING",
        "defaultValue": "LDEN_LAEQ_max"
      }
    },
    "GROUP_BREAKPOINTS": {
      "ui_func": QgsProcessingParameterString,
      "ui_args":{
        "description": QT_TRANSLATE_NOOP("summarizerisk","Break-points of the grouping"),
        "defaultValue": "40, 45, 50, 55, 60, 65, 70"
      }
    },
    "POP_FIELD": {
      "ui_func": QgsProcessingParameterField,
      "ui_args":{
        "description": QT_TRANSLATE_NOOP("summarizerisk","Population field"),
        "parentLayerParameterName": "BUILDING",
        "defaultValue": "pop"
      }
    },
    "OUTPUT": {
      "ui_func": QgsProcessingParameterFeatureSink,
      "ui_args": {
        "description": QT_TRANSLATE_NOOP("summarizerisk","Building group with health risks")
      }
    }
  }
  
  AGGREGATES = [
    # "group":       {"input": None,            "aggregate": "group", "type": 10}, # special aggregation for categorized variable
    # {"name": "nBldg", "input": None, "aggregate": "count", "type": 2},
    # {"name": "areaBldg", "input": None, "aggregate": "sum", "type": 6},
    {"name": "nHA", "input": '"nHA"', "aggregate": "sum", "type": 6},
    {"name": "nHSD", "input": '"nHSD"', "aggregate": "sum", "type": 6},
    {"name": "nPatientIHD", "input": '"nPatientIHD"', "aggregate": "sum",   "type": 6},
    {"name": "nIncIHD", "input": '"nIncIHD"', "aggregate": "sum", "type": 6},
    {"name": "nDeathIHD", "input": '"nDeathIHD"', "aggregate": "sum", "type": 6}
  ]
  
  def __init__(self):
    super().__init__()
    self.UTIL = HrUtil(self)
  
  def initAlgorithm(self, config):
    self.UTIL.initParameters()
  
  # convert the data type number to string
  def getDtypeStr(self, dtype_num, long = False):
    if dtype_num == QVariant.Double:
      return "double precision" if long else "double"
    elif dtype_num == QVariant.Int:
      return "integer" if long else "int"
    elif dtype_num == QVariant.String:
      return "text" if long else "text"
    else:
      return "not num"
    
  # compute the expression for aggregation
  def cmptAggExpr(self, feature, expr_idx, expr_label, feedback):  
    aggregates = []
    flds = feature.fields()
    
    for agg_dict in self.AGGREGATES:  
      name = agg_dict["input"].strip('"')
      # check if the input field exists
      if name not in flds.names():
        continue
        
      # data types of input and output
      input_dtype = self.getDtypeStr(flds.field(flds.indexFromName(name)).type())
      output_dtype = agg_dict["type"]
      
      # check if the aggregation method is the special method: group
      if agg_dict["aggregate"] == "group_idx":
        if expr_idx == "NULL":
          feedback.pushInfo("No grouping is applied")
          continue
        else:
          feedback.pushInfo(f"Grouping using {agg_dict['input']} is applied")
          agg_dict["input"] = expr_idx.replace(f"\"{agg_dict['input']}\"", f"median(\"{agg_dict['input']}\")")
          agg_dict["aggregate"] = "first_value"
      elif agg_dict["aggregate"] == "group_label":
        if expr_label == "NULL":
          feedback.pushInfo("No grouping is applied")
          continue
        else:
          feedback.pushInfo(f"Grouping using {agg_dict['input']} is applied")
          agg_dict["input"] = expr_label.replace(f"\"{agg_dict['input']}\"", f"median(\"{agg_dict['input']}\")")
          agg_dict["aggregate"] = "first_value"
      elif agg_dict["aggregate"] == "sum":
        if input_dtype == "double":
          if output_dtype == "int":
            feedback.pushWarning(f"{agg_dict['name']} int = sum ({agg_dict['input']}) double")
          else:
            feedback.pushInfo(f"{agg_dict['name']} = sum ({agg_dict['input']})")            
        elif input_dtype == "int":
          if output_dtype == "double":
            feedback.pushWarning(f"{agg_dict['name']} double = sum ({agg_dict['input']}) int")
          else:
            feedback.pushInfo(f"{agg_dict['name']} = sum ({agg_dict['input']})")            
        elif input_dtype == "text" or input_dtype == "not num":
          feedback.pushWarning(f"{agg_dict['input']} is not a numeric field and skipped calculating {agg_dict['name']}")
          continue
      
      # set the residual parameters for the aggregation and add to the list
      aggregates.append(
        agg_dict | {
          "delimiter": ".",
          "length": 0,
          "precision": 0,
          "sub_type": 0,
          "type_name": self.getDtypeStr(agg_dict["type"], long = True)
        }
      ) 

    return aggregates
  
  # compute the expression for grouping
  def cmptGroupExpr(self, var_str, grouping_str):
    exp_label = f"CASE WHEN \"{var_str}\" is NULL THEN 'NULL'"
    exp_idx = f"CASE WHEN \"{var_str}\" is NULL THEN -1"
    grp_prev = ""
    for i, grp in enumerate(grouping_str.split(",")):
      try:
        grp = float(grp)
      except:
        raise Exception(self.tr("Invalid value for grouping"))
      exp_idx   += f" WHEN \"{var_str}\" < {grp} THEN {i}"
      if grp_prev == "":
        exp_label += f" WHEN \"{var_str}\" < {grp} THEN '< {grp}'"
      else:
        exp_label += f" WHEN \"{var_str}\" < {grp} THEN '{grp_prev} - {grp}'"
      grp_prev = grp
    
    exp_idx   += f" ELSE {len(grouping_str.split(','))} END"
    exp_label += f" ELSE '>= {grp_prev}' END"
    
    return exp_idx, exp_label

  def processAlgorithm(self, parameters, context, feedback):    
    
    self.UTIL.registerProcessingParameters(parameters, context, feedback)
    self.CURRENT_PROCESS = self.UTIL.parseCurrentProcess()
    
    # recalculate the health risks, if necessary
    bldg_risk = self.parameterAsSource(parameters, "BUILDING", context).materialize(QgsFeatureRequest(), feedback)
    
    pop_field = self.parameterAsString(parameters, "POP_FIELD", context)
    self.AGGREGATES = [{
      "name": "nPop", 
      "input": f'"{pop_field}"', 
      "aggregate": "sum", 
      "type": 6
    }] + self.AGGREGATES
    
    # check if the results need to be grouped
    grouping = self.parameterAsBool(parameters, "GROUPING", context)
    if grouping:
      grouping_field = self.parameterAsString(parameters, "GROUPING_FIELD", context)
      self.AGGREGATES = [
        {
          "name": f'"{grouping_field}_idx"', 
          "input": f'"{grouping_field}"',
          "aggregate": "group_idx",
          "type": 2
        },{
          "name": f'"{grouping_field}"', 
          "input": f'"{grouping_field}"',
          "aggregate": "group_label",
          "type": 10
        }
        ] + self.AGGREGATES
    
    bldg_rfld = processing.run(
      "native:retainfields",
      {
        "INPUT": bldg_risk,
        "FIELDS": list(set([value["input"].strip('"') for value in self.AGGREGATES])),
        "OUTPUT": "TEMPORARY_OUTPUT"
      },
      context = context,
      is_child_algorithm = True
    )["OUTPUT"]
    
    bldg_with_id = processing.run(
      "native:addautoincrementalfield",      
      {
        "INPUT": bldg_rfld,
        "FIELD_NAME": "id",
        "OUTPUT": "TEMPORARY_OUTPUT"
      },
      context = context,
      is_child_algorithm = True
    )["OUTPUT"]
    
    bldg_with_area = processing.run(
      "native:fieldcalculator",      
      {
        "INPUT": bldg_with_id,
        "FIELD_NAME": "areaBldg",
        "FIELD_TYPE": 0,
        "FORMULA": "area($geometry)",
        "OUTPUT": "TEMPORARY_OUTPUT"
      },
      context = context,
      is_child_algorithm = True
    )["OUTPUT"]
    
    bldg_with_area = context.getMapLayer(bldg_with_area)
      
    self.AGGREGATES.append({"name":"nBldg", "input": f'"id"', "aggregate": "count", "type": 2})
    self.AGGREGATES.append({"name":"areaBldg", "input": f'"areaBldg"', "aggregate": "sum", "type": 6})
    
    if grouping:
      # compute the expression for aggregation
      expr_idx, expr_label = self.cmptGroupExpr(grouping_field, self.parameterAsString(parameters, "GROUP_BREAKPOINTS", context))
      aggregates = self.cmptAggExpr(bldg_with_area, expr_idx, expr_label, feedback=feedback)
    else:
      aggregates = self.cmptAggExpr(bldg_risk, "NULL", "NULL", feedback=feedback)
    
    # aggregate the data
    bldg_grouped = processing.run(
      "native:aggregate", 
      {
        "INPUT": bldg_with_area,
        "GROUP_BY": expr_label,
        "AGGREGATES": aggregates,
        "OUTPUT": "TEMPORARY_OUTPUT"
      },
      context=context,
      feedback=feedback,
      is_child_algorithm=True
    )["OUTPUT"]
    
    bldg_grouped = context.getMapLayer(bldg_grouped)
    
    fields_with_values = {
      "HISTORY": {
        "type": QVariant.String, 
        "value": self.CURRENT_PROCESS,
        "append": True
        }
    }
    
    dest_id = self.UTIL.outputVectorLayer(
      vector_layer= bldg_grouped,
      param_sink = "OUTPUT",
      fields_with_values= fields_with_values
    )
    
    PostProcessors[dest_id] = HrPostProcessor(history = [self.CURRENT_PROCESS])
    self.UTIL.registerPostProcessAlgorithm(context, PostProcessors)
          
    return {"OUTPUT": dest_id}
  

  def name(self):
    return self.__class__.__name__

  def displayName(self):
    return self.tr("Summarize Risks")

  def group(self):
    return self.tr("Evaluate health risk")

  def groupId(self):
    return "healthrisk"
  
  def createInstance(self):
    return summarizerisk()

  # placing here is necessary, when employing pylupdate
  def tr(self, string):
    return QCoreApplication.translate(self.__class__.__name__, string)
