# -*- coding: utf-8 -*-
import os
import tempfile
import numpy as np
from osgeo import gdal
from qgis.core import (
    QgsProcessingAlgorithm,
    QgsProcessingParameterRasterLayer,
    QgsProcessingParameterString,
    QgsProcessingParameterFileDestination,
    QgsProcessingParameterVectorLayer,
    QgsProcessingParameterEnum,
    QgsProcessingParameterNumber,
    QgsProcessingParameterRasterDestination,
    QgsProcessingContext,
    QgsProcessingFeedback,
    QgsProcessingException,
    QgsVectorFileWriter,
    QgsProcessing,  # enum: QgsProcessing.TypeVectorPolygon
)
from qgis.PyQt.QtGui import QIcon


def _export_boundary_to_gpkg(boundary_layer, tmp_dir, transform_ctx):
    """
    Boundary 폴리곤을 임시 GPKG로 저장 후 경로 반환 (QGIS 버전별 API 차이 대응)
    """
    if boundary_layer is None:
        return None
    out_path = os.path.join(tmp_dir, "boundary.gpkg")
    try:
        opts = QgsVectorFileWriter.SaveVectorOptions()
        opts.driverName = "GPKG"
        opts.layerName = "boundary"
        res = QgsVectorFileWriter.writeAsVectorFormatV2(
            boundary_layer, out_path, transform_ctx, opts
        )
        if isinstance(res, tuple):
            if len(res) == 2:
                err_code, err_msg = res
                if err_code != QgsVectorFileWriter.NoError:
                    raise QgsProcessingException(f"Failed to export boundary: {err_msg or err_code}")
            elif len(res) == 4:
                err_code, err_msg, _new_path, _new_layer = res
                if err_code != QgsVectorFileWriter.NoError:
                    raise QgsProcessingException(f"Failed to export boundary: {err_msg or err_code}")
        else:
            if res != QgsVectorFileWriter.NoError:
                raise QgsProcessingException(f"Failed to export boundary: {res}")
        return out_path
    except Exception:
        err = QgsVectorFileWriter.writeAsVectorFormat(
            boundary_layer, out_path, "UTF-8", boundary_layer.crs(), "GPKG"
        )
        if err != QgsVectorFileWriter.NoError:
            raise QgsProcessingException(f"Failed to export boundary (fallback): {err}")
        return out_path


def _read_masked_array(raster_path, boundary_path=None):
    """
    래스터를 읽고, boundary가 있으면 cutline으로 마스킹하여 배열/메타데이터 반환
    """
    ds = gdal.Open(raster_path, gdal.GA_ReadOnly)
    if ds is None:
        raise QgsProcessingException(f"Failed to open raster: {raster_path}")
    band = ds.GetRasterBand(1)
    nodata = band.GetNoDataValue()

    if boundary_path:
        dst_nodata = nodata if nodata is not None else -999999
        masked = gdal.Warp(
            "", ds, format="MEM",
            cutlineDSName=boundary_path,
            cropToCutline=False,
            dstNodata=dst_nodata
        )
        if masked is None:
            raise QgsProcessingException("GDAL Warp with cutline failed.")
        arr = masked.GetRasterBand(1).ReadAsArray()
        nodata = dst_nodata
        masked = None
    else:
        arr = band.ReadAsArray()

    gt = ds.GetGeoTransform()
    proj = ds.GetProjection()
    ds = None
    return arr, nodata, gt, proj


