import os
from qgis.PyQt import uic
from qgis.PyQt.QtCore import pyqtSignal, Qt
from qgis.PyQt.QtWidgets import QDockWidget, QVBoxLayout, QWidget, QCheckBox

import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure

# Load the UI file
FORM_CLASS, _ = uic.loadUiType(os.path.join(os.path.dirname(__file__), 'spectra_dockwidget_base.ui'))

class SpectraDockWidget(QDockWidget, FORM_CLASS):
    closingPlugin = pyqtSignal()
    reset_clicked = pyqtSignal()

    def __init__(self, parent=None):
        super(SpectraDockWidget, self).__init__(parent)
        self.setupUi(self)
        
        # Layout for Plot
        self.plot_layout = QVBoxLayout(self.layout_frame) 
        self.plot_layout.setContentsMargins(0, 0, 0, 0)

        # Matplotlib Setup
        self.figure = Figure(figsize=(5, 4), dpi=100)
        self.canvas = FigureCanvas(self.figure)
        self.plot_layout.addWidget(self.canvas)
        
        # Add Interactive Toolbar (Zoom, Pan, Save)
        self.toolbar = NavigationToolbar(self.canvas, self.layout_frame)
        self.plot_layout.addWidget(self.toolbar)

        self.ax = self.figure.add_subplot(111)
        # Adjust margins to fit labels
        self.figure.subplots_adjust(top=0.90, bottom=0.15, left=0.12, right=0.95)

        # Connect the Hover Event
        self.canvas.mpl_connect("motion_notify_event", self.on_hover)

        # Controls
        self.chk_wavelength = QCheckBox("Use Wavelength (nm)")
        self.chk_wavelength.setChecked(False)
        self.chk_wavelength.setEnabled(False) 
        self.chk_wavelength.stateChanged.connect(self.refresh_plot)
        self.layout_options.addWidget(self.chk_wavelength)

        self.chk_multi = QCheckBox("Multi-Plot Mode")
        self.chk_multi.setChecked(False)
        self.layout_options.addWidget(self.chk_multi)

        # Data Storage
        self.traces = []  
        self.plot_lines = [] 
        self.chart_title = "Spectral Profile"
        self.annot = None
        self.legend = None # Store legend reference

        if hasattr(self, 'btn_reset'):
            self.btn_reset.clicked.connect(self.handle_reset_click)

        self.reset_plot_area()

    def closeEvent(self, event):
        self.closingPlugin.emit()
        event.accept()

    def handle_reset_click(self):
        self.clear_all_traces()
        self.reset_clicked.emit()

    def reset_plot_area(self):
        self.ax.clear()
        self.ax.text(0.5, 0.5, "Click on map to plot", 
                     horizontalalignment='center', verticalalignment='center',
                     transform=self.ax.transAxes, color='gray')
        self.ax.set_xticks([])
        self.ax.set_yticks([])
        self.canvas.draw()

    def add_trace(self, x_data, y_data, color, label, layer_name, markers=None):
        self.traces.append({
            'x': x_data,
            'y': y_data,
            'color': color,
            'label': label,
            'markers': markers or {}
        })
        self.chart_title = layer_name
        
        if x_data is not None:
            self.chk_wavelength.setEnabled(True)
            self.chk_wavelength.setChecked(True)
            
        self.refresh_plot()

    def update_latest_markers(self, markers):
        if self.traces:
            self.traces[-1]['markers'] = markers
            self.refresh_plot()

    def clear_all_traces(self):
        self.traces = []
        self.plot_lines = []
        self.chart_title = "Spectral Profile"
        self.chk_wavelength.setEnabled(False)
        self.reset_plot_area()

    def on_hover(self, event):
        """Handles mouse movement: Tooltip + Legend Position."""
        if not self.annot:
            return

        vis = self.annot.get_visible()
        
        # --- LEGEND RE-POSITIONING LOGIC ---
        if event.inaxes == self.ax and self.legend:
            # Check if mouse is on Left or Right half
            xlim = self.ax.get_xlim()
            mid_x = (xlim[0] + xlim[1]) / 2
            
            # Default is 'upper right'
            new_loc = 'upper right'
            if event.xdata > mid_x:
                # Mouse is on Right side -> Move Legend to Left
                new_loc = 'upper left'
            
            # Only redraw if location actually changed
            # Matplotlib doesn't let us easily check current loc string, 
            # so we just re-set it if valid.
            # However, legend.set_loc() is not available in all versions directly or 
            # requires re-drawing the legend. The safest way is to re-generate the legend
            # but that is heavy. 
            # A lighter way: check the bbox of the legend vs mouse.
            # SIMPLEST ROBUST WAY: Just set the loc property and draw_idle
            
            # Note: set_loc is for the draggable legend helper, not the legend itself usually.
            # We must recreate the legend to move it reliably in standard Matplotlib backend
            # OR we can just use the property that was set initially.
            
            # Let's check the previous location index if we can, or just update it.
            # Codes: 'upper right'=1, 'upper left'=2
            current_code = self.legend._loc if hasattr(self.legend, '_loc') else 0
            target_code = 2 if new_loc == 'upper left' else 1
            
            if current_code != target_code:
                 self.legend = self.ax.legend(loc=new_loc, fontsize='small')
                 self.canvas.draw_idle()

        # --- TOOLTIP LOGIC ---
        if event.inaxes == self.ax:
            found = False
            for line, label in self.plot_lines:
                cont, ind = line.contains(event)
                if cont:
                    x_arr, y_arr = line.get_data()
                    idx = ind['ind'][0]
                    x_val = x_arr[idx]
                    y_val = y_arr[idx]
                    
                    # 1. Update Position
                    self.annot.xy = (x_val, y_val)

                    # 2. Dynamic Direction (Left vs Right)
                    xlim = self.ax.get_xlim()
                    mid_x = (xlim[0] + xlim[1]) / 2

                    if x_val > mid_x:
                        self.annot.set_position((-15, 10))
                        self.annot.set_horizontalalignment('right')
                    else:
                        self.annot.set_position((15, 10))
                        self.annot.set_horizontalalignment('left')
                    
                    # 3. Update Text
                    band_num = int(idx + 1)
                    text = f"{label}\nBand: {band_num}\nX: {x_val:.1f}\nY: {y_val:.3f}"
                    self.annot.set_text(text)
                    
                    self.annot.get_bbox_patch().set_alpha(0.9)
                    self.annot.set_visible(True)
                    self.canvas.draw_idle()
                    found = True
                    break
            
            if not found and vis:
                self.annot.set_visible(False)
                self.canvas.draw_idle()

    def refresh_plot(self):
        self.ax.clear()
        self.plot_lines = [] 
        
        if not self.traces:
            self.reset_plot_area()
            return

        latest_trace = self.traces[-1]
        has_wavelengths = (latest_trace['x'] is not None)
        use_wavelengths = self.chk_wavelength.isChecked() and has_wavelengths
        
        x_label = "Wavelength (nm)" if use_wavelengths else "Band Number"
        
        all_x = []
        all_y = []

        # 1. Plot Traces
        for trace in self.traces:
            y_vals = trace['y']
            if use_wavelengths:
                x_vals = trace['x']
            else:
                x_vals = list(range(1, len(y_vals) + 1))
            
            all_x.extend(x_vals)
            all_y.extend(y_vals)

            line, = self.ax.plot(x_vals, y_vals, color=trace['color'], marker='o', 
                         markersize=4, label=trace['label'], linewidth=1.5)
            line.set_picker(5)
            self.plot_lines.append((line, trace['label']))

        # Initialize Annotation
        self.annot = self.ax.annotate("", xy=(0,0), xytext=(10,10),
                            textcoords="offset points",
                            bbox=dict(boxstyle="round", fc="white", ec="black", alpha=0.9),
                            arrowprops=dict(arrowstyle="->"))
        self.annot.set_visible(False)

        # Offset for labels
        if all_x:
            x_min, x_max = min(all_x), max(all_x)
            x_range = x_max - x_min if x_max != x_min else 1.0
            text_offset = x_range * 0.015
        else:
            text_offset = 0

        # 2. Specific Reference Wavelengths (B, G, R, N)
        if use_wavelengths:
            refs = [
                (485.9, 'blue', 'B'),
                (560.0, 'green', 'G'),
                (657.1, 'red', 'R'),
                (866.0, 'brown', 'N')
            ]
            for wav, col, txt in refs:
                self.ax.axvline(x=wav, color=col, linestyle='--', linewidth=0.8, alpha=0.5)
                # Shift Text RIGHT
                self.ax.text(wav + text_offset, 0.01, txt, color=col, fontsize=7, 
                             horizontalalignment='left', verticalalignment='bottom',
                             transform=self.ax.get_xaxis_transform())

        # 3. Dynamic Indicators
        if 'markers' in latest_trace:
            markers = latest_trace['markers']
            
            def get_x_pos(band_idx):
                idx0 = band_idx - 1
                if use_wavelengths and has_wavelengths:
                    current_x = latest_trace['x']
                    if 0 <= idx0 < len(current_x):
                        return current_x[idx0]
                    return None
                else:
                    return band_idx

            # A. Named Bands (Red/Green/Blue) -> Solid Lines
            for m in markers.get('named', []):
                x_pos = get_x_pos(m['band'])
                if x_pos is not None:
                    self.ax.axvline(x=x_pos, color=m['color'], linestyle='-', linewidth=1, alpha=0.3)
                    self.ax.text(x_pos, 1.02, m['label'], color=m['color'], 
                                 fontsize=8, rotation=90, verticalalignment='bottom', 
                                 transform=self.ax.get_xaxis_transform())

            # B. Active Visualization Bands (1, 2, 3)
            active_bands = markers.get('active', [])
            for i, band_idx in enumerate(active_bands):
                x_pos = get_x_pos(band_idx)
                if x_pos is not None:
                    # Black color, Dotted style, Thin but visible
                    self.ax.axvline(x=x_pos, color='black', linestyle=':', linewidth=0.8, alpha=1.0)
                    # Shift Text LEFT
                    self.ax.text(x_pos - text_offset, 0.96, str(i + 1), color='black', fontweight='bold',
                                 fontsize=9, horizontalalignment='right', verticalalignment='top',
                                 transform=self.ax.get_xaxis_transform())

        # 4. Styling & LEGEND CREATION
        self.ax.set_title(f"{self.chart_title}")
        self.ax.set_xlabel(x_label)
        self.ax.set_ylabel("Pixel Value")
        self.ax.grid(True, alpha=0.3)
        
        # Initialize Legend at top right
        self.legend = self.ax.legend(loc='upper right', fontsize='small')
        
        self.ax.set_ylim(0, 1)
        
        if all_x:
            pad = (x_max - x_min) * 0.05 if x_max != x_min else 1.0
            self.ax.set_xlim(x_min - pad, x_max + pad)

        self.canvas.draw()