# -*- coding: utf-8 -*-
import os
import tempfile
import numpy as np
from collections import Counter
from osgeo import gdal
from qgis.core import (
    QgsProcessingAlgorithm,
    QgsProcessingParameterRasterLayer,
    QgsProcessingParameterString,
    QgsProcessingParameterFileDestination,
    QgsProcessingParameterVectorLayer,   # 경계는 VectorLayer로
    QgsProcessingContext,
    QgsProcessingFeedback,
    QgsProcessingException,
    QgsVectorFileWriter,
    QgsProcessing,                       # enum: QgsProcessing.TypeVectorPolygon
)
from qgis.PyQt.QtGui import QIcon


def _export_boundary_to_gpkg(boundary_layer, tmp_dir, transform_ctx):
    """
    QgsVectorLayer(폴리곤)를 임시 GPKG로 내보내고 경로 반환.
    - QGIS 버전에 따라 writeAsVectorFormatV2의 리턴 시그니처가 달라서
      (2튜플/4튜플/코드) 안전 처리하고, 실패 시 구버전 API로 폴백.
    """
    if boundary_layer is None:
        return None
    out_path = os.path.join(tmp_dir, "boundary.gpkg")

    try:
        opts = QgsVectorFileWriter.SaveVectorOptions()
        opts.driverName = "GPKG"
        opts.layerName = "boundary"
        res = QgsVectorFileWriter.writeAsVectorFormatV2(
            boundary_layer, out_path, transform_ctx, opts
        )
        if isinstance(res, tuple):
            if len(res) == 2:
                err_code, err_msg = res
                if err_code != QgsVectorFileWriter.NoError:
                    raise QgsProcessingException(f"Failed to export boundary: {err_msg or err_code}")
            elif len(res) == 4:
                err_code, err_msg, _new_path, _new_layer = res
                if err_code != QgsVectorFileWriter.NoError:
                    raise QgsProcessingException(f"Failed to export boundary: {err_msg or err_code}")
        else:
            if res != QgsVectorFileWriter.NoError:
                raise QgsProcessingException(f"Failed to export boundary: {res}")
        return out_path
    except Exception:
        err = QgsVectorFileWriter.writeAsVectorFormat(
            boundary_layer, out_path, "UTF-8", boundary_layer.crs(), "GPKG"
        )
        if err != QgsVectorFileWriter.NoError:
            raise QgsProcessingException(f"Failed to export boundary (fallback): {err}")
        return out_path


def _read_masked_array(raster_path, boundary_path=None):
    """
    래스터를 numpy array로 읽기.
    boundary_path가 있으면 GDAL Warp cutline으로 경계 마스크 적용.
    """
    ds = gdal.Open(raster_path, gdal.GA_ReadOnly)
    if ds is None:
        raise QgsProcessingException(f"Failed to open raster: {raster_path}")
    band = ds.GetRasterBand(1)
    nodata = band.GetNoDataValue()

    if boundary_path:
        dst_nodata = nodata if nodata is not None else -999999
        masked = gdal.Warp(
            "", ds, format="MEM",
            cutlineDSName=boundary_path,
            cropToCutline=False,   # 필요하면 True로 바꾸면 래스터도 크롭됨
            dstNodata=dst_nodata
        )
        if masked is None:
            raise QgsProcessingException("GDAL Warp with cutline failed.")
        arr = masked.GetRasterBand(1).ReadAsArray()
        nodata = dst_nodata
        masked = None
    else:
        arr = band.ReadAsArray()

    ds = None
    return arr, nodata


class GeoBioToolSimpsonAlgorithm(QgsProcessingAlgorithm):

    def name(self) -> str:
        return "compute_simpson_1949"

    def displayName(self) -> str:
        return "Compute Simpson (1949)"

    def group(self) -> str:
        return "1 Raster"

    def groupId(self) -> str:
        return "raster"

    def icon(self) -> QIcon:
        # Shannon과 동일: qrc 경로 고정
        return QIcon(":/icons/icon.png")

    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterRasterLayer(
            "INPUT", "Input classified raster"
        ))
        self.addParameter(QgsProcessingParameterString(
            "CLASSES", "Target classes (e.g., 1,4,6 or 0-9)", optional=True
        ))
        # ✅ 경계(옵션) — VectorLayer + 올바른 enum
        self.addParameter(QgsProcessingParameterVectorLayer(
            "BOUNDARY", "Boundary polygons (optional)",
            types=[QgsProcessing.TypeVectorPolygon],
            optional=True
        ))
        self.addParameter(QgsProcessingParameterFileDestination(
            "OUTPUT_TEXT", "Output text file", fileFilter="Text files (*.txt)"
        ))

    def parse_classes(self, s: str):
        if not s:
            return None
        classes = set()
        for token in s.split(','):
            token = token.strip()
            if '-' in token:
                a, b = map(int, token.split('-'))
                classes.update(range(a, b+1))
            else:
                classes.add(int(token))
        return classes

    def processAlgorithm(self, parameters, context: QgsProcessingContext, feedback: QgsProcessingFeedback) -> dict:
        layer          = self.parameterAsRasterLayer(parameters, "INPUT", context)
        selected       = self.parse_classes(self.parameterAsString(parameters, "CLASSES", context))
        boundary_layer = self.parameterAsVectorLayer(parameters, "BOUNDARY", context)
        out_txt        = self.parameterAsFileOutput(parameters, "OUTPUT_TEXT", context)

        if layer is None:
            raise QgsProcessingException("Input raster is required.")

        # 경계가 있으면 임시 GPKG로 내보내고 cutline 적용
        tmp_dir = tempfile.mkdtemp(prefix="gbt_")
        boundary_path = None
        if boundary_layer is not None:
            boundary_path = _export_boundary_to_gpkg(boundary_layer, tmp_dir, context.transformContext())

        # 마스킹된 래스터 읽기
        arr, nodata = _read_masked_array(layer.source(), boundary_path)

        # NaN/Inf 정리(기존 동작 유지)
        data = arr.astype(np.float32)
        data[np.isnan(data)] = 0
        data[np.isinf(data)] = 0
        flat = data.flatten()

        # 선택 클래스 필터링 또는 기본(0<val<255) + nodata 제외
        if selected:
            vals = flat[np.isin(flat, list(selected))]
        else:
            mask = np.isfinite(flat)
            if nodata is not None:
                mask &= (flat != nodata)
            mask &= (flat > 0) & (flat < 255)
            vals = flat[mask]

        cnt   = Counter(vals)
        total = sum(cnt.values())
        if total == 0:
            raise QgsProcessingException("No valid pixels found.")
        ps    = [v/total for v in cnt.values()]
        simpson = 1 - sum(p * p for p in ps)

        with open(out_txt, "w", encoding="utf-8") as f:
            f.write(f"Total pixels: {total}\n")
            f.write(f"Simpson Index: {simpson:.4f}\n\n")
            f.write("Class ID order and proportions:\n")
            for cls in sorted(cnt.keys()):
                f.write(f"  Class {int(cls)}: {cnt[cls]/total:.4f} ({cnt[cls]} pixels)\n")
            f.write("\nTop classes by proportion:\n")
            for cls, prop in sorted({k: v/total for k, v in cnt.items()}.items(), key=lambda x: x[1], reverse=True):
                f.write(f"  Class {int(cls)}: {prop:.4f} ({cnt[cls]} pixels)\n")

        return {"OUTPUT_TEXT": out_txt}

    def createInstance(self):
        return GeoBioToolSimpsonAlgorithm()
