# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
from os.path import dirname, join

import numpy as np
import pandas as pd
from qgis.PyQt import uic
from qgis.PyQt.QtCore import Qt
from qgis.PyQt.QtGui import QColor, QCursor
from qgis.PyQt.QtWidgets import QDialog
from qgis.core import QgsProject, QgsSymbol, QgsRuleBasedRenderer, QgsSpatialIndex
from ..common_tools.point_tool import PointTool

from polaris.utils.database.db_utils import read_and_close

from ..style_loader.style_loader import load_style_by_name

FORM_CLASS, _ = uic.loadUiType(join(dirname(__file__), "forms/locations.ui"))


class LocationDialog(QDialog, FORM_CLASS):
    def __init__(self, _PQgis, location_id=None):
        QDialog.__init__(self)
        self.tool = PointTool(_PQgis.iface.mapCanvas())
        self.iface = _PQgis.iface
        self.canvas = self.iface.mapCanvas()
        self.setupUi(self)

        self._PQgis = _PQgis
        self._p = _PQgis.network
        self.__data_tables = self._p.tables
        self.all_locs: np.ndarray
        self.required_layers = ["Location", "Parking", "Link", "EV_Charging_Stations"]
        self.load_all_locations()
        self.remove_layers = []

        # layers already loaded
        layer_ids = [layer.id() for layer in QgsProject.instance().mapLayers().values()]

        for lyr in self.required_layers:
            if lyr.lower() not in self._PQgis.layers:
                self._PQgis.create_layer_by_name(lyr)
            layer_id = self._PQgis.layers[lyr.lower()][1]
            if layer_id not in layer_ids:
                self._PQgis.load_layer_by_name(lyr)
                self.remove_layers.append(self._PQgis.layers[lyr.lower()][0])
        self.location_layer = self._PQgis.layers["location"][0]
        self.nindex = QgsSpatialIndex(self.location_layer.getFeatures())

        self.parking_layer = self._PQgis.layers["parking"][0]
        load_style_by_name(self.parking_layer, "parking_location_display.qml")

        self.link_layer = self._PQgis.layers["link"][0]

        self.ev_layer = self._PQgis.layers["ev_charging_stations"][0]
        load_style_by_name(self.ev_layer, "ev_charging_location_display.qml")

        self.txt_goto.textChanged.connect(self.quick_selection)
        if location_id:
            self.txt_goto.setText(str(location_id))
        else:
            self.txt_goto.setText("pick one")
            self.style_links([-1])
            self.style_locations(-1)
            self.parking_layer.setSubsetString("parking<0")
            self.ev_layer.setSubsetString("location<0")
            self.activate_selector()

    def activate_selector(self):
        self.canvas.setCursor(QCursor(Qt.PointingHandCursor))
        self.canvas.setMapTool(self.tool)
        self.tool.clicked.connect(self.found_node)

    def found_node(self):
        self.canvas.setCursor(QCursor(Qt.ArrowCursor))
        point = self.tool.point
        point = self.tool.toLayerCoordinates(self.location_layer, point)
        nearest = self.nindex.nearestNeighbor(point, 1)[0]
        feature = self.location_layer.getFeature(nearest)
        self.txt_goto.setText(str(feature["location"]))

    def quick_selection(self):
        location_id = self.txt_goto.text()
        if not location_id.isdigit():
            return

        loc_id = int(location_id)

        if np.isin(self.all_locs, loc_id).any():
            self.load_location(loc_id)

    def load_location(self, location_id):
        # Get parking facilities and links associated with the location
        with read_and_close(self._PQgis.supply_path) as conn:
            parking = pd.read_sql(f"Select parking from Location_Parking where location={location_id}", conn)
            links = pd.read_sql(f"Select link from Location_Links where location={location_id}", conn)

        # Filters the parking layer to the corresponding parking facilities
        if parking.empty:
            fltr = '"parking" = -1'
        else:
            parks = tuple(parking.parking.tolist())
            fltr = f'"parking" IN {str(parks)}' if len(parks) > 1 else f'"parking" = {parks[0]}'
        self.parking_layer.setSubsetString(fltr)

        # Filters the EV_Charging_Stations layer to those associated with the current location
        for lyr in ["EV_Charging_Stations"]:
            self._PQgis.layers[lyr.lower()][0].setSubsetString(f'"location"={location_id}')

        # Styles the link layer
        self.style_links(tuple(links.link.tolist()) if links.shape[0] > 0 else (-1))

        # styles the location layer
        self.style_locations(location_id)

        # Zooms to the layer we are working with
        self.zoom_to_location(location_id)

    def zoom_to_location(self, location_id):
        features = list(self.location_layer.getFeatures(f'"location"= {location_id}'))
        if not features:
            return

        centroid = features[0].geometry().centroid().asPoint()
        self.iface.mapCanvas().setCenter(centroid)
        self.iface.mapCanvas().refresh()
        self.activate_selector()

    def load_all_locations(self):
        with read_and_close(self._PQgis.supply_path) as conn:
            df = pd.read_sql("Select location from Location order by location", conn)
            self.all_locs = df.location.to_numpy()

    def style_links(self, tuple_of_links):
        link_rule = f'"link" IN {str(tuple_of_links)}' if len(tuple_of_links) > 1 else f'"link" = {tuple_of_links[0]}'
        # create a new rule-based renderer
        symbol = QgsSymbol.defaultSymbol(self.link_layer.geometryType())
        renderer = QgsRuleBasedRenderer(symbol)

        # get the "root" rule
        root_rule = renderer.rootRule()
        rules = (
            ("Location links", link_rule, "red", 3),
            ("network", "", "grey", 0.26),
        )

        for label, expression, color_name, thickness in rules:
            # create a clone (i.e. a copy) of the default rule
            rule = root_rule.children()[0].clone()
            # set the label, expression and color
            rule.setLabel(label)
            if expression:
                rule.setFilterExpression(expression)
            else:
                rule.setIsElse(True)
            rule.symbol().setColor(QColor(color_name))
            rule.symbol().setWidth(thickness)
            # append the rule to the list of rules
            root_rule.appendChild(rule)

        # delete the default rule
        root_rule.removeChildAt(0)

        # apply the renderer to the layer
        self.link_layer.setRenderer(renderer)
        # refresh the layer on the map canvas
        self.link_layer.triggerRepaint()

    def style_locations(self, location_id):
        # create a new rule-based renderer
        symbol = QgsSymbol.defaultSymbol(self.location_layer.geometryType())
        renderer = QgsRuleBasedRenderer(symbol)

        # get the "root" rule
        root_rule = renderer.rootRule()
        rules = (
            ("Location", f'"location" = {location_id}', "red", 5),
            ("other locations", "", "grey", 1.0),
        )

        for label, expression, color_name, thickness in rules:
            # create a clone (i.e. a copy) of the default rule
            rule = root_rule.children()[0].clone()
            # set the label, expression and color
            rule.setLabel(label)
            if expression:
                rule.setFilterExpression(expression)
            else:
                rule.setIsElse(True)
            rule.symbol().setColor(QColor(color_name))
            rule.symbol().setSize(thickness)
            # append the rule to the list of rules
            root_rule.appendChild(rule)

        # delete the default rule
        root_rule.removeChildAt(0)

        # apply the renderer to the layer
        self.location_layer.setRenderer(renderer)
        # refresh the layer on the map canvas
        self.location_layer.triggerRepaint()

    def exit_procedure(self):
        self.close()

    def closeEvent(self, event):
        root = QgsProject.instance().layerTreeRoot()
        for lyr in self.remove_layers:
            root.removeLayer(lyr)
        self.close()
        self.exit_procedure()
