# -*- coding: utf-8 -*-
"""
/***************************************************************************
 RasterStretch
                                 A QGIS plugin
 This plugin visualizes raster layers using 2–98 percentile (or custom)
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2025-09-13
        git sha              : $Format:%H$
        copyright            : (C) 2025 by Mikhail Sokolov
        email                : sokolov.usmk@gmail.com
 ***************************************************************************/

/***************************************************************************
 *   This program is free software; you can redistribute it and/or modify   *
 *   it under the terms of the GNU General Public License as published by   *
 *   the Free Software Foundation; either version 2 of the License, or      *
 *   (at your option) any later version.                                    *
 ***************************************************************************/
"""
import os

from qgis.PyQt.QtCore import QSettings, QTranslator, QCoreApplication, Qt, QLocale
from qgis.PyQt.QtGui import QIcon
from qgis.PyQt.QtWidgets import QAction

from qgis.core import (
    Qgis,
    QgsProject,
    QgsRasterLayer,
    QgsRectangle,
    QgsRasterBandStats,
    QgsRasterHistogram,
    QgsContrastEnhancement,
    QgsSingleBandGrayRenderer,
    QgsMultiBandColorRenderer,
    QgsRasterRenderer,
    QgsSingleBandPseudoColorRenderer,
    QgsRasterRange,
)

# Resources & dock widget
try:
    from .resources import *
    from .raster_stretch_dockwidget import RasterStretchDockWidget
except ImportError:  # allow running outside the package (dev convenience)
    import sys
    sys.path.append(os.path.dirname(__file__))
    from resources import *
    from raster_stretch_dockwidget import RasterStretchDockWidget


