from __future__ import annotations

from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

from qgis.core import QgsFeature, QgsGeometry, QgsPointXY, QgsVectorLayer, edit

from topaze.toolbelt import PlgLogger, i18n

from .calc.helmert_transformation import HelmertTransformation


class MonumentsPointsManager:
    """Centralized management of Helmert control points (monuments) and residuals.

    This class is responsible for:
      - adding/removing homologous control point pairs in hel_monuments (NoGeometry)
      - computing Helmert transformation residuals with least squares (>= 3 active points)
      - neutralizing worst point while RMSE (EMQ) is above tolerance
      - keeping hel_gaps (Point) in sync with per-point residual magnitudes

    Layer schema expected for hel_monuments:
      fid (int), used (int), x (double), y (double), new_x (double), new_y (double), err (double)

    Layer schema expected for hel_gaps:
      fid (int), gap (double)
    """

    def __init__(self, monuments_layer: QgsVectorLayer, gaps_layer: QgsVectorLayer):
        self.monuments_layer = monuments_layer
        self.gaps_layer = gaps_layer

        self.last_rmse: Optional[float] = None
        self.last_neutralized_fids: List[int] = []

        self._validate_layers()

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------
    def add_monument_point(
        self,
        x: float,
        y: float,
        new_x: float,
        new_y: float,
        tolerance_m: Optional[float] = None,
        auto_neutralize: bool = True,
    ) -> Tuple[Optional[float], List[int]]:
        """Add a homologous point pair and recompute residuals.

        :returns: (rmse, neutralized_fids)
        """
        # Add feature
        feat = QgsFeature(self.monuments_layer.fields())
        feat.setAttribute("used", 1)
        feat.setAttribute("x", float(x))
        feat.setAttribute("y", float(y))
        feat.setAttribute("new_x", float(new_x))
        feat.setAttribute("new_y", float(new_y))
        feat.setAttribute("err", None)

        with edit(self.monuments_layer):
            if not self.monuments_layer.addFeature(feat):
                raise RuntimeError(
                    i18n.tr("Failed to add feature to {layer_name}").format(
                        layer_name=self.monuments_layer.name()
                    )
                )

        # Recompute
        rmse, neutralized = self.recompute(
            tolerance_m=tolerance_m, auto_neutralize=auto_neutralize
        )

        return rmse, neutralized

    def remove_monument_points(
        self, feature_ids: Sequence[int], tolerance_m=None, auto_neutralize=True
    ):
        self._delete_features(self.monuments_layer, [int(i) for i in feature_ids])
        return self.recompute(tolerance_m=tolerance_m, auto_neutralize=auto_neutralize)

    def recompute(
        self, tolerance_m: Optional[float] = None, auto_neutralize: bool = True
    ) -> Tuple[Optional[float], List[int]]:
        """Recompute LS residuals and update layers.

        - Uses only points where used == 1.
        - If >= 3 active points: computes RMSE, and if RMSE > tolerance,
          neutralizes worst point (max residual) and repeats until RMSE <= tolerance
          or < 3 active points remain.
        - If < 3 active points: clears residuals (err=None) and gaps.

        :returns: (rmse, neutralized_fids_in_this_run)
        """
        self.last_neutralized_fids = []

        feats = list(self.monuments_layer.getFeatures())
        active = [f for f in feats if self._is_used(f)]

        if len(active) < 3:
            # No RMSE / residuals for <3 points (per spec)
            self.last_rmse = None
            self._set_err_for_all(None)
            self._clear_layer(self.gaps_layer)
            return self.last_rmse, self.last_neutralized_fids

        # Neutralization loop
        while True:
            ht = HelmertTransformation()
            for f in active:
                ht.add_control_point(
                    int(f.id()),
                    float(f["x"]),
                    float(f["y"]),
                    float(f["new_x"]),
                    float(f["new_y"]),
                )

            try:
                ht.fit()
                summary = ht.compute_residuals()
            except Exception:
                # Degenerate or unexpected -> clear residuals, stop
                self.last_rmse = None
                self._set_err_for_all(None)
                self._clear_layer(self.gaps_layer)
                return self.last_rmse, self.last_neutralized_fids

            rmse = float(summary.rmse)
            self.last_rmse = rmse

            if (
                not auto_neutralize
                or tolerance_m is None
                or rmse <= tolerance_m
                or len(active) <= 2
            ):
                # Stop (either no tolerance, or within tolerance, or cannot keep neutralizing)
                break

            # Worst point among current active set
            worst_id, worst_res = max(
                summary.residuals.items(), key=lambda it: float(it[1].residual)
            )
            worst_fid = int(worst_id)

            # Neutralize it
            self._set_attributes(
                self.monuments_layer,
                {
                    worst_fid: {"used": 0, "err": None},
                },
            )
            self.last_neutralized_fids.append(worst_fid)

            # Refresh active set
            feats = list(self.monuments_layer.getFeatures())
            active = [f for f in feats if self._is_used(f)]

            if len(active) < 3:
                # Spec: stop once only 2 pairs remain (err is then indeterminate)
                break

        # Final pass: set residuals for active points if we ended with >=3 active
        feats = list(self.monuments_layer.getFeatures())
        active = [f for f in feats if self._is_used(f)]

        if len(active) < 3:
            self.last_rmse = None
            self._set_err_for_all(None)
            self._clear_layer(self.gaps_layer)
            return self.last_rmse, self.last_neutralized_fids

        # Recompute once for final active set (so summary matches final used mask)
        ht = HelmertTransformation()
        for f in active:
            ht.add_control_point(
                int(f.id()),
                float(f["x"]),
                float(f["y"]),
                float(f["new_x"]),
                float(f["new_y"]),
            )
        ht.fit()
        summary = ht.compute_residuals()
        self.last_rmse = float(summary.rmse)

        updates: Dict[int, Dict[str, Any]] = {}
        for f in feats:
            fid = int(f.id())
            if self._is_used(f) and fid in summary.residuals:
                updates[fid] = {"err": float(summary.residuals[fid].residual)}
            else:
                # Neutralized or not part of LS => residual indeterminate
                updates[fid] = {"err": None}

        self._set_attributes(self.monuments_layer, updates)

        # Sync gaps layer
        self._rebuild_gaps_layer()

        return self.last_rmse, self.last_neutralized_fids

    def build_fitted_transformation(self) -> HelmertTransformation:
        """Build and fit a Helmert transformation from active monument pairs.

        Active pairs are those with ``used == 1``.

        :returns: A fitted Helmert transformation.
        :rtype: HelmertTransformation
        :raises ValueError: If fewer than 2 active pairs are available.
        """
        feats = [f for f in self.monuments_layer.getFeatures() if self._is_used(f)]
        if len(feats) < 2:
            raise ValueError(
                i18n.tr(
                    "At least 2 active monument pairs are required to fit the Helmert transformation."
                )
            )

        ht = HelmertTransformation()
        for f in feats:
            ht.add_control_point(
                int(f.id()),
                float(f["x"]),
                float(f["y"]),
                float(f["new_x"]),
                float(f["new_y"]),
            )
        ht.fit()
        return ht

    def get_monuments_data(self) -> List[Dict[str, Any]]:
        rows = []
        for f in self.monuments_layer.getFeatures():
            used = _qvariant_to_int(f["used"]) or 0
            v_err = f["err"]
            err = None
            if v_err is not None and not (hasattr(v_err, "isNull") and v_err.isNull()):
                err = float(v_err)

            rows.append(
                {
                    "fid": f.id(),
                    "used": int(used),
                    "x": float(f["x"]),
                    "y": float(f["y"]),
                    "new_x": float(f["new_x"]),
                    "new_y": float(f["new_y"]),
                    "err": err,
                }
            )

        rows.sort(
            key=lambda d: (d["fid"] is None, d["fid"] if d["fid"] is not None else -1)
        )
        return rows

    def get_monument_by_id(self, fid: int) -> Optional[QgsFeature]:
        """Return a monument feature matching fid attribute."""
        for f in self.monuments_layer.getFeatures():
            if f.id() == int(fid):
                return f
        return None

    def neutralize_points(self, feature_ids, tolerance_m=None, auto_neutralize=False):
        updates = {int(fid): {"used": 0, "err": None} for fid in feature_ids}
        if updates:
            self._set_attributes(self.monuments_layer, updates)
        return self.recompute(tolerance_m=tolerance_m, auto_neutralize=auto_neutralize)

    def reactivate_points(self, feature_ids, tolerance_m=None, auto_neutralize=False):
        updates = {int(fid): {"used": 1, "err": None} for fid in feature_ids}
        if updates:
            self._set_attributes(self.monuments_layer, updates)
        return self.recompute(tolerance_m=tolerance_m, auto_neutralize=auto_neutralize)

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------
    def _validate_layers(self) -> None:
        if not self.monuments_layer or not self.gaps_layer:
            raise ValueError(i18n.tr("Layers must be valid"))

        required_monuments = {"used", "x", "y", "new_x", "new_y", "err"}
        missing = required_monuments - {f.name() for f in self.monuments_layer.fields()}
        if missing:
            raise ValueError(
                i18n.tr(
                    "hel_monuments layer schema mismatch. Missing fields: {missing}"
                ).format(missing=sorted(missing))
            )

        required_gaps = {"fid", "gap"}
        missing = required_gaps - {f.name() for f in self.gaps_layer.fields()}
        if missing:
            raise ValueError(
                i18n.tr(
                    "hel_gaps layer schema mismatch. Missing fields: {missing}"
                ).format(missing=sorted(missing))
            )

    def _set_err_for_all(self, value: Optional[float]) -> None:
        updates: Dict[int, Dict[str, Any]] = {}
        for f in self.monuments_layer.getFeatures():
            updates[int(f.id())] = {"err": value}
        self._set_attributes(self.monuments_layer, updates)

    def _rebuild_gaps_layer(self) -> None:
        self._clear_layer(self.gaps_layer)
        gap_features: List[QgsFeature] = []

        idx_fid = self.gaps_layer.fields().indexOf("fid")  # optionnel
        for f in self.monuments_layer.getFeatures():
            if not self._is_used(f):
                continue

            v_err = f["err"]
            if v_err is None or (hasattr(v_err, "isNull") and v_err.isNull()):
                continue
            err = float(v_err)

            gap_feat = QgsFeature(self.gaps_layer.fields())
            if idx_fid >= 0:
                gap_feat.setAttribute("fid", int(f.id()))
            gap_feat.setAttribute("gap", err)

            gap_feat.setGeometry(
                QgsGeometry.fromPointXY(
                    QgsPointXY(float(f["new_x"]), float(f["new_y"]))
                )
            )
            gap_features.append(gap_feat)

        if gap_features:
            self._add_features(self.gaps_layer, gap_features)

    def _add_feature(self, layer: QgsVectorLayer, feature: QgsFeature) -> None:
        with edit(layer):
            ok = layer.addFeature(feature)
            if not ok:
                raise RuntimeError(
                    i18n.tr("Failed to add feature to layer {layer_name}").format(
                        layer_name=layer.name()
                    )
                )

    def _add_features(self, layer: QgsVectorLayer, features: List[QgsFeature]) -> None:
        with edit(layer):
            ok = layer.addFeatures(features)
            if not ok:
                raise RuntimeError(
                    i18n.tr("Failed to add features to layer {layer_name}").format(
                        layer_name=layer.name()
                    )
                )

    def _delete_features(
        self, layer: QgsVectorLayer, feature_ids: Sequence[int]
    ) -> None:
        with edit(layer):
            ok = layer.deleteFeatures([int(i) for i in feature_ids])
            if not ok:
                raise RuntimeError(
                    i18n.tr("Failed to delete features from layer {layer_name}").format(
                        layer_name=layer.name()
                    )
                )

    def _set_attributes(
        self, layer: QgsVectorLayer, updates: Dict[int, Dict[str, Any]]
    ) -> None:
        """Batch attribute update: {feature_id: {field_name: value}}"""
        if not updates:
            return

        name_to_idx = {f.name(): i for i, f in enumerate(layer.fields())}

        with edit(layer):
            for fid, changes in updates.items():
                for field_name, value in changes.items():
                    idx = name_to_idx.get(field_name)
                    if idx is None:
                        continue
                    layer.changeAttributeValue(int(fid), idx, value)

    def _clear_layer(self, layer: QgsVectorLayer) -> None:
        with edit(layer):
            layer.deleteFeatures([f.id() for f in layer.getFeatures()])

    def _is_used(self, f) -> bool:
        return (_qvariant_to_int(f["used"]) or 0) == 1

    def reset_all_used(self) -> None:
        """Re-enable all pairs before applying tolerance neutralization."""
        updates = {}
        for f in self.monuments_layer.getFeatures():
            fid = int(f.id())
            updates[fid] = {"used": 1, "err": None}
        if updates:
            self._set_attributes(self.monuments_layer, updates)


# ------------------------------------------------------------------
def _qvariant_to_int(v):
    if v is None:
        return None
    # PyQt QVariant : souvent toInt() -> (val, ok)
    if hasattr(v, "toInt"):
        val, ok = v.toInt()
        return int(val) if ok else None
    try:
        return int(v)
    except Exception:
        return None
