from qgis.PyQt.QtWidgets import QAction, QDialog, QMessageBox, QFileDialog
from qgis.PyQt.QtCore import QDir, QObject, pyqtSignal
from qgis.core import QgsProject, QgsRaster
from qgis.utils import iface
from .equivalentslopetool_dialog_base import Ui_EquivalentSlopeToolDialogBase
import pandas as pd
import matplotlib.pyplot as plt


class SignalEmitter(QObject):
    exportSuccess = pyqtSignal(str)
    exportError = pyqtSignal(str)


class EquivalentSlopeTool:
    def __init__(self, iface):
        self.iface = iface
        self.signals = SignalEmitter()
        self.signals.exportSuccess.connect(
            lambda path: QMessageBox.information(None, "Export", f"Excel saved to {path}")
        )
        self.signals.exportError.connect(
            lambda msg: QMessageBox.critical(None, "Export Error", msg)
        )
        # storage for last run
        self.data = None
        self.equivalent_slope = None

    def initGui(self):
        self.action = QAction("Equivalent Slope Tool", self.iface.mainWindow())
        self.action.triggered.connect(self.run)
        self.iface.addPluginToMenu("&Equivalent Slope Tool", self.action)

    def unload(self):
        self.iface.removePluginMenu("&Equivalent Slope Tool", self.action)

    def run(self):
        self.dialog = QDialog()
        self.ui = Ui_EquivalentSlopeToolDialogBase()
        self.ui.setupUi(self.dialog)

        # populate layer comboboxes
        self.ui.comboBox.clear()
        self.ui.comboBox_2.clear()
        for layer in QgsProject.instance().mapLayers().values():
            if layer.type() == layer.RasterLayer:
                self.ui.comboBox.addItem(layer.name())
            elif layer.type() == layer.VectorLayer:
                self.ui.comboBox_2.addItem(layer.name())

        # connect buttons
        # runButton = calculate; pushButton_2 = export excel; pushButton_3 = export plot
        self.ui.runButton.clicked.connect(self.calculateSlope)
        self.ui.pushButton_2.clicked.connect(self.exportExcel)
        self.ui.pushButton_3.clicked.connect(self.exportPlot)

        # disable export until after run
        self.ui.pushButton_2.setEnabled(False)
        self.ui.pushButton_3.setEnabled(False)

        self.dialog.show()

    def calculateSlope(self):
        try:
            dem_name = self.ui.comboBox.currentText()
            line_name = self.ui.comboBox_2.currentText()
            num_segments = int(self.ui.spinBox.value()) if hasattr(self.ui, "spinBox") else 15

            if not dem_name or not line_name:
                QMessageBox.warning(None, "Input error", "Please select DEM and Flow Path layers.")
                return

            dem_layer = QgsProject.instance().mapLayersByName(dem_name)[0]
            line_layer = QgsProject.instance().mapLayersByName(line_name)[0]

            # get line geometry
            feature = next(line_layer.getFeatures())
            geom = feature.geometry()
            total_length = geom.length()

            # ------------- Step 1: dense sampling -------------
            base_interval = 5.0  # meters (you can change to 2 or 10 if desired)
            sample_dists = []
            sample_elevs = []

            d = 0.0
            last_valid_elev = None
            while d <= total_length + 1e-6:
                pt = geom.interpolate(d).asPoint()
                ident = dem_layer.dataProvider().identify(pt, QgsRaster.IdentifyFormatValue)
                elev = ident.results().get(1) if ident.isValid() else None

                # if elevation missing, carry-forward last valid to avoid breaking slopes;
                # this is better than throwing away the point (adjust if you prefer)
                if elev is None and last_valid_elev is not None:
                    elev = last_valid_elev
                elif elev is not None:
                    last_valid_elev = elev

                sample_dists.append(d)
                sample_elevs.append(elev)
                d += base_interval

            # safety: ensure we have at least two samples
            if len(sample_dists) < 2:
                QMessageBox.critical(None, "Error", "Not enough sample points from DEM/line.")
                return

            # ------------- Step 2: compute weights from slopes -------------
            # weight per tiny interval (between sample i and i+1)
            weights = []
            eps = 1e-9
            for i in range(len(sample_dists) - 1):
                z1 = sample_elevs[i]
                z2 = sample_elevs[i + 1]
                dx = sample_dists[i + 1] - sample_dists[i]
                if z1 is None or z2 is None or dx == 0:
                    slope = 0.0
                else:
                    slope = abs((z2 - z1) / dx)  # rise/run absolute slope
                # use slope^p (p>1) to emphasize steep portions; keep p=1 for linear weighting
                p = 1.0
                weight = (slope ** p) + eps
                weights.append(weight)

            # cumulative weight at each sample index (start at 0)
            cum_weights = [0.0]
            for w in weights:
                cum_weights.append(cum_weights[-1] + w)
            total_weight = cum_weights[-1]

            # if all weights zero (flat), do uniform segmentation by distance
            if total_weight <= eps:
                # uniform indices
                targets = [i * (len(sample_dists) - 1) / num_segments for i in range(1, num_segments)]
                # convert fractional targets to integer indices
                selected_indices = [0]
                for t in targets:
                    idx = int(round(t))
                    if idx <= 0:
                        idx = 1
                    if idx >= len(sample_dists) - 1:
                        idx = len(sample_dists) - 2
                    if idx not in selected_indices:
                        selected_indices.append(idx)
                selected_indices.append(len(sample_dists) - 1)
            else:
                # place N segments by equal cumulative-weight fractions
                selected_indices = [0]
                # targets in cumulative-weight domain
                targets = [(total_weight * k) / num_segments for k in range(1, num_segments)]
                ti = 0
                for i in range(1, len(cum_weights)):
                    while ti < len(targets) and cum_weights[i] >= targets[ti]:
                        # map this target to sample index i
                        if i not in selected_indices:
                            selected_indices.append(i)
                        ti += 1
                    if ti >= len(targets):
                        break
                # ensure end included
                if selected_indices[-1] != len(sample_dists) - 1:
                    selected_indices.append(len(sample_dists) - 1)

                # If we got fewer than num_segments+1 points because of duplicates, attempt to fill
                # (rare) by adding evenly spaced indices
                while len(selected_indices) < num_segments + 1:
                    # insert midpoints between existing points
                    inserted = False
                    for j in range(len(selected_indices) - 1):
                        a = selected_indices[j]
                        b = selected_indices[j + 1]
                        if b - a > 1:
                            mid = (a + b) // 2
                            if mid not in selected_indices:
                                selected_indices.insert(j + 1, mid)
                                inserted = True
                                break
                    if not inserted:
                        # fallback: append last-1 index
                        idx = max(1, len(sample_dists) - 2)
                        if idx not in selected_indices:
                            selected_indices.insert(-1, idx)
                        else:
                            break

            # ensure strictly increasing and unique
            selected_indices = sorted(list(dict.fromkeys(selected_indices)))

            # reduce or expand to exactly num_segments+1 (enforce)
            if len(selected_indices) > num_segments + 1:
                # pick evenly spaced ones from selected_indices to get required count
                step = (len(selected_indices) - 1) / (num_segments)
                new_indices = []
                for k in range(num_segments + 1):
                    idx = int(round(k * step))
                    if idx < 0:
                        idx = 0
                    if idx > len(selected_indices) - 1:
                        idx = len(selected_indices) - 1
                    new_indices.append(selected_indices[idx])
                selected_indices = sorted(list(dict.fromkeys(new_indices)))
            elif len(selected_indices) < num_segments + 1:
                # pad using nearest available indices
                i = 1
                while len(selected_indices) < num_segments + 1 and i < len(sample_dists) - 1:
                    if i not in selected_indices:
                        selected_indices.insert(-1, i)
                    i += 1
                selected_indices = sorted(list(dict.fromkeys(selected_indices)))
                # final trim if still longer than needed
                if len(selected_indices) > num_segments + 1:
                    selected_indices = selected_indices[:num_segments + 1]

            # Build final list of distances and elevations for output (exactly N+1 points)
            final_distances = [sample_dists[i] for i in selected_indices]
            final_elevations = [sample_elevs[i] for i in selected_indices]

            # Step 3: Build table and compute equivalent slope (same formula you used)
            data = []
            start_elev = final_elevations[0]
            for i in range(len(final_distances)):
                serial_no = i + 1
                rd_km = round(final_distances[i] / 1000, 3)
                rl = round(final_elevations[i], 2) if final_elevations[i] is not None else None
                li = round((final_distances[i] - final_distances[i - 1]) / 1000, 3) if i > 0 else 0
                di = round(rl - start_elev, 2) if rl is not None else None
                di_prev = round(final_elevations[i - 1] - start_elev, 2) if i > 0 and final_elevations[i - 1] is not None else 0
                di_sum = round(di_prev + di, 2) if di is not None else None
                li_di_sum = round(li * di_sum, 5) if di_sum is not None else None

                data.append({
                    'Serial No': serial_no,
                    'RD (km)': rd_km,
                    'RL (m)': rl,
                    'Li (km)': li,
                    'Di (m)': di,
                    'Di-1+Di (m)': di_sum,
                    'Li*(Di-1+Di)': li_di_sum
                })

            # save to object state
            self.data = data
            valid_rows = [r for r in data if r['Li*(Di-1+Di)'] is not None]
            sum_li_di_sum = sum(r['Li*(Di-1+Di)'] for r in valid_rows) if valid_rows else 0
            max_rd = data[-1]['RD (km)'] if data else 0.0
            self.equivalent_slope = round(sum_li_di_sum / (max_rd ** 2), 5) if max_rd != 0 else 0.0

            # show result in textBrowser (QTextBrowser)
            if hasattr(self.ui, "textBrowser"):
                self.ui.textBrowser.setText(f"Equivalent Slope = {self.equivalent_slope}")
            else:
                QMessageBox.information(None, "Result", f"Equivalent Slope = {self.equivalent_slope}")

            # enable exports
            if hasattr(self.ui, "pushButton_2"):
                self.ui.pushButton_2.setEnabled(True)
            if hasattr(self.ui, "pushButton_3"):
                self.ui.pushButton_3.setEnabled(True)

        except Exception as ex:
            QMessageBox.critical(None, "Error", str(ex))

    def exportExcel(self):
        try:
            if not self.data:
                QMessageBox.warning(None, "No data", "Run calculation before exporting.")
                return

            path, _ = QFileDialog.getSaveFileName(
                None,
                "Save Excel File",
                QDir.homePath() + "/equivalent_slope.xlsx",
                "Excel Files (*.xlsx)"
            )
            if not path:
                return
            if not path.lower().endswith(".xlsx"):
                path += ".xlsx"

            df = pd.DataFrame(self.data)
            df.fillna("", inplace=True)

            # append summary rows
            valid_rows = [r for r in self.data if r['Li*(Di-1+Di)'] is not None]
            sum_li_di_sum = sum(r['Li*(Di-1+Di)'] for r in valid_rows) if valid_rows else 0
            max_rd = self.data[-1]['RD (km)'] if self.data else 0
            equivalent_slope = round(sum_li_di_sum / (max_rd ** 2), 5) if max_rd != 0 else 0

            summary = pd.DataFrame([
                {'Serial No': '', 'RD (km)': '', 'RL (m)': '', 'Li (km)': '', 'Di (m)': '', 'Di-1+Di (m)': '', 'Li*(Di-1+Di)': round(sum_li_di_sum, 5)},
                {'Serial No': '', 'RD (km)': '', 'RL (m)': '', 'Li (km)': '', 'Di (m)': '', 'Di-1+Di (m)': '', 'Li*(Di-1+Di)': f'Equivalent slope = {equivalent_slope}'}
            ])
            df_out = pd.concat([df, summary], ignore_index=True)
            df_out.to_excel(path, index=False)

            self.signals.exportSuccess.emit(path)

        except Exception as ex:
            self.signals.exportError.emit(str(ex))

    def exportPlot(self):
        try:
            if not self.data:
                QMessageBox.warning(None, "No data", "Run calculation before exporting plot.")
                return

            path, _ = QFileDialog.getSaveFileName(
                None,
                "Save Plot Image",
                QDir.homePath() + "/terrain_profile.png",
                "PNG Image (*.png)"
            )
            if not path:
                return
            if not path.lower().endswith(".png"):
                path += ".png"

            dem_name = self.ui.comboBox.currentText()
            line_name = self.ui.comboBox_2.currentText()
            dem_layer = QgsProject.instance().mapLayersByName(dem_name)[0]
            lp_layer = QgsProject.instance().mapLayersByName(line_name)[0]

            feature = next(lp_layer.getFeatures())
            geom = feature.geometry()
            interval = dem_layer.rasterUnitsPerPixelX()
            length = geom.length()

            distances_dem = []
            elevations_dem = []
            d = 0.0
            last_valid = None
            while d <= length + 1e-6:
                pt = geom.interpolate(d).asPoint()
                ident = dem_layer.dataProvider().identify(pt, QgsRaster.IdentifyFormatValue)
                elev = ident.results().get(1) if ident.isValid() else None
                if elev is None and last_valid is not None:
                    elev = last_valid
                if elev is not None:
                    last_valid = elev
                distances_dem.append(d)
                elevations_dem.append(elev)
                d += interval

            distances_seg = [row['RD (km)'] * 1000 for row in self.data]
            elevations_seg = [row['RL (m)'] for row in self.data]

            valid_elevations = [e for e in elevations_dem + elevations_seg if pd.notnull(e)]
            y_max = max(valid_elevations) * 1.1 if valid_elevations else 100

            plt.figure(figsize=(14, 6))
            plt.title('Combined Terrain Profile: DEM vs Segmented Slope')
            plt.xlabel('Distance (m)')
            plt.ylabel('Elevation (m)')
            plt.ylim(0, y_max)

            # DEM profile (black)
            for i in range(1, len(distances_dem)):
                if elevations_dem[i] is None or elevations_dem[i - 1] is None:
                    continue
                plt.plot([distances_dem[i - 1], distances_dem[i]], [elevations_dem[i - 1], elevations_dem[i]], linewidth=1.5, label='DEM Profile' if i == 1 else "")

            # Segmented profile (green)
            for i in range(1, len(distances_seg)):
                plt.plot([distances_seg[i - 1], distances_seg[i]], [elevations_seg[i - 1], elevations_seg[i]], linewidth=2, label='Segmented Profile' if i == 1 else "")

            # Segment points
            plt.scatter(distances_seg, elevations_seg, color='red', s=30, label='Segment Points')

            # vertical lines
            for i in range(len(distances_seg)):
                if pd.notnull(elevations_seg[i]):
                    plt.plot([distances_seg[i], distances_seg[i]], [0, elevations_seg[i]], color='gray', linestyle='--', linewidth=1)

            plt.grid(True)
            plt.legend()
            plt.tight_layout()
            plt.savefig(path)
            QMessageBox.information(None, "Export", f"Plot saved to {path}")

        except Exception as ex:
            QMessageBox.critical(None, "Plot Error", str(ex))