class RasterStretch:
    """QGIS Plugin Implementation."""

    def __init__(self, iface):
        """
        :param iface: QGIS interface (QgsInterface)
        """
        self.iface = iface
        self.plugin_dir = os.path.dirname(__file__)

        # i18n (robust to missing settings in tests/clean profiles)
        val = QSettings().value("locale/userLocale", QLocale().name())
        # Ensure it's a str and safely slice to 2 letters
        locale = (str(val) if val is not None else "en")[:2]
        locale_path = os.path.join(self.plugin_dir, "i18n", f"RasterStretch_{locale}.qm")
        if os.path.exists(locale_path):
            self.translator = QTranslator()
            self.translator.load(locale_path)
            QCoreApplication.installTranslator(self.translator)

        # UI state
        self.actions = []
        self.menu = self.tr("&RasterStretch")
        self.toolbar = self.iface.addToolBar("RasterStretch")
        self.toolbar.setObjectName("RasterStretch")
        self.pluginIsActive = False
        self.dockwidget = None

    # ------------------------------------------------------------------ Utils

    @staticmethod
    def tr(message: str) -> str:
        return QCoreApplication.translate("RasterStretch", message)

    def add_action(
        self,
        icon_path,
        text,
        callback,
        enabled_flag=True,
        add_to_menu=True,
        add_to_toolbar=True,
        status_tip=None,
        whats_this=None,
        parent=None,
    ) -> QAction:
        icon = QIcon(icon_path)
        action = QAction(icon, text, parent)
        action.triggered.connect(callback)
        action.setEnabled(enabled_flag)
        if status_tip:
            action.setStatusTip(status_tip)
        if whats_this:
            action.setWhatsThis(whats_this)
        if add_to_toolbar:
            self.toolbar.addAction(action)
        if add_to_menu:
            self.iface.addPluginToMenu(self.menu, action)
        self.actions.append(action)
        return action

    # ----------------------------------------------------------------- QGIS UI

    def initGui(self):
        icon_path = ":/plugins/raster_stretch/icon.png"
        self.add_action(icon_path, text=self.tr("Raster Stretch"), callback=self.run, parent=self.iface.mainWindow())

    def onClosePlugin(self):
        self.dockwidget.closingPlugin.disconnect(self.onClosePlugin)
        self.pluginIsActive = False

    def unload(self):
        for action in self.actions:
            self.iface.removePluginMenu(self.tr("&RasterStretch"), action)
            self.iface.removeToolBarIcon(action)
        del self.toolbar

    # ----------------------------------------------------------- Core routine

    def _percentile_from_hist(self, provider, band: int, lower_pct: float, upper_pct: float,
                              extent=None, bins: int = 256):
        """
        Calculate lower and upper percentile values from a sampled histogram.

        :returns: (lower_value, upper_value) or (None, None) on failure
        """
        if extent is None:
            extent = QgsRectangle()

        # Preserve existing user NoData ranges
        original_user_nd = provider.userNoDataValues(band)
        changed_nd = False

        # Optional: apply temporary user NoData from UI
        if self.dockwidget.checkBoxNoData.isChecked() and self.dockwidget.lineEditNoData.isEnabled():
            try:
                nodata_val = float(self.dockwidget.lineEditNoData.text())
            except ValueError:
                nodata_val = None

            if nodata_val is not None:
                provider.setUserNoDataValue(band, [QgsRasterRange(nodata_val, nodata_val)])
                provider.reload()
                changed_nd = True
            else:
                provider.setUserNoDataValue(band, [])
                provider.reload()
                changed_nd = True

        # Sampled min/max
        stats = provider.bandStatistics(
            band,
            QgsRasterBandStats.Min | QgsRasterBandStats.Max,
            extent,
            sampleSize=200_000,
        )
        if not stats or stats.minimumValue is None or stats.maximumValue is None:
            if changed_nd:
                provider.setUserNoDataValue(band, original_user_nd)
            return None, None

        minval, maxval = stats.minimumValue, stats.maximumValue
        if minval is None or maxval is None or minval >= maxval:
            if changed_nd:
                provider.setUserNoDataValue(band, original_user_nd)
            return None, None

        # Histogram (sampled)
        h = provider.histogram(
            band, bins, minval, maxval, extent,
            sampleSize=200_000, includeOutOfRange=False
        )
        if not isinstance(h, QgsRasterHistogram) or not getattr(h, "histogramVector", None):
            if changed_nd:
                provider.setUserNoDataValue(band, original_user_nd)
            return None, None

        counts = list(h.histogramVector)
        if len(counts) != bins:
            if changed_nd:
                provider.setUserNoDataValue(band, original_user_nd)
            return None, None

        total = sum(counts)
        if total <= 0:
            if changed_nd:
                provider.setUserNoDataValue(band, original_user_nd)
            return None, None

        # Cumulative thresholds
        lower_threshold = total * (lower_pct / 100.0)
        upper_threshold = total * (upper_pct / 100.0)

        acc = 0
        low_bin = 0
        for i, c in enumerate(counts):
            acc += c
            if acc >= lower_threshold:
                low_bin = i
                break

        acc = 0
        high_bin = len(counts) - 1
        for i in range(len(counts) - 1, -1, -1):
            acc += counts[i]
            if acc >= (total - upper_threshold):
                high_bin = i
                break

        bin_width = (maxval - minval) / float(len(counts))
        lower_value = minval + bin_width * low_bin
        upper_value = minval + bin_width * (high_bin + 1)

        # Restore original user NoData if we changed it
        if changed_nd:
            provider.setUserNoDataValue(band, original_user_nd)

        if lower_value >= upper_value:
            return None, None
        return lower_value, upper_value

    # --------------------------------------------------------------- Commands

    def run(self):
        """Instantiate and show the dock widget."""
        if self.pluginIsActive:
            return

        self.pluginIsActive = True
        if self.dockwidget is None:
            self.dockwidget = RasterStretchDockWidget()
            self._connect_project_signals()
            self._refresh_raster_combo()
            self.dockwidget.comboBox.currentIndexChanged.connect(self._on_raster_selected)
            self.dockwidget.Apply.clicked.connect(self._on_apply_clicked)
            self.dockwidget.checkBoxNoData.stateChanged.connect(self._on_nodata_checked)

        self.dockwidget.closingPlugin.connect(self.onClosePlugin)
        self.iface.addDockWidget(Qt.RightDockWidgetArea, self.dockwidget)
        self.dockwidget.show()

    def _on_nodata_checked(self, state):
        self.dockwidget.lineEditNoData.setEnabled(state == Qt.Checked)

    def _on_apply_clicked(self):
        """Apply stretch to the selected raster layer."""
        current_min = float(self.dockwidget.minPercVal.value())
        current_max = float(self.dockwidget.maxPercVal.value())
        use_canvas_extent = False

        lyr = self._selected_layer()
        if lyr is None:
            return

        renderer = lyr.renderer()
        provider = lyr.dataProvider()
        extent = self.iface.mapCanvas().extent() if use_canvas_extent else lyr.extent()

        def apply_to_band(band: int, set_ce_callable):
            min_val, max_val = self._percentile_from_hist(
                provider=provider,
                band=band,
                lower_pct=current_min,
                upper_pct=current_max,
                extent=extent,
                bins=512,
            )
            if min_val is None or max_val is None:
                self.iface.messageBar().pushMessage(
                    "RasterStretch", "Failed to compute percentiles", level=Qgis.Warning, duration=4
                )
                return False

            ce = QgsContrastEnhancement(provider.dataType(band))
            ce.setMinimumValue(min_val)
            ce.setMaximumValue(max_val)
            ce.setContrastEnhancementAlgorithm(QgsContrastEnhancement.StretchToMinimumMaximum, True)
            set_ce_callable(ce)
            return True

        if isinstance(renderer, QgsSingleBandGrayRenderer):
            band = renderer.grayBand() or 1
            if apply_to_band(band, renderer.setContrastEnhancement):
                lyr.setRenderer(renderer.clone())
                lyr.triggerRepaint()
            return

        if isinstance(renderer, QgsMultiBandColorRenderer):
            bands_and_setters = [
                (renderer.redBand(),   renderer.setRedContrastEnhancement),
                (renderer.greenBand(), renderer.setGreenContrastEnhancement),
                (renderer.blueBand(),  renderer.setBlueContrastEnhancement),
            ]
            ok_any = False
            for band, setter in bands_and_setters:
                if band and callable(setter):
                    ok_any |= apply_to_band(band, setter)
            if ok_any:
                lyr.setRenderer(renderer.clone())
                lyr.triggerRepaint()
            return

        if isinstance(renderer, (QgsSingleBandPseudoColorRenderer, QgsRasterRenderer)):
            band = 1
            min_val, max_val = self._percentile_from_hist(
                provider=provider,
                band=band,
                lower_pct=current_min,
                upper_pct=current_max,
                extent=extent,
                bins=512,
            )
            if min_val is None or max_val is None:
                self.iface.messageBar().pushMessage(
                    "RasterStretch", "Failed to compute percentiles", level=Qgis.Warning, duration=4
                )
                return
            ce = QgsContrastEnhancement(provider.dataType(band))
            ce.setMinimumValue(min_val)
            ce.setMaximumValue(max_val)
            ce.setContrastEnhancementAlgorithm(QgsContrastEnhancement.StretchToMinimumMaximum, True)
            new_renderer = QgsSingleBandGrayRenderer(provider, band)
            new_renderer.setContrastEnhancement(ce)
            lyr.setRenderer(new_renderer)
            lyr.triggerRepaint()
            return

        self.iface.messageBar().pushMessage(
            "RasterStretch", "Unsupported raster renderer type", level=Qgis.Warning, duration=4
        )

    # ----------------------------------------------------------- Layer helpers

    def _selected_layer(self) -> QgsRasterLayer | None:
        cb = self.dockwidget.comboBox
        layer_id = cb.itemData(cb.currentIndex())
        if not layer_id:
            return None
        lyr = QgsProject.instance().mapLayer(layer_id)
        return lyr if isinstance(lyr, QgsRasterLayer) else None

    def _on_raster_selected(self, _):
        lyr = self._selected_layer()
        if not lyr:
            return

        provider = lyr.dataProvider()
        band_count = provider.bandCount()
        nodata_values = [provider.sourceNoDataValue(b) for b in range(1, band_count + 1)]

        if all(v is None for v in nodata_values):
            self.dockwidget.nodataText.setText("no nodata value found")
        elif len(set(nodata_values)) == 1:
            self.dockwidget.nodataText.setText(f"nodata value found: {nodata_values[0]}")
        else:
            self.dockwidget.nodataText.setText(f"nodata values found (per band): {nodata_values}")

    def _refresh_raster_combo(self):
        """Rebuild combo items from all raster layers in the current project."""
        cb = self.dockwidget.comboBox
        cb.blockSignals(True)
        cb.clear()
        for lyr in QgsProject.instance().mapLayers().values():
            if isinstance(lyr, QgsRasterLayer) and lyr.isValid():
                cb.addItem(lyr.name(), lyr.id())
        cb.blockSignals(False)

    def _on_layers_changed(self, *args):
        self._refresh_raster_combo()

    def _connect_project_signals(self):
        prj = QgsProject.instance()
        prj.layersAdded.connect(self._on_layers_changed)
        prj.layersRemoved.connect(self._on_layers_changed)
        prj.cleared.connect(self._refresh_raster_combo)
        prj.readProject.connect(self._refresh_raster_combo)
