# -*- coding: utf-8 -*-

"""
/***************************************************************************
 AMERTA
                                 A QGIS plugin
 Analisis Multi-kriteria Embung dan Rencana Tata Air
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2025-09-18
        copyright            : (C) 2025 by Badan Riset dan Inovasi Nasional
        email                : sitaranisafitri@gmail.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Sitarani Safitri, Orbita Roswintiarti, Okta Fajar Saputra, Galdita Aruba Chulafak, Gatot Nugroho, Wismu Sunarmodo, Kusumaning Ayu Dyah Sukowati, Hana Listi Fitriana'
__date__ = '2025-09-18'
__copyright__ = '(C) 2025 by Badan Riset dan Inovasi Nasional'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import math, os
from contextlib import contextmanager
from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.core import (
    QgsProcessing, QgsProcessingAlgorithm, QgsProcessingException,
    QgsProcessingParameterRasterLayer, QgsProcessingParameterVectorLayer,
    QgsProcessingParameterEnum, QgsProcessingParameterNumber,
    QgsProcessingParameterVectorDestination, QgsProcessingParameterBoolean,
    QgsCoordinateReferenceSystem, QgsField, QgsVectorLayer, QgsProcessingUtils,
    QgsRasterLayer, QgsPointXY
)
import processing
import os
from qgis.PyQt.QtGui import QIcon
from qgis.PyQt.QtCore import QUrl

@contextmanager
def edit(layer):
    layer.startEditing()
    try: 
        yield
    except Exception:
        layer.rollBack()
        raise
    else:
        layer.commitChanges()

def score_twi(val):
    if val is None:
        return None
    try:
        v = float(val)
    except Exception:
        return None
    if v < 0:
        return 1
    elif v <= 5:
        return 2
    elif v <= 10:
        return 3
    elif v <= 15:
        return 4
    else:
        return 5

class GridTwiAlgorithm(QgsProcessingAlgorithm):
    RASTER_TWI  = 'RASTER_TWI'
    AOI         = 'AOI'
    GRID_SIZE   = 'GRID_SIZE'
    CUSTOM_M    = 'CUSTOM_M'
    BUILD_INDEX = 'BUILD_INDEX'
    FILL_GAPS   = 'FILL_GAPS'
    LAND_MASK   = 'LAND_MASK'   # NEW: optional land mask
    OUTPUT      = 'OUTPUT'

    GRID_CHOICES = ['30″ (≈ 925 m)', '5″ (≈ 150 m)', 'Custom (meter)']

    def tr(self, s):
        return QCoreApplication.translate('Processing', s)

    def name(self):
        return 'a_grid_twi_uppercase'

    def displayName(self):
        return self.tr('Grid TWI')

    def groupId(self):
        return 'C. MCDA Factors for Retention Ponds'

    def group(self):
        return self.tr('C. MCDA Factors for Retention Ponds')

    def icon(self):
        return QIcon(os.path.join(os.path.dirname(__file__), 'preanalysis.png'))

    def shortHelpString(self):
        return self.tr("""\
🇮🇩 ID  Modul ini membuat grid dan menetapkan nilai Topographic Wetness Index (TWI) per grid menggunakan Zonal Statistics (Majority; cadangan Maximum), lalu menghitung skor S_TWI.

Alur kerja:
1) Siapkan raster TWI dan data spasial AOI (mis. batas administrasi / DAS).
2) Pilih resolusi grid dari dropdown atau isikan nilai custom (meter) — 30″ / 5″ / Custom.
3) Sistem akan menyesuaikan proyeksi dan extent ke AOI secara otomatis.
4) Hitung TWI per sel grid dengan Majority; jika NULL, gunakan Maximum; jika masih NULL, sampling centroid (opsional isi NoData cepat).
5) (Opsional) Sisa sel NULL di daratan (berdasarkan land mask, bila diisi) diisi nilai default TWI=0 (skor S_TWI=1).
6) Mengisi skor S_TWI berdasarkan rentang nilai TWI.

Keluaran:
• Grid TWI berisi kolom: id, TWI, S_TWI.

──────────────

🌍 EN  This module builds a grid and assigns Topographic Wetness Index (TWI) per grid via Zonal Statistics (Majority; Maximum as fallback), then computes S_TWI scores.

