# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
# Copyright (c) 2025, UChicago Argonne, LLC
# BSD OPEN SOURCE LICENSE. Full license can be found in LICENSE
from polaris.utils.database.db_utils import commit_and_close
from qgis.PyQt.QtCore import Qt
from qgis.core import QgsProject, QgsSpatialIndex
from qgis.gui import QgsMapToolPan
from qgis.PyQt.QtGui import QCursor

from ..supply_editor.intersection_dialog import IntersectionDialog
from .point_tool import PointTool


class FeaturePicker:
    def __init__(self, _PQgis, layer_name, field_name):
        self.tool = PointTool(_PQgis.iface.mapCanvas())
        self.iface = _PQgis.iface
        self.canvas = self.iface.mapCanvas()
        self._PQgis = _PQgis
        self.layer_name = layer_name
        self.field_name = field_name
        self.feature_id = None

        # layers already loaded
        layer_ids = [layer.id() for layer in QgsProject.instance().mapLayers().values()]

        _ = self._PQgis.layers.get(self.layer_name.lower(), self._PQgis.create_layer_by_name(self.layer_name))

        self.layer, layer_id = self._PQgis.layers[self.layer_name.lower()]
        if layer_id not in layer_ids:
            self._PQgis.load_layer_by_name(self.layer_name.lower())
        self.nindex = QgsSpatialIndex(self.layer.getFeatures())
        self.canvas.setCursor(QCursor(Qt.PointingHandCursor))
        self.canvas.setMapTool(self.tool)
        self.add_tool()

    def add_tool(self):
        self.tool.clicked.connect(self.found_node)

    def found_node(self):
        self.canvas.setCursor(QCursor(Qt.ArrowCursor))
        point = self.tool.point
        point = self.tool.toLayerCoordinates(self._PQgis.layers[self.layer_name.lower()][0], point)
        nearest = self.nindex.nearestNeighbor(point, 1)[0]
        feature = self.layer.getFeature(nearest)
        self.feature_id = feature[self.field_name]

        if self.layer_name.lower() == "node" and isinstance(self.tool, PointTool):
            self.show_intersection()

    def show_intersection(self):
        self.canvas.setMapTool(QgsMapToolPan(self.canvas))
        with commit_and_close(self._PQgis.supply_path, spatial=True) as conn:
            dlg2 = IntersectionDialog(self._PQgis, conn, self.feature_id)
            dlg2.setWindowFlags(Qt.WindowStaysOnTopHint)
            dlg2.show()
            dlg2.exec_()
