# Copyright 2023 Bunting Labs, Inc.

import os
import http.client
import json
from osgeo import gdal, osr
import numpy as np
import ssl
import math

from qgis.core import QgsTask, QgsMapSettings, QgsMapRendererCustomPainterJob, \
    QgsCoordinateTransform, QgsProject, QgsRectangle, Qgis
from qgis.gui import QgsMapToolCapture
from qgis.PyQt.QtGui import QImage, QPainter, QColor
from qgis.PyQt.QtCore import QSize, pyqtSignal

class AutocompleteTask(QgsTask):
    # This task can run in the background of QGIS, streaming results
    # back from the inference server.

    pointReceived = pyqtSignal(tuple)
    # Tuple for (error message, error link or None, error button text or None)
    errorReceived = pyqtSignal(tuple)

    def __init__(self, tracing_tool, vlayer, rlayers, project_crs):
        super().__init__(
            'Bunting Labs AI Vectorizer background task for ML inference',
            QgsTask.CanCancel
        )

        self.tracing_tool = tracing_tool
        self.vlayer = vlayer
        self.rlayers = rlayers
        self.project_crs = project_crs

    def run(self):
        # By default, we zoom out 2.5x from the user's perspective.
        proj_crs_units_per_screen_pixel = 2.5 * (self.tracing_tool.plugin.iface.mapCanvas().extent().width() / self.tracing_tool.plugin.iface.mapCanvas().width())

        # The resolution of a raster layer is defined as the ground distance covered by one pixel
        # of the raster. Therefore, a smaller resolution value means a higher resolution raster.
        mapEpsgCode = self.project_crs.postgisSrid()

        # Assuming self.rlayers is a list of QgsRasterLayer objects
        layers_same_crs = [ rlayer for rlayer in self.rlayers if rlayer.crs() == self.project_crs ]
        intersecting_layers = [ rlayer for rlayer in layers_same_crs if rlayer.extent().contains(self.tracing_tool.vertices[-1]) ]

        # We don't want to upsample on a raster.
        # Find the highest resolution raster below us.
        # Highest resolution raster has the smallest rasterUnitsPerPixelX.
        # Then cap the resolution at that high resolution.
        highest_res_at_pt = min(
            map(lambda rlayer: rlayer.rasterUnitsPerPixelX(), intersecting_layers),
            default=proj_crs_units_per_screen_pixel
        )

        dx = max(proj_crs_units_per_screen_pixel, highest_res_at_pt)
        dy = dx

        if len(self.rlayers) == 0:
            self.errorReceived.emit((
                'No raster layers are loaded. Load a GeoTIFF to use autocomplete.',
                None, None
            ))
            return False

        # Size of the rectangle in the CRS coordinates
        window_size = self.tracing_tool.plugin.settings.value("buntinglabs-qgis-plugin/window_size_px", "1200")
        assert window_size in ["1200", "2500"] # Two allowed sizes

        img_width, img_height = int(window_size), int(window_size)
        x_size = img_width * dx
        y_size = img_height * dy

        if x_size <= 0 or y_size <= 0:
            self.errorReceived.emit((
                'Could not render an image from the rasters (this is a plugin bug!).',
                'https://github.com/BuntingLabs/buntinglabs-qgis-plugin/issues/new',
                'Report Bug'
            ))
            return False

        # i = y, j = x
        # note that negative i (or y) is up
        x0, y0 = self.tracing_tool.vertices[-2]
        x1, y1 = self.tracing_tool.vertices[-1]
        cx, cy = (x0+x1)/2, (y0+y1)/2

        x_min = cx - x_size / 2
        x_max = cx + x_size / 2
        y_min = cy - y_size / 2
        y_max = cy + y_size / 2

        # create image
        # Format_RGB888 is 24-bit (8 bits each) for each color channel, unlike
        # Format_RGB32 which by default has 0xff on the alpha channel, and screws
        # up reading it into GDAL!
        img = QImage(QSize(img_width, img_height), QImage.Format_RGB888)

        # white is most canonically background
        color = QColor(255, 255, 255)
        img.fill(color.rgb())

        mapSettings = QgsMapSettings()

        mapSettings.setDestinationCrs(self.project_crs)
        mapSettings.setLayers(self.rlayers)

        rect = QgsRectangle(x_min, y_min, x_max, y_max)
        mapSettings.setExtent(rect)
        mapSettings.setOutputSize(img.size())

        p = QPainter()
        p.begin(img)
        p.setRenderHint(QPainter.Antialiasing)

        render = QgsMapRendererCustomPainterJob(mapSettings, p)
        render.start()
        render.waitForFinished()
        p.end()

        try:
            # Convert QImage to np.array
            ptr = img.bits()
            ptr.setsize(img.height() * img.width() * 3)
            img_np = np.frombuffer(ptr, np.uint8).reshape((img.height(), img.width(), 3))

            # Call the function to convert the image to a geotiff tif and save it as bytes
            tif_data = georeference_img_to_tiff(img_np, mapEpsgCode, x_min, y_max, x_max, y_min)

            i0 = int((y0 - y_max) / dy) * -1
            j0 = int((x0 - x_min) / dx)

            i1 = int((y1 - y_max) / dy) * -1
            j1 = int((x1 - x_min) / dx)

        except Exception as e:
            self.errorReceived.emit((str(e), None, None))
            return False

        vector_payload = json.dumps({
            'coordinates': [[i0, j0], [i1, j1]]
        })

        options_payload = json.dumps({
            'num_completions': self.tracing_tool.num_completions,
            'qgis_version': Qgis.QGIS_VERSION,
            'plugin_version': self.tracing_tool.plugin.plugin_version,
            'proj_epsg': mapEpsgCode,
            'is_polygon': self.tracing_tool.mode() == QgsMapToolCapture.CapturePolygon,
            # Rasters can be at all sorts of resolutions, but the current zoom level of
            # the QGIS window gives us a hint as to the best zoom to autocomplete with.
            "resolution_units_per_pixel": self.tracing_tool.plugin.iface.mapCanvas().extent().width() / self.tracing_tool.plugin.iface.mapCanvas().width(),
            "proj_crs_units_per_screen_pixel": proj_crs_units_per_screen_pixel,
            "highest_res_at_pt": highest_res_at_pt,
            "dist_pixels_between_points": math.sqrt((i0-i1)**2 + (j0-j1)**2)
        })

        boundary = 'wL36Yn8afVp8Ag7AmP8qZ0SA4n1v9T'
        body = (
            '--' + boundary,
            'Content-Disposition: form-data; name="image"; filename="rendered.tif"',
            'Content-Type: application/octet-stream',
            '',
            tif_data,
            '--' + boundary,
            'Content-Disposition: form-data; name="vector"; filename="vector.json"',
            'Content-Type: application/json',
            '',
            vector_payload,
            '--' + boundary,
            'Content-Disposition: form-data; name="options"; filename="options.json"',
            'Content-Type: application/json',
            '',
            options_payload,
            '--' + boundary + '--',
            ''
        )
        body = b'\r\n'.join([part.encode() if isinstance(part, str) else part for part in body])

        headers = {
            'Content-Type': 'multipart/form-data; boundary=' + boundary,
            'x-api-key': self.tracing_tool.plugin.settings.value("buntinglabs-qgis-plugin/api_key", "demo")
        }

        try:
            conn = http.client.HTTPSConnection("qgis-api.buntinglabs.com")
            conn.request("POST", "/v1", body, headers)
            res = conn.getresponse()
            if res.status != 200:
                error_payload = res.read().decode('utf-8')

                try:
                    error_details = json.loads(error_payload)
                    self.errorReceived.emit((
                        error_details.get('message'),
                        error_details.get('link'),
                        error_details.get('link_text')
                    ))
                except json.JSONDecodeError:
                    self.errorReceived.emit((error_payload, None, None))

                return False
        except BrokenPipeError:
            self.errorReceived.emit(('Got BrokenPipeError when trying to connect to inference server', None, None))
            return False
        except ssl.SSLCertVerificationError:
            self.errorReceived.emit(('SSL Certificate Verification Failed when connecting to inference server', None, None))
            return False
        except Exception as e:
            self.errorReceived.emit((f'Error when trying to connect to inference server: {str(e)}', None, None))
            return False

        buffer = ""
        while True:
            # For some reason, read errors with IncompleteRead?
            try:
                chunk = res.read(16)
                if not chunk:
                    break

                buffer += chunk.decode('utf-8')
            except http.client.IncompleteRead as e:
                buffer += e.partial.decode('utf-8')

            while '\n' in buffer:
                if self.isCanceled():
                    return False

                line, buffer = buffer.split('\n', 1)
                new_point = json.loads(line)

                ix, jx = new_point[0], new_point[1]

                # convert to xy
                xn = (jx * dx) + x_min
                yn = y_max - (ix * dy)

                self.pointReceived.emit(((xn, yn), 1.0))

        return True

    def finished(self, result):
        pass

    def cancel(self):
        super().cancel()