Workflow:
1) Prepare a TWI raster and an AOI layer (e.g., administrative boundary / watershed).
2) Choose grid resolution from the dropdown or enter a custom value in meters — 30″ / 5″ / Custom.
3) The tool auto-aligns projection and extent to the AOI.
4) Compute per-cell TWI using Majority; if NULL, fall back to Maximum; if still NULL, use centroid sampling (optional quick NoData fill).
5) (Optional) Remaining NULL cells on land (based on land mask, if provided) are filled with default TWI=0 (S_TWI=1).
6) Assign S_TWI scores based on TWI range values.

Output:
• TWI grid with fields: id, TWI, S_TWI.""")

    def createInstance(self):
        return GridTwiAlgorithm()

    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterRasterLayer(
            self.RASTER_TWI, self.tr('Topographic Wetness Index (TWI)')))
        self.addParameter(QgsProcessingParameterVectorLayer(
            self.AOI, self.tr('Area of Interest (AOI)'),
            types=[QgsProcessing.TypeVectorPolygon]))
        self.addParameter(QgsProcessingParameterEnum(
            self.GRID_SIZE, self.tr('Grid Resolution'),
            self.GRID_CHOICES, defaultValue=0))
        self.addParameter(QgsProcessingParameterNumber(
            self.CUSTOM_M, self.tr('Grid Resolution (meter) for "Custom (meter)"'),
            type=QgsProcessingParameterNumber.Double, defaultValue=1.0, minValue=0.0))
        self.addParameter(QgsProcessingParameterBoolean(
            self.BUILD_INDEX, self.tr('Create spatial index on output'),
            defaultValue=True))
        self.addParameter(QgsProcessingParameterBoolean(
            self.FILL_GAPS, self.tr('Fill NoData (quick, 40 px)'),
            defaultValue=False))
        # NEW: optional land mask polygon
        self.addParameter(QgsProcessingParameterVectorLayer(
            self.LAND_MASK, self.tr('Land mask (polygons)'),
            types=[QgsProcessing.TypeVectorPolygon],
            optional=True))
        self.addParameter(QgsProcessingParameterVectorDestination(
            self.OUTPUT, self.tr('Grid TWI')))

    # ----- helpers -----
    def _as_layer(self, ref, context):
        try:
            if hasattr(ref, 'extent'):
                return ref
            lyr = QgsProcessingUtils.mapLayerFromString(ref, context)
            if lyr is not None:
                return lyr
            lyr = QgsVectorLayer(ref, 'layer', 'ogr')
            if lyr is not None and lyr.isValid():
                return lyr
        except Exception:
            pass
        raise QgsProcessingException(self.tr('Failed to load layer from reference: {}').format(ref))

    def _mk_spatial_index(self, vlayer, context, feedback):
        try:
            processing.run('native:createspatialindex', {'INPUT': vlayer},
                           context=context, feedback=feedback, is_child_algorithm=True)
        except Exception:
            pass
        return vlayer

    def _fix_geoms(self, vlayer, context, feedback):
        try:
            return processing.run('native:fixgeometries',
                                  {'INPUT': vlayer, 'OUTPUT':'TEMPORARY_OUTPUT'},
                                  context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']
        except Exception:
            return vlayer

    def _dissolve_all(self, vlayer, context, feedback):
        try:
            return processing.run('native:dissolve',
                                  {'INPUT': vlayer, 'FIELD':[], 'SEPARATE_DISJOINT': False,
                                   'OUTPUT':'TEMPORARY_OUTPUT'},
                                  context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']
        except Exception:
            return vlayer

    def _to_wgs84(self, layer, context, feedback):
        return processing.run('native:reprojectlayer',
            {'INPUT': layer, 'TARGET_CRS': QgsCoordinateReferenceSystem('EPSG:4326'),
             'OPERATION':'', 'OUTPUT':'TEMPORARY_OUTPUT'},
            context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']

    def _warp_to_wgs84_res(self, raster, xres_deg, yres_deg, context, feedback, target_extent_wkt=None):
        params = {
            'INPUT': raster, 'SOURCE_CRS': None,
            'TARGET_CRS': QgsCoordinateReferenceSystem('EPSG:4326'),
            'RESAMPLING': 0, 'NODATA': -9999,
            'X_RES': float(xres_deg), 'Y_RES': float(yres_deg),
            'MULTITHREADING': True, 'DATA_TYPE': 6,
            'TARGET_EXTENT': target_extent_wkt if target_extent_wkt else None,
            'TARGET_EXTENT_CRS': QgsCoordinateReferenceSystem('EPSG:4326') if target_extent_wkt else None,
            'OUTPUT':'TEMPORARY_OUTPUT'
        }
        return processing.run('gdal:warpreproject', params,
                              context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']

    def _fill_nodata(self, in_raster, distance, context, feedback):
        if distance <= 0:
            return in_raster
        return processing.run('gdal:fillnodata',
            {'INPUT': in_raster, 'BAND':1, 'DISTANCE': int(distance), 'ITERATIONS':0,
             'NO_MASK': True, 'MASK_LAYER': None, 'MASK_VALID_DATA': False,
             'USE_MULTITHREADING': True, 'OUTPUT':'TEMPORARY_OUTPUT'},
            context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']

    def _extent_str(self, layer_obj):
        e = layer_obj.extent()
        return f'{e.xMinimum()},{e.xMaximum()},{e.yMinimum()},{e.yMaximum()} [EPSG:4326]'

    def _deg_spacing_from_option(self, option_idx, custom_m, aoi_wgs_layer_obj):
        if option_idx == 0:
            return 30.0/3600.0, 30.0/3600.0
        if option_idx == 1:
            return 5.0/3600.0, 5.0/3600.0
        c = aoi_wgs_layer_obj.extent().center()
        mean_lat_rad = math.radians(c.y())
        deg_lat = float(custom_m) / 111320.0
        coslat = max(1e-6, math.cos(mean_lat_rad))
        deg_lon = float(custom_m) / (111320.0 * coslat)
        return deg_lon, deg_lat

    def _assign_stable_id_global(self, vlayer, dx_deg, dy_deg, lon0=-180.0, lat0=-90.0, numeric=True):
        prov = vlayer.dataProvider()
        if vlayer.fields().indexOf('id') == -1:
            prov.addAttributes([QgsField('id', QVariant.Int if numeric else QVariant.String)])
            vlayer.updateFields()
        idx_id = vlayer.fields().indexOf('id')
        arcsec_x = int(round(dx_deg * 3600.0 + 1e-9))
        with edit(vlayer):
            for ft in vlayer.getFeatures():
                c = ft.geometry().centroid().asPoint()
                ix = int((c.x() - lon0) // dx_deg)
                iy = int((c.y() - lat0) // dy_deg)
                gid = int(iy) * 1_000_000 + int(ix) if numeric else f"G{arcsec_x}s_{iy:06d}_{ix:06d}"
                vlayer.changeAttributeValue(ft.id(), idx_id, gid)
        return vlayer

    def _keep_only_fields(self, vlayer, keep_names):
        fields = vlayer.fields()
        drop = [i for i, f in enumerate(fields) if f.name() not in keep_names]
        if drop:
            with edit(vlayer):
                vlayer.dataProvider().deleteAttributes(sorted(drop, reverse=True))
                vlayer.updateFields()

    # ----- core -----
    def processAlgorithm(self, parameters, context, feedback):
        twi_ras  = self.parameterAsRasterLayer(parameters, self.RASTER_TWI, context)
        aoi_vec  = self.parameterAsVectorLayer(parameters, self.AOI, context)
        grid_idx = int(self.parameterAsEnum(parameters, self.GRID_SIZE, context))
        custom_m = float(self.parameterAsDouble(parameters, self.CUSTOM_M, context))
        make_idx = bool(self.parameterAsBoolean(parameters, self.BUILD_INDEX, context))
        do_fill  = bool(self.parameterAsBoolean(parameters, self.FILL_GAPS, context))
        land_vec = self.parameterAsVectorLayer(parameters, self.LAND_MASK, context)  # NEW
        out_path = self.parameterAsOutputLayer(parameters, self.OUTPUT, context)

        if twi_ras is None:
            raise QgsProcessingException(self.tr('Raster TWI tidak valid.'))
        if aoi_vec is None:
            raise QgsProcessingException(self.tr('AOI layer tidak valid.'))

        # AOI → WGS84 → fix → dissolve → index
        aoi_wgs = self._to_wgs84(aoi_vec, context, feedback)
        aoi_wgs = self._fix_geoms(aoi_wgs, context, feedback)
        aoi_wgs = self._dissolve_all(aoi_wgs, context, feedback)
        aoi_wgs = self._as_layer(aoi_wgs, context)
        self._mk_spatial_index(aoi_wgs, context, feedback)

        # NEW: siapkan geometri land mask (jika ada)
        land_geoms = None
        if land_vec is not None:
            land_layer = self._as_layer(land_vec, context)
            land_layer = self._fix_geoms(land_layer, context, feedback)
            self._mk_spatial_index(land_layer, context, feedback)
            land_geoms = [f.geometry() for f in land_layer.getFeatures()
                          if f.geometry() is not None and not f.geometry().isEmpty()]
            feedback.pushInfo(self.tr(f'Land mask loaded with {len(land_geoms)} polygon(s).'))

        # resolusi grid
        xdeg, ydeg = self._deg_spacing_from_option(grid_idx, custom_m, aoi_wgs)
        feedback.pushInfo(self.tr(f'Grid resolution (deg): dX={xdeg}, dY={ydeg}'))

        # warp raster ke extent AOI
        aoi_extent_wkt = self._extent_str(aoi_wgs)
        twi_wgs0 = self._warp_to_wgs84_res(twi_ras, xdeg, ydeg, context, feedback, target_extent_wkt=aoi_extent_wkt)
        twi_proc = self._fill_nodata(twi_wgs0, 40, context, feedback) if do_fill else twi_wgs0

        # build grid & subset AOI
        grid = processing.run('native:creategrid',
            {'TYPE':2, 'EXTENT': aoi_extent_wkt,
             'HSPACING': xdeg, 'VSPACING': ydeg, 'HOVERLAY':0.0, 'VOVERLAY':0.0,
             'CRS': QgsCoordinateReferenceSystem('EPSG:4326'), 'OUTPUT':'TEMPORARY_OUTPUT'},
            context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']
        self._mk_spatial_index(grid, context, feedback)

        grid_aoi = processing.run('qgis:extractbylocation',
            {'INPUT': grid, 'PREDICATE':[0], 'INTERSECT': aoi_wgs, 'OUTPUT':'TEMPORARY_OUTPUT'},
            context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']

        grid_aoi_layer = self._as_layer(grid_aoi, context)
        grid_aoi_layer = self._assign_stable_id_global(grid_aoi_layer, dx_deg=xdeg, dy_deg=ydeg,
                                                       lon0=-180.0, lat0=-90.0, numeric=True)
        self._mk_spatial_index(grid_aoi_layer, context, feedback)
        grid_aoi = processing.run('native:savefeatures',
            {'INPUT': grid_aoi_layer, 'OUTPUT':'TEMPORARY_OUTPUT'},
            context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']

        # Zonal Majority
        z1 = processing.run('native:zonalstatisticsfb',
            {'INPUT': grid_aoi, 'INPUT_RASTER': twi_proc, 'RASTER_BAND':1,
             'COLUMN_PREFIX':'twi_', 'STATISTICS':[8], 'OUTPUT':'TEMPORARY_OUTPUT'},
            context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']
        v = self._as_layer(z1, context)

        # pastikan field target uppercase
        adds = []
        if v.fields().indexOf('TWI') == -1:
            adds.append(QgsField('TWI', QVariant.Double))
        if v.fields().indexOf('S_TWI') == -1:
            adds.append(QgsField('S_TWI', QVariant.Int))
        if adds:
            v.dataProvider().addAttributes(adds)
            v.updateFields()
        idx_TWI = v.fields().indexOf('TWI')
        idx_STW = v.fields().indexOf('S_TWI')

        # ambil kolom majority dari prefix twi_
        maj_field = None
        for f in v.fields():
            ln = f.name().lower()
            if ln.startswith('twi_') and ('major' in ln or 'maj' in ln or 'mode' in ln or 'frequ' in ln):
                maj_field = f.name()
                break
        idx_maj = v.fields().indexOf(maj_field) if maj_field else -1

        null_after_maj = 0
        with edit(v):
            for ft in v.getFeatures():
                val = ft[idx_maj] if idx_maj != -1 else None
                if val is None:
                    null_after_maj += 1
                v.changeAttributeValue(ft.id(), idx_TWI, val)
                v.changeAttributeValue(ft.id(), idx_STW, score_twi(val))
        feedback.pushInfo(self.tr(f'After Majority: {null_after_maj} NULL'))

        # Fallback Maximum
        if null_after_maj > 0:
            z2 = processing.run('native:zonalstatisticsfb',
                {'INPUT': v, 'INPUT_RASTER': twi_proc, 'RASTER_BAND':1,
                 'COLUMN_PREFIX':'twi2_', 'STATISTICS':[6], 'OUTPUT':'TEMPORARY_OUTPUT'},
                context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']
            v2 = self._as_layer(z2, context)
            max_field = None
            for f in v2.fields():
                ln = f.name().lower()
                if ln.startswith('twi2_') and ('max' in ln or 'maximum' in ln):
                    max_field = f.name()
                    break
            idx_max = v2.fields().indexOf(max_field) if max_field else -1

            null_after_max = 0
            with edit(v2):
                for ft in v2.getFeatures():
                    if ft[idx_TWI] is None:
                        val = ft[idx_max] if idx_max != -1 else None
                        if val is None:
                            null_after_max += 1
                        v2.changeAttributeValue(ft.id(), idx_TWI, val)
                        v2.changeAttributeValue(ft.id(), idx_STW, score_twi(val))
            feedback.pushInfo(self.tr(f'After Maximum fallback: {null_after_max} NULL'))
            v = v2

        # Sampling centroid (terakhir)
        if any(ft[idx_TWI] is None for ft in v.getFeatures()):
            r = QgsRasterLayer(twi_proc, 'twi_res')
            if not r.isValid():
                raise QgsProcessingException('Raster sampling tidak valid.')
            prov = r.dataProvider()
            filled = 0
            with edit(v):
                for ft in v.getFeatures():
                    if ft[idx_TWI] is None:
                        c = ft.geometry().centroid().asPoint()
                        rv, ok = prov.sample(QgsPointXY(c), 1)
                        if ok and rv is not None:
                            try:
                                fv = float(rv)
                                if not math.isnan(fv):
                                    v.changeAttributeValue(ft.id(), idx_TWI, fv)
                                    v.changeAttributeValue(ft.id(), idx_STW, score_twi(fv))
                                    filled += 1
                            except Exception:
                                pass
            feedback.pushInfo(self.tr(f'Centroid sampling filled {filled}'))

        # NEW: fallback terakhir untuk sisa NULL
        #      - hanya aktif jika Fill NoData dicentang
        #      - kalau land mask diisi: hanya sel yang beririsan daratan yang diisi
        if do_fill:
            final_filled = 0
            with edit(v):
                for ft in v.getFeatures():
                    if ft[idx_TWI] is None:
                        on_land = True
                        if land_geoms is not None:
                            g = ft.geometry()
                            on_land = g is not None and any(g.intersects(lg) for lg in land_geoms)
                        if on_land:
                            # Nilai default: TWI = 0, skor minimum
                            v.changeAttributeValue(ft.id(), idx_TWI, 0.0)
                            v.changeAttributeValue(ft.id(), idx_STW, score_twi(0.0))
                            final_filled += 1
            feedback.pushInfo(self.tr(f'Final default fill applied to {final_filled} cells'))

        # bersihkan kolom temp
        drop_idxs = [i for i, f in enumerate(v.fields())
                     if f.name().lower().startswith('twi_') or f.name().lower().startswith('twi2_')]
        if drop_idxs:
            with edit(v):
                v.dataProvider().deleteAttributes(sorted(drop_idxs, reverse=True))
                v.updateFields()

        # keep only final
        self._keep_only_fields(v, keep_names=['id', 'TWI', 'S_TWI'])

        # save
        try:
            if isinstance(out_path, str) and out_path and out_path != 'TEMPORARY_OUTPUT':
                base = out_path.split('|')[0]
                if base.lower().endswith('.gpkg') and os.path.exists(base):
                    os.remove(base)
        except Exception:
            pass

        saved = processing.run('native:savefeatures',
                               {'INPUT': v, 'OUTPUT': out_path},
                               context=context, feedback=feedback, is_child_algorithm=True)['OUTPUT']
        if make_idx:
            self._mk_spatial_index(saved, context, feedback)

        return {self.OUTPUT: saved}
