import warnings
import numpy as np
import processing  # this is the qgis processing stuff
from datetime import datetime

from qgis.core import QgsProject
from qgis.core import QgsVectorLayer, QgsRasterLayer
from qgis.core import QgsField
from qgis.core import QgsFeature
from qgis.core import NULL
from qgis.core import QgsProcessing
from qgis.core import QgsProcessingAlgorithm
from qgis.core import QgsGeometry
from qgis.core import QgsCoordinateReferenceSystem
from qgis.core import Qgis
from qgis.core import QgsVectorLayerUtils

from .dataBridge import dataBridge


class preprocessorCalculator:
    def __init__(self):
        self._dataBridge = None
        self._combinedExtent = None
        self._pixelSize = (
            None 
        )

        self._layers = None  # raw layers from the input
        self._rasterLayers = []  # raster layers that are to be combined
        self._rasterStats = {}

        self._layerFields = None

        self._timestamp = None

    def importFromDataBridge(self, db: dataBridge):
        """
        Basic function to get specifications from the user via the dataBridge.
        """

        self._layers = db.getLayers()
        self._combinedExtent = db.findSuperExtent()
        self._pixelSize = db.getGridResolution() 

        self._dataBridge = db

        self._timestamp = datetime.now().strftime("%Y-%m-%d-%H%M")

    def makeRasters_point(
        self,
        layer: QgsVectorLayer,
        radius: float,
        pixel_size: float,
        use_field: bool = False,
        field_name: str = None,
        layer_name: str = "points",
        advanced=False,
    ):
        """
        Helper function for making raster layer for points layers.

        Currently, "advanced" cannot be set to True by the user,
        but I implemented that functionality first so the code is available and retained for posterity.

        """
        if advanced:
            # check whether there's a specific field whose values we want to use
            if not use_field:
                field_name = None
            temp_dict = processing.run(
                "qgis:heatmapkerneldensityestimation",
                {
                    "INPUT": layer,
                    "RADIUS": radius,
                    "WEIGHT_FIELD": field_name,
                    "PIXEL_SIZE": pixel_size,
                    "OUTPUT_VALUE": 1,  # scale the output values.
                    "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
                },
            )
            
            temp_layer = QgsRasterLayer(temp_dict["OUTPUT"], layer_name)
            temp_layer.setCrs(
                layer.crs()
            )  # this step is REALLY IMPORTANT because the raster layers have no built-in CRS information.
            # without this, the virtual raster layer downstream will fail with a useless error message.
            # this message brought to you by about 10 hours of attempting to troubleshoot.
            return temp_layer

        else:  # use plain buffer method
            # step 1: clean up the geometry: If the geometry of a line layer is not "fixed" in advanced and has overlaps, the buffer algorithm appears to
            # experience some kind of segfault or other error that crashes GIS. So we do that first just in case.
            fixed_dict = processing.run(
                "native:fixgeometries",
                {"INPUT": layer, "METHOD": 1, "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT},
            )
            fixed_layer = fixed_dict["OUTPUT"]

            # step 2: run buffering
            buffered_dict = processing.run(
                "native:buffer",
                {
                    "INPUT": fixed_layer,
                    "DISTANCE": radius,
                    "SEGMENTS": 5,  # this is the default anyway
                    "END_CAP_STYLE": 0,  # rounded
                    "JOIN_STYLE": 0,  # rounded
                    "DISSOLVE_RESULT": 0,  # overlapping features will not be combined
                    "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
                },
            )

            buffered_layer = buffered_dict["OUTPUT"]
            # step 3: rasterize

            rasterized_layer = self.makeRasters_polygon(
                buffered_layer,
                pixel_size=pixel_size,
                use_field=use_field,
                field_name=field_name,
                layer_name=layer_name,
            )

            return rasterized_layer

    def makeRasters_line(
        self,
        layer,
        radius: float,
        pixel_size: float,
        use_field: bool = False,
        field_name: str = None,
        layer_name: str = "line",
        advanced=False,
    ):
        """
        Helper function for making raster layer for line layers.

        Currently, "advanced" cannot be set to True by the user,
        but I implemented that functionality first so the code is available and retained for posterity/future development.



        """
        if advanced:
            # check whether there's a specific field whose values we want to use
            if not use_field:
                field_name = None
            temp_dict = processing.run(
                "native:linedensity",
                {
                    "INPUT": layer,
                    "WEIGHT": field_name,
                    "RADIUS": radius,
                    "PIXEL_SIZE": pixel_size,
                    "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
                },
            )

            
            temp_layer = QgsRasterLayer(temp_dict["OUTPUT"], layer_name)
            temp_layer.setCrs(
                layer.crs()
            )  # this step is REALLY IMPORTANT because the raster layers have no built-in CRS information.
            # without this, the virtual raster layer downstream will fail with a useless error message.
            # this message brought to you by about 10 hours of attempting to troubleshoot.
            return temp_layer

        else:  # not advanced
            fixed_dict = processing.run(
                "native:fixgeometries",
                {"INPUT": layer, "METHOD": 1, "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT},
            )

            fixed_layer = fixed_dict["OUTPUT"] 
 

            # step 2: run buffering
            buffered_dict = processing.run(
                "native:buffer",
                {
                    "INPUT": fixed_layer,
                    "DISTANCE": radius,
                    "SEGMENTS": 5,  # this is the default anyway
                    "END_CAP_STYLE": 0,  # rounded
                    "JOIN_STYLE": 0,  # rounded
                    "DISSOLVE_RESULT": 0,  # overlapping features will not be combined
                    "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
                },
            )

            buffered_layer = buffered_dict["OUTPUT"]  
          

            # step 3: rasterize

            rasterized_layer = self.makeRasters_polygon(
                buffered_layer,
                pixel_size=pixel_size,
                use_field=use_field,
                field_name=field_name,
                layer_name=layer_name,
            )

            return rasterized_layer

    def makeRasters_polygon(
        self,
        layer,
        pixel_size: float,
        use_field: bool = False,
        field_name: str = None,
        layer_name: str = "polygon",
    ):
        """
        Helper function for rasterizing a polygon layer.

        """

        fixed_dict = processing.run(
            "native:fixgeometries",
            {"INPUT": layer, "METHOD": 1, "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT},
        )

        fixed_layer = fixed_dict["OUTPUT"]

        # check whether there's a specific field whose values we want to use
        if not use_field:
            field_name = None

        temp_dict = processing.run(
            "gdal:rasterize",
            {
                "INPUT": fixed_layer,
                "FIELD": field_name, # None if not provided
                "BURN": 1,
                "USE_Z": False,
                "UNITS": 1,  # 1 is georeferenced units. 0 is pixel units.
                "WIDTH": pixel_size,
                "HEIGHT": pixel_size,
                "EXTENT": self.getExtent(),
                "NODATA": 0,  # what is the value for not having data
                "OPTIONS": "",
                "DATA_TYPE": 5,  # float32 is 5
                "INIT": 1e-6,  # pre-initialize the output with a specific value
                "INVERT": False,
                "EXTRA": "",
                "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
            },
        )
        temp_layer = QgsRasterLayer(temp_dict["OUTPUT"], layer_name)
        temp_layer.setCrs(
            layer.crs()
        )  # this step is REALLY IMPORTANT because the raster layers have no built-in CRS information.
        # without this, the virtual raster layer downstream will fail with a useless error message.
        # this message brought to you by about 10 hours of attempting to troubleshoot.

        return temp_layer

    def makeRasters(self):
        """
        Convert the layers into raster layers using techniques appropriate to their geometry type.
        Under normal usage:
            point and line layers will be buffered and then rasterized binary (inside resulting buffer or not).
            polygon layers will be rasterized binary.
        Advanced options (not currently in the gui but implemented):
            Point layers will be converted into heatmaps using KDE
            Lines will be converted into heatmaps using Line Density.
            Polygon layers will be rasterized.

        
        """
        for layer in self.getRawLayers():
            if type(layer) == QgsRasterLayer:
                # raster layers don't need anything that we won't do later
                self.addRasterLayer(layer)
                self._rasterStats[layer] = processing.run(
                    "native:rasterlayerstatistics",
                    {
                        "INPUT": layer,
                        "BAND": 1,
                        # allow it to do a temporary file
                    },
                )

            else:
                # we need to find out what type of geometries are in this layer
                geomtype = layer.geometryType()
                distunitname = Qgis.DistanceUnit(
                    layer.crs().mapUnits()
                ).name  # doing it this way to try to be robust to enum changes

                pixel_size = self.getPixelSize()

                if geomtype.name == "Point":  # point
                    temp_layer = self.makeRasters_point(
                        layer=layer,
                        pixel_size=pixel_size,
                        radius=self.getDataBridge().getLayerRadiusByLayerName(
                            layer.name()
                        ),
                        use_field=self.getDataBridge().getIfUsingFieldByLayerName(
                            layer.name()
                        ),
                        field_name=self.getDataBridge().getFieldofLayerByLayerName(
                            layer.name()
                        ),
                        layer_name=layer.name(),
                        advanced=False,
                    )

                elif geomtype.name == "Line":  # line
                    temp_layer = self.makeRasters_line(
                        layer=layer,
                        radius=self.getDataBridge().getLayerRadiusByLayerName(
                            layer.name()
                        ),
                        pixel_size=pixel_size,
                        use_field=self.getDataBridge().getIfUsingFieldByLayerName(
                            layer.name()
                        ),
                        field_name=self.getDataBridge().getFieldofLayerByLayerName(
                            layer.name()
                        ),
                        layer_name=layer.name(),
                        advanced=False,
                    )

                elif geomtype.name == "Polygon":  # polygon: just rasterize it
                    temp_layer = self.makeRasters_polygon(
                        layer,
                        pixel_size=pixel_size,
                        use_field=self.getDataBridge().getIfUsingFieldByLayerName(
                            layer.name()
                        ),
                        field_name=self.getDataBridge().getFieldofLayerByLayerName(
                            layer.name()
                        ),
                        layer_name=layer.name(),
                    )
                else:
                    raise ValueError(
                        f"Layer {layer.name()} uses an unexpected geometry type: {geomtype.name}. Skipping."
                    )

                self.addRasterLayer(temp_layer)
                # collect the raster stats for when we have to calculate the final result
                self._rasterStats[temp_layer] = processing.run(
                    "native:rasterlayerstatistics",
                    {
                        "INPUT": temp_layer,
                        "BAND": 1,
                        # allow it to do a temporary file
                    },
                )
                
    def constructNormalizedWeights(self): 
        """
        Helper function for constructRasterCalculationExpression. Calculates 1-normalized weight for each layer based on the user-supplied weights (1-10).
        Returns dict of form {layer name: normalized weight}
        
        """

        rawWeights = {rasterlayer.name() : self._dataBridge.getInterLayerWeightByLayerName(rasterlayer.name()) for rasterlayer in self.getRasterLayers()}
        wNorm = np.linalg.norm(np.array(list(rawWeights.values())), ord=1)
        normedWeights = {i: rawWeights[i]/wNorm for i in rawWeights}
        
        return normedWeights


    def constructRasterCalculationExpression(self):
        """
        Constructs the string for the virtual raster calculator algorithm's expression parameter. Helper function for calculateCombinedRaster().
        The expression, operating cell-wise, works out to:
            sum of (normalized layer values at that cell, mutliplied by the normalized weight of that layer) for all layers
                
        """
        
        normedWeights = self.constructNormalizedWeights()
        
        # do some error checking on whether one of your raster layers has all the same value - if it does, you 
        # need to throw it out because otherwise you get a divide by 0 error and all values are float min.
        rasterignore = []
        for rasterlayer in self.getRasterLayers():
            if self._rasterStats[rasterlayer]["MIN"] == self._rasterStats[rasterlayer]["MAX"]:
                warnings.warn(f"Layer {rasterlayer.name()} has a raster containing only a single value, which contaminates results, so this layer will be ignored. If this layer started as a points or line layer, it might be worth changing your radius values. If it started as a polygon layer, make sure that you have distinct values within distinct polygons or that it has an irregular shape.")
                rasterignore.append(rasterlayer.name())
        
        tempexp = [
            '("{0}@1"- {1})/(({2}-{1}))*{3}'.format(
                rasterlayer.name(),
                self._rasterStats[rasterlayer]["MIN"],
                self._rasterStats[rasterlayer]["MAX"],
                normedWeights[rasterlayer.name()]
                
            )
            for rasterlayer in self.getRasterLayers() if rasterlayer.name() not in rasterignore
        ]
        tempexp = "+".join(tempexp)
        exp = "(" + tempexp + ")"
        if exp == "()": 
            raise ValueError("No acceptable raster layers were produced. This can happen if all layers' rasters had a single value in the area of interest.")
        return exp
        
        
    def calculateCombinedRaster(self):
        """
        Having rasterized the layers of interest, creates a raster layer that contains the output function on those layers.

        The output layer will be named "ReNCAT-preprocessed-<timestamp>".
        """

        expr = self.constructRasterCalculationExpression()

        layercrs = self.getRasterLayers()[0].crs()

        if layercrs.mapUnits().name in ["Meters", "Feet"]:
            pixel_size = self.getPixelSize()
        else:
            raise ValueError(
                f"crs units not handled! units are {Qgis.DistanceUnit(layercrs.mapUnits()).name}"
            )

        layer_name = f"ReNCAT-preprocessed-{self._timestamp}"

        temp_dict = processing.run(
            "native:virtualrastercalc",
            {
                # "LAYERS": [i.dataProvider().dataSourceUri() for i in self.getRasterLayers()], # this appears to break stuff and results in a useless error message. maintaining for posterity/documentation.
                "LAYERS": self.getRasterLayers(),
                "EXPRESSION": expr,
                "EXTENT": self.getExtent(),
                "CELL_SIZE": pixel_size,
                # "CRS": None,# leaving the crs blank because then the routine uses the first reference layer
                "LAYER_NAME": layer_name,
            },
        )

        return temp_dict["OUTPUT"]

    def findMinimumDesirableRasterValue(self, lay: QgsRasterLayer, db: dataBridge):
        """
        Determines the minimum desirable raster value. This is read in from user inputs.
        If there is no minimum raster value, and all points should be returned, returns None.
        Else returns a float.
        """

        if self._dataBridge.getReturnValueInfo() == "MAX":
            # this operation is being performed not on the individual inputs' raster layer but rather the
            # combined raster layer
            raster_min_val = processing.run(
                "native:rasterlayerstatistics",
                {
                    "INPUT": lay,
                    "BAND": 1,
                    # allow it to do a temporary file
                },
            )["MAX"]

        elif self._dataBridge.getReturnValueInfo() == "ALL":
            raster_min_val = None
        else:
            raster_min_val = self._dataBridge.getReturnValueInfo()

        return raster_min_val



    def makeRasterLayerByDesiredValue(self, lay: QgsRasterLayer, value: float):
        """
        Based on the input value, filters the raster layer
        so that all cells of greater or equal value to that value are retained.

        value may be None, in which case nothing is done and all points should be exported.
        """

        if value is None:  # all points should be returned, not a subset.
            return lay
        else:  # filter out to a subset of points
            filtered = processing.run(
                "native:rastercalc",
                {
                    "LAYERS": [lay],
                    "EXPRESSION": f"('{lay.name()}@1'>= {value}) * '{lay.name()}@1'",
                    "EXTENT": lay.extent(),
                    # "CELLSIZE": None, # leave this blank to use the cell size of the input layer here
                    "CRS": lay.crs(),  # this is redundant but included for completeness
                    "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
                },
            )

            # filtered["OUTPUT"] is going to be a string because of raster layer output properties

            filtered_layer = QgsRasterLayer(
                filtered["OUTPUT"], lay.name() + "_filtered"
            )
            filtered_layer.setCrs(
                lay.crs()
            )  # this step is REALLY IMPORTANT because the raster layers have no built-in CRS information.
            # without this, the virtual raster layer downstream will fail with a useless error message.

            return filtered_layer

    def extractPointsFromRasterResults(self, lay: QgsRasterLayer, value: float):
        """
            Using the "results" layer with the goodness of the different raster areas,
            creates a set of points (as vector) from those pixels, with the values as attributes.

        Returns: points vector layer with points whose value is \geq the input value. If value is None, then returns all points.
        """

        pixel_points = processing.run(
            "native:pixelstopoints",
            {
                "INPUT_RASTER": lay,
                "RASTER_BAND": 1,
                "FIELD_NAME": "value",
                "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
            },
        )
        if value is None:
            pixel_points["OUTPUT"].setName(
                f"ReNCAT-preprocessed-points-{self._timestamp}"
            )
            return pixel_points["OUTPUT"]

        else:
            pixel_points_filter = processing.run(
                "native:extractbyexpression",
                {
                    "INPUT": pixel_points["OUTPUT"],
                    "EXPRESSION": f'"value">= {value}',
                    "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
                },
            )

            pixel_points_filter["OUTPUT"].setName(
                f"ReNCAT-preprocessed-points-{self._timestamp}"
            )
            return pixel_points_filter["OUTPUT"]  # this is a vector layer

    def addXYToPoints(self, lay: QgsVectorLayer):
        """
        Adds x/y coordinates as columns to all output points. This will be in the CRS of the input layer.

        Parameters:
            lay: QgsVectorLayer

        Returns:
            QgsVectorLayer of points, with the attribute table now containing additional x and y columns.

        """

        withY = processing.run(
            "native:fieldcalculator",
            {
                "INPUT": lay,
                "FIELD_NAME": "y",
                "FIELD_TYPE": 0,
                # "FIELD_LENGTH": 10, # this is the default, let it be so
                "FIELD_PRECISION": 6,
                # "NEW_FIELD": True, # this is the default
                "FORMULA": "y(@geometry)",
                "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
            },
        )

        withX = processing.run(
            "native:fieldcalculator",
            {
                "INPUT": withY["OUTPUT"],
                "FIELD_NAME": "x",
                "FIELD_TYPE": 0,
                # "FIELD_LENGTH": 10, # this is the default, let it be so
                "FIELD_PRECISION": 6,
                # "NEW_FIELD": True, # this is the default
                "FORMULA": "x(@geometry)",
                "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
            },
        )
        withX["OUTPUT"].setName(f"ReNCAT-preprocessed-xy-{self._timestamp}")

        return withX["OUTPUT"]

    def convertLayerToGeographic(self, lay: QgsVectorLayer):
        """
        Converts the layer to EPSG:4326 if not already in that CRS.

        Parameters:
            lay: QgsVectorLayer

        Returns:
            QgsVectorLayer in EPSG:4326.
        """

        if lay.crs().authid() == "EPSG:4326":
            return lay
        else:
            reprojected = processing.run(
                "native:reprojectlayer",
                {
                    "INPUT": lay,
                    "TARGET_CRS": "EPSG:4326",
                    "OUTPUT": QgsProcessing.TEMPORARY_OUTPUT,
                },
            )
            return reprojected["OUTPUT"]

    def preparePointsLayerForRencatExport(self, lay: QgsVectorLayer):
        """
        Perform processing of points layer to munge for I/O output.

        Returns:
            tuple of lists, of form (feature ids, latitudes, longitudes)
        """
        # reproject
        reprojectedLayer = self.convertLayerToGeographic(lay)

        # add coordinates in the reprojected layer
        coordinatePoints = self.addXYToPoints(reprojectedLayer)

        # get the data we want to export
        ids = [i.id() for i in coordinatePoints.getFeatures()]
        xvals = QgsVectorLayerUtils.getValues(coordinatePoints, fieldOrExpression="x")[
            0
        ]  # magic index because this function returns a tuple
        yvals = QgsVectorLayerUtils.getValues(coordinatePoints, fieldOrExpression="y")[
            0
        ]
        
        

        return (ids, xvals, yvals)

    # ------------ setters -----------
    def addRasterLayer(self, lay):
        self._rasterLayers.append(lay)

    # -----------getters -------------

    def getRawLayers(self):
        """
        Returns the list of raw layers, which may be raster or vector.
        """
        return self._layers

    def getExtent(self):
        return self._combinedExtent

    def getRasterLayers(self):
        """
        Returns the list of raster layers for final numerical processing
        """
        return self._rasterLayers

    def getPixelSize(self):
        return self._pixelSize

    def getDataBridge(self):
        return self._dataBridge

    def getPointsReturnInfo(self):
        return self.getDataBridge().getPointsReturnInfo()
