import os.path
import random
from qgis.PyQt.QtCore import QCoreApplication, Qt
from qgis.PyQt.QtGui import QIcon, QColor
from qgis.PyQt.QtWidgets import QAction

# CORE imports
from qgis.core import QgsWkbTypes, QgsGeometry, QgsRectangle, QgsRaster
# GUI imports
from qgis.gui import QgsMapToolEmitPoint, QgsRubberBand

from .spectra_dockwidget import SpectraDockWidget
from .resources import *

class SpectraPlotter:
    def __init__(self, iface):
        self.iface = iface
        self.plugin_dir = os.path.dirname(__file__)
        self.actions = []
        self.menu = self.tr(u'&Spectra Plotter')
        self.dockwidget = None
        self.tool = None
        
        self.active_markers = []
        self.monitored_layer = None
        self.color_idx = 0
        
        # 20 Distinct Dark Colors
        self.colors = [
            '#E6194B', '#3CB44B', '#4363D8', '#F58231', '#911EB4',
            '#469990', '#DCBEFF', '#9A6324', '#800000', '#AAFFC3',
            '#808000', '#000075', '#A9A9A9', '#000000', '#f032e6',
            '#aaffc3', '#808000', '#ffd8b1', '#000075', '#808080'
        ]

    def tr(self, message):
        return QCoreApplication.translate('SpectraPlotter', message)

    def add_action(self, icon_path, text, callback, enabled_flag=True, add_to_menu=True,
                   add_to_toolbar=True, status_tip=None, whats_this=None, parent=None):
        icon = QIcon(icon_path)
        action = QAction(icon, text, parent)
        action.triggered.connect(callback)
        action.setEnabled(enabled_flag)
        if add_to_toolbar: self.iface.addToolBarIcon(action)
        if add_to_menu: self.iface.addPluginToMenu(self.menu, action)
        self.actions.append(action)
        return action

    def initGui(self):
        icon_path = ':/plugins/spectra/icon.png'
        self.action = self.add_action(
            icon_path,
            text=self.tr(u'Spectra Plotter'),
            callback=self.run,  
            parent=self.iface.mainWindow())

    def unload(self):
        for action in self.actions:
            self.iface.removePluginMenu(self.tr(u'&Spectra Plotter'), action)
            self.iface.removeToolBarIcon(action)
        if self.dockwidget:
            self.dockwidget.close()
            self.iface.removeDockWidget(self.dockwidget)
        if self.tool:
            self.iface.mapCanvas().unsetMapTool(self.tool)

    def run(self):
        if not self.dockwidget:
            self.dockwidget = SpectraDockWidget()
            self.dockwidget.closingPlugin.connect(self.on_close_dock) 
            self.dockwidget.reset_clicked.connect(self.reset_all) 

        self.iface.addDockWidget(Qt.RightDockWidgetArea, self.dockwidget)
        self.dockwidget.show()
        
        self.tool = QgsMapToolEmitPoint(self.iface.mapCanvas())
        self.tool.canvasClicked.connect(self.handle_click)
        self.iface.mapCanvas().setMapTool(self.tool)
        
        self.check_layer_connection()

    def on_close_dock(self):
        self.reset_all()
        self.iface.mapCanvas().unsetMapTool(self.tool)

    def check_layer_connection(self):
        layer = self.iface.activeLayer()
        if layer and layer.type() == layer.RasterLayer:
            if self.monitored_layer != layer:
                if self.monitored_layer:
                    try: self.monitored_layer.styleChanged.disconnect(self.on_style_changed)
                    except: pass
                
                self.monitored_layer = layer
                self.monitored_layer.styleChanged.connect(self.on_style_changed)

    def on_style_changed(self):
        if self.monitored_layer and self.dockwidget:
            new_info = self.get_band_info(self.monitored_layer)
            self.dockwidget.update_latest_markers(new_info)

    def get_band_info(self, layer):
        info = {'active': [], 'named': []}
        
        # 1. Active Visual Bands
        renderer = layer.renderer()
        if renderer:
            if renderer.type() == 'multibandcolor':
                rb, gb, bb = renderer.redBand(), renderer.greenBand(), renderer.blueBand()
                if rb > 0: info['active'].append(rb)
                if gb > 0: info['active'].append(gb)
                if bb > 0: info['active'].append(bb)
            elif renderer.type() == 'singlebandgray':
                gb = renderer.grayBand()
                if gb > 0: info['active'].append(gb)
            elif renderer.type() == 'singlebandpseudocolor':
                 if renderer.band() > 0: info['active'].append(renderer.band())

        # 2. Named Bands from Metadata
        for i in range(1, layer.bandCount() + 1):
            name = layer.bandName(i)
            if not name: continue
            if "(Red)" in name:
                info['named'].append({'band': i, 'color': 'red', 'label': 'Red'})
            elif "(Green)" in name:
                info['named'].append({'band': i, 'color': 'green', 'label': 'Green'})
            elif "(Blue)" in name:
                info['named'].append({'band': i, 'color': 'blue', 'label': 'Blue'})
        
        return info

    def handle_click(self, point):
        layer = self.iface.activeLayer()
        if not layer or layer.type() != layer.RasterLayer:
            self.iface.messageBar().pushMessage("Error", "Please select a Raster Layer", level=1)
            return
        
        # 1. Reset logic first
        is_multi = self.dockwidget.chk_multi.isChecked() if self.dockwidget else False
        if not is_multi:
            self.reset_all()

        # 2. Connect logic second
        self.check_layer_connection()

        # 3. Geometry
        extent = layer.extent()
        x_res = layer.rasterUnitsPerPixelX()
        y_res = layer.rasterUnitsPerPixelY()
        col = int((point.x() - extent.xMinimum()) / x_res)
        row = int((extent.yMaximum() - point.y()) / y_res)

        if col < 0 or col >= layer.width() or row < 0 or row >= layer.height():
            return

        pixel_x_min = extent.xMinimum() + (col * x_res)
        pixel_y_max = extent.yMaximum() - (row * y_res)
        true_rect = QgsRectangle(pixel_x_min, pixel_y_max - y_res, pixel_x_min + x_res, pixel_y_max)

        # 4. Markers
        hex_color = self.get_next_color()
        q_color = QColor(hex_color)

        box = QgsRubberBand(self.iface.mapCanvas(), QgsWkbTypes.PolygonGeometry)
        box.setColor(QColor(q_color.red(), q_color.green(), q_color.blue(), 40)) 
        box.setStrokeColor(q_color) 
        box.setWidth(2) 
        box.setToGeometry(QgsGeometry.fromRect(true_rect), layer.crs())
        
        cross = QgsRubberBand(self.iface.mapCanvas(), QgsWkbTypes.PointGeometry)
        cross.setIcon(QgsRubberBand.ICON_CROSS)
        cross.setColor(q_color)
        cross.setWidth(2) 
        cross.setToGeometry(QgsGeometry.fromPointXY(true_rect.center()), layer.crs())
        
        self.active_markers.append({'box': box, 'cross': cross})

        # 5. Data & Plot
        provider = layer.dataProvider()
        results = provider.identify(point, QgsRaster.IdentifyFormatValue)
        
        if results.isValid():
            values = results.results()
            y_values = [values.get(b, 0) for b in range(1, layer.bandCount() + 1)]
            x_values = self.get_wavelengths(layer)
            band_markers = self.get_band_info(layer)

            if self.dockwidget:
                label = f"Pt {len(self.active_markers)}"
                self.dockwidget.add_trace(x_values, y_values, hex_color, label, layer.name(), band_markers)

    def get_wavelengths(self, layer):
        wavelengths = []
        found_any = False
        for i in range(1, layer.bandCount() + 1):
            name = layer.bandName(i)
            import re
            match = re.search(r'(\d+(?:\.\d+)?)nm', name)
            if match:
                wavelengths.append(float(match.group(1)))
                found_any = True
            else:
                wavelengths.append(i) 
        return wavelengths if found_any else None

    def get_next_color(self):
        if self.color_idx < len(self.colors):
            c = self.colors[self.color_idx]
        else:
            r = random.randint(0, 150)
            g = random.randint(0, 150)
            b = random.randint(0, 150)
            c = '#{:02x}{:02x}{:02x}'.format(r, g, b)
        self.color_idx += 1
        return c

    def reset_all(self):
        if self.monitored_layer:
            try: self.monitored_layer.styleChanged.disconnect(self.on_style_changed)
            except: pass
            self.monitored_layer = None

        for m_set in self.active_markers:
            try:
                self.iface.mapCanvas().scene().removeItem(m_set['box'])
                self.iface.mapCanvas().scene().removeItem(m_set['cross'])
            except: pass
        self.active_markers = []
        self.color_idx = 0
        if self.dockwidget:
            self.dockwidget.clear_all_traces()