def georeference_img_to_tiff(img_np, epsg, x_min, y_min, x_max, y_max):
    # Open the PNG file
    (rasterYSize, rasterXSize, rasterCount) = img_np.shape

    # Create a new GeoTIFF file in memory
    dst = gdal.GetDriverByName('GTiff').Create('/vsimem/bunting_qgis_tracer.tif', rasterXSize, rasterYSize, rasterCount,
                                               gdal.GDT_Byte, options=["COMPRESS=JPEG", "JPEG_QUALITY=85"])

    # Set the geotransform
    geotransform = [x_min, (x_max-x_min)/rasterXSize, 0, y_min, 0, (y_max-y_min)/rasterYSize]
    dst.SetGeoTransform(geotransform)

    # Set the projection
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(epsg)
    dst.SetProjection(srs.ExportToWkt())

    # Write the array data to the raster bands
    for b in range(rasterCount):
        band = dst.GetRasterBand(b + 1)
        band.WriteArray(img_np[:, :, b])

    # Close the files
    dst = None

    # Return the GeoTIFF-encoded memory contents as a byte array
    f = gdal.VSIFOpenL('/vsimem/bunting_qgis_tracer.tif', 'rb')
    # Because we use the same /vsimem/ URI for each query, double clicking quickly
    # can result in a race condition in georeference_img_to_tiff where it gets .Unlink()'ed
    # before the above open call. This means we get a null pointer here. TODO solve
    # more elegantly, but for now, we'll error out.
    if f is None:
        raise RuntimeError("Autocomplete was used too quickly, please wait a second between requests.")

    gdal.VSIFSeekL(f, 0, os.SEEK_END)
    size = gdal.VSIFTellL(f)
    gdal.VSIFSeekL(f, 0, os.SEEK_SET)
    data = gdal.VSIFReadL(1, size, f)
    gdal.VSIFCloseL(f)

    # Delete the temporary file
    gdal.Unlink('/vsimem/bunting_qgis_tracer.tif')

    return data