class GeoBioToolBetaLocalAlgorithm(QgsProcessingAlgorithm):
    """
    이동창 기반 로컬 β-다양성 (Jaccard / Sorensen)
    - ring 이웃은 완전 비겹침(step = window)
    - 존재 판정 임계치(PRESENCE_MIN_PROP)로 희귀 잡음 억제
    - 출력: TXT 요약 + (선택) GeoTIFF 래스터
    """

    # keys
    INPUT = "INPUT"; CLASSES = "CLASSES"; BOUNDARY = "BOUNDARY"
    NEIGHBOR = "NEIGHBOR"; AGGREGATE = "AGGREGATE"; METRIC = "METRIC"
    WINDOW = "WINDOW"; MIN_VALID = "MIN_VALID"; PRESENCE_MIN_PROP = "PRESENCE_MIN_PROP"
    OUTPUT_TEXT = "OUTPUT_TEXT"; OUTPUT_RASTER = "OUTPUT_RASTER"; NODATA = "NODATA"

    NEI_CHOICES = ["8", "ring"]                 # default: ring
    AGG_CHOICES = ["mean", "max"]               # default: max
    METRIC_CHOICES = ["jaccard", "sorensen"]    # default: sorensen

    def name(self) -> str:
        return "compute_beta_local"

    def displayName(self) -> str:
        return "Compute Local Beta (Jaccard / Sorensen)"

    def group(self) -> str:
        return "1 Raster"

    def groupId(self) -> str:
        return "raster"

    def icon(self) -> QIcon:
        return QIcon(":/icons/icon.png")

    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterRasterLayer(
            self.INPUT, "Input classified raster"
        ))
        self.addParameter(QgsProcessingParameterString(
            self.CLASSES, "Target classes (e.g., 1,4,6 or 0-9)", optional=True
        ))
        self.addParameter(QgsProcessingParameterVectorLayer(
            self.BOUNDARY, "Boundary polygons (optional)",
            types=[QgsProcessing.TypeVectorPolygon], optional=True
        ))

        # ---- Enum(콤보) & 숫자 파라미터 (안정 프리셋 기본값) ----
        self.addParameter(QgsProcessingParameterEnum(
            self.NEIGHBOR, "Neighbor scheme", self.NEI_CHOICES, defaultValue=1  # ring
        ))
        self.addParameter(QgsProcessingParameterEnum(
            self.AGGREGATE, "Aggregate neighbors", self.AGG_CHOICES, defaultValue=1  # max
        ))
        self.addParameter(QgsProcessingParameterEnum(
            self.METRIC, "Metric", self.METRIC_CHOICES, defaultValue=1  # sorensen
        ))
        self.addParameter(QgsProcessingParameterNumber(
            self.WINDOW, "Window size (odd, pixels)",
            QgsProcessingParameterNumber.Integer, defaultValue=31, minValue=1
        ))
        self.addParameter(QgsProcessingParameterNumber(
            self.MIN_VALID, "Min valid ratio within window (0–1)",
            QgsProcessingParameterNumber.Double, defaultValue=0.6, minValue=0.0, maxValue=1.0
        ))
        self.addParameter(QgsProcessingParameterNumber(
            self.PRESENCE_MIN_PROP, "Presence threshold within window (0–1, by proportion)",
            QgsProcessingParameterNumber.Double, defaultValue=0.02, minValue=0.0, maxValue=1.0
        ))

        # 출력
        self.addParameter(QgsProcessingParameterFileDestination(
            self.OUTPUT_TEXT, "Output text file", fileFilter="Text files (*.txt)"
        ))
        self.addParameter(QgsProcessingParameterRasterDestination(
            self.OUTPUT_RASTER, "Output turnover raster (GeoTIFF)"
        ))
        self.addParameter(QgsProcessingParameterNumber(
            self.NODATA, "Output NoData value (raster)",
            QgsProcessingParameterNumber.Double, defaultValue=-9999
        ))

    # -------- helpers --------
    @staticmethod
    def _box_sum_2d(arr, win):
        """
        입력과 동일 (H,W) 크기의 2D 박스합. 적분영상(누적합) 방식으로 정확히 정합되는 크기 반환.
        """
        r = win // 2
        ap = np.pad(arr, ((r, r), (r, r)), mode='constant')
        cs0 = np.concatenate([np.zeros((1, ap.shape[1]), dtype=np.float64),
                              np.cumsum(ap, axis=0, dtype=np.float64)], axis=0)
        v = cs0[win:, :] - cs0[:-win, :]
        cs1 = np.concatenate([np.zeros((v.shape[0], 1), dtype=np.float64),
                              np.cumsum(v, axis=1, dtype=np.float64)], axis=1)
        h = cs1[:, win:] - cs1[:, :-win]
        return h.astype(np.float32, copy=False)

    @staticmethod
    def _shift_bool(arr_bool, dy, dx):
        """
        불리언 2D 배열을 (dy,dx)만큼 이동. 범위 밖은 False로 채움.
        """
        h, w = arr_bool.shape
        out = np.zeros_like(arr_bool, dtype=bool)
        y0 = max(0, dy); y1 = h + min(0, dy)
        x0 = max(0, dx); x1 = w + min(0, dx)
        out[y0:y1, x0:x1] = arr_bool[y0 - dy:y1 - dy, x0 - dx:x1 - dx]
        return out

    def parse_classes(self, s: str):
        if not s:
            return None
        classes = set()
        for token in s.split(','):
            token = token.strip()
            if not token:
                continue
            if '-' in token:
                a, b = map(int, token.split('-'))
                classes.update(range(a, b + 1))
            else:
                classes.add(int(token))
        return classes

    def processAlgorithm(self, parameters, context: QgsProcessingContext, feedback: QgsProcessingFeedback) -> dict:
        # --- 입력 파라미터 ---
        layer          = self.parameterAsRasterLayer(parameters, self.INPUT, context)
        selected       = self.parse_classes(self.parameterAsString(parameters, self.CLASSES, context))
        boundary_layer = self.parameterAsVectorLayer(parameters, self.BOUNDARY, context)

        nei_idx   = self.parameterAsEnum(parameters, self.NEIGHBOR, context)
        agg_idx   = self.parameterAsEnum(parameters, self.AGGREGATE, context)
        met_idx   = self.parameterAsEnum(parameters, self.METRIC, context)
        win       = int(self.parameterAsInt(parameters, self.WINDOW, context))
        min_valid = float(self.parameterAsDouble(parameters, self.MIN_VALID, context))
        pres_prop = float(self.parameterAsDouble(parameters, self.PRESENCE_MIN_PROP, context))

        out_txt    = self.parameterAsFileOutput(parameters, self.OUTPUT_TEXT, context)
        out_raster = self.parameterAsOutputLayer(parameters, self.OUTPUT_RASTER, context)
        out_nodata = float(self.parameterAsDouble(parameters, self.NODATA, context))

        if layer is None:
            raise QgsProcessingException("Input raster is required.")
        if win % 2 == 0:
            win += 1  # 홀수 보장

        neighbor  = self.NEI_CHOICES[nei_idx]
        aggregate = self.AGg_CHOICES[agg_idx] if hasattr(self, 'AGg_CHOICES') else self.AGG_CHOICES[agg_idx]  # (오타 가드)
        metric    = self.METRIC_CHOICES[met_idx]

        # --- 경계 cutline ---
        tmp_dir = tempfile.mkdtemp(prefix="gbt_")
        boundary_path = None
        if boundary_layer is not None:
            boundary_path = _export_boundary_to_gpkg(boundary_layer, tmp_dir, context.transformContext())

        # --- 래스터 읽기 ---
        arr, nodata, gt, proj = _read_masked_array(layer.source(), boundary_path)
        data = arr.astype(np.int64, copy=False)

        valid_px = np.ones_like(data, dtype=bool)
        if nodata is not None:
            valid_px &= (data != int(nodata))

        if selected:
            class_mask = np.isin(data, list(selected))
            valid_px &= class_mask

        # 창 유효비율
        valid_counts = self._box_sum_2d(valid_px.astype(np.float32), win)  # (H,W)
        window_area  = float(win * win)
        keep_center  = (valid_counts / window_area) >= min_valid
        if not np.any(keep_center):
            raise QgsProcessingException("No valid windows after applying MIN_VALID. Try reducing MIN_VALID or enlarging WINDOW.")

        # 클래스 목록
        classes = np.unique(data[valid_px])
        if classes.size == 0:
            raise QgsProcessingException("No valid class pixels found (after class filter / boundary).")

        h, w = data.shape
        focal_S = np.zeros((h, w), dtype=np.float32)

        # 이웃: ring은 창크기만큼 띄워 완전 비겹침
        step = 1 if neighbor == "8" else win
        offsets = [(-step,-step),(-step,0),(-step,step),
                   (0,-step),(0,step),
                   (step,-step),(step,0),(step,step)]

        neigh_S = [np.zeros((h, w), dtype=np.float32) for _ in offsets]
        inter   = [np.zeros((h, w), dtype=np.float32) for _ in offsets]

        # 존재 임계치(픽셀 수) = pres_prop * 창 유효픽셀수
        pres_min_count = (valid_counts * pres_prop).astype(np.float32)

        # 클래스별 존재/부재 → 박스합 → 종수/교집합 누적
        for i, cls in enumerate(classes):
            if feedback.isCanceled():
                break
            if (i % max(1, classes.size // 20)) == 0:
                feedback.setProgress(int(100 * i / max(1, classes.size)))

            m = (data == cls) & valid_px
            cnt = self._box_sum_2d(m.astype(np.float32), win)     # (H,W)
            pres = (cnt >= pres_min_count) & (valid_counts > 0)   # 존재 임계치 적용
            focal_S += pres.astype(np.float32)

            for k, (dy, dx) in enumerate(offsets):
                pres_n = self._shift_bool(pres, dy, dx)
                neigh_S[k] += pres_n.astype(np.float32)
                inter[k]   += (pres & pres_n).astype(np.float32)

        # 유효(중심 & 이웃 중심)
        valid_nei_center = [self._shift_bool(keep_center, dy, dx) for (dy, dx) in offsets]
        both_valid = [keep_center & vn for vn in valid_nei_center]

        # 거리 계산 (0–1)
        eps = 1e-12
        dist_stack = []
        for k in range(len(offsets)):
            if metric == "jaccard":
                U = focal_S + neigh_S[k] - inter[k]
                d = np.full((h, w), np.nan, dtype=np.float32)
                mask = (U > eps) & both_valid[k]
                d[mask] = 1.0 - (inter[k][mask] / U[mask]).astype(np.float32)
            else:
                denom = focal_S + neigh_S[k]
                d = np.full((h, w), np.nan, dtype=np.float32)
                mask = (denom > eps) & both_valid[k]
                d[mask] = 1.0 - (2.0 * inter[k][mask] / denom[mask]).astype(np.float32)
            dist_stack.append(d)

        dist_stack = np.stack(dist_stack, axis=0)

        # 이웃 집계
        if aggregate == "mean":
            num = np.nansum(dist_stack, axis=0)
            den = np.sum(np.isfinite(dist_stack), axis=0)
            out = np.full((h, w), np.nan, dtype=np.float32)
            m = den > 0
            out[m] = (num[m] / den[m]).astype(np.float32)
        else:
            out = np.nanmax(dist_stack, axis=0)

        # 통계 (유효 창 중심만)
        v = out[keep_center & np.isfinite(out)]
        if v.size == 0:
            raise QgsProcessingException("No valid turnover values; check parameters.")
        q = np.percentile(v, [5, 25, 50, 75, 95]).tolist()
        stats = dict(
            n=int(v.size),
            mean=float(np.mean(v)),
            std=float(np.std(v)),
            min=float(np.min(v)),
            q05=float(q[0]), q25=float(q[1]), q50=float(q[2]),
            q75=float(q[3]), q95=float(q[4]),
            max=float(np.max(v))
        )

        # TXT 출력 (+ 분포가 0쪽 치우치면 힌트)
        with open(out_txt, "w", encoding="utf-8") as f:
            f.write("[GeoBioTool] Local beta-diversity (window-based) report\n")
            f.write(f"Input: {layer.source()}\n")
            if selected:
                f.write(f"Classes: {sorted(list(selected))}\n")
            if boundary_layer is not None:
                f.write("Boundary: applied (cutline)\n")
            f.write(f"Metric: {metric}\n")
            f.write(f"Neighbor: {neighbor}\n")
            f.write(f"Aggregate: {aggregate}\n")
            f.write(f"Window(px): {win}\n")
            f.write(f"Min valid ratio: {min_valid}\n")
            f.write(f"Presence min prop: {pres_prop}\n\n")
            f.write(f"[turnover_stats] {stats}\n")

        if stats["q75"] == 0.0 or stats["mean"] < 0.05:
            with open(out_txt, "a", encoding="utf-8") as f:
                f.write("\n[hint] Values are concentrated near 0. "
                        "Try Neighbor=ring, Aggregate=max, and a larger Window (21–31 px).\n")

        # (선택) 래스터 저장
        outputs = {self.OUTPUT_TEXT: out_txt}
        if out_raster:
            driver = gdal.GetDriverByName('GTiff')
            dst = driver.Create(out_raster, w, h, 1, gdal.GDT_Float32,
                                options=['COMPRESS=LZW', 'TILED=YES'])
            dst.SetGeoTransform(gt); dst.SetProjection(proj)
            rb = dst.GetRasterBand(1); rb.SetNoDataValue(out_nodata)
            out_img = out.copy()
            out_img[~np.isfinite(out_img)] = out_nodata
            out_img[~keep_center] = out_nodata
            rb.WriteArray(out_img)
            rb.SetDescription("Local_beta_turnover")
            dst.FlushCache(); dst = None
            outputs[self.OUTPUT_RASTER] = out_raster

        return outputs

    def createInstance(self):
        return GeoBioToolBetaLocalAlgorithm()
