from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Dict, Hashable, List, Optional, Sequence, Tuple, Union

import numpy as np


@dataclass
class _ControlPoint:
    """Internal container for a control point pair."""

    point_id: Hashable
    old_xy: np.ndarray  # shape (2,)
    new_xy: np.ndarray  # shape (2,)
    residual_xy: Optional[np.ndarray] = None  # shape (2,)
    residual_norm: Optional[float] = None


class HelmertTransformation:
    """
    2D Helmert (similarity) transformation estimated from control point pairs.

    The fitted model is::

        X = a  + b*x + c*y
        Y = a1 + b1*x + c1*y

    with the similarity constraints::

        b  =  s*cos(theta)
        c  = -s*sin(theta)
        b1 =  s*sin(theta)
        c1 =  s*cos(theta)

    where ``s`` is the scale factor and ``theta`` is the rotation angle.
    """

    def __init__(self) -> None:
        """Create an empty Helmert transformation (no control points, not fitted)."""
        self._points: Dict[Hashable, _ControlPoint] = {}
        self.clear_parameters()

    # ---------------------------------------------------------------------
    # Angle helpers (the original C code stores the rotation in grads/gon)
    # ---------------------------------------------------------------------
    @staticmethod
    def rad_to_gon(angle_rad: float) -> float:
        """Convert an angle from radians to grads (gon).

        :param angle_rad: Angle in radians.
        :type angle_rad: float
        :returns: Angle in grads (gon), where 400 gon = 2π radians.
        :rtype: float
        """
        return angle_rad * 200.0 / math.pi

    @staticmethod
    def gon_to_rad(angle_gon: float) -> float:
        """Convert an angle from grads (gon) to radians.

        :param angle_gon: Angle in grads (gon), where 400 gon = 2π radians.
        :type angle_gon: float
        :returns: Angle in radians.
        :rtype: float
        """
        return angle_gon * math.pi / 200.0

    @staticmethod
    def normalize_gon(angle_gon: float) -> float:
        """Normalize an angle in grads (gon) to the [0, 400) interval.

        :param angle_gon: Angle in grads (gon).
        :type angle_gon: float
        :returns: Normalized angle in [0, 400).
        :rtype: float
        """
        return angle_gon % 400.0

    # ---------------------------------------------------------------------
    # Lifecycle
    # ---------------------------------------------------------------------
    def clear(self) -> None:
        """Remove all control points and reset fitted parameters."""
        self._points.clear()
        self.clear_parameters()

    def clear_parameters(self) -> None:
        """Reset transformation parameters (marks the instance as not fitted)."""
        self.a: Optional[float] = None
        self.a1: Optional[float] = None
        self.b: Optional[float] = None
        self.c: Optional[float] = None
        self.b1: Optional[float] = None
        self.c1: Optional[float] = None

        self.scale: Optional[float] = None
        self.rotation_gon: Optional[float] = None
        self.rotation_rad: Optional[float] = None

        self.mean_residual_rms: Optional[float] = None

    @property
    def is_fitted(self) -> bool:
        """Whether transformation parameters have been estimated."""
        return self.a is not None

    # ---------------------------------------------------------------------
    # Control points management (ports Heajpt / He_mod_anc / He_mod_nou / supt)
    # ---------------------------------------------------------------------
    def add_control_point(
        self,
        point_id: Hashable,
        old_x: float,
        old_y: float,
        new_x: float,
        new_y: float,
        *,
        overwrite: bool = False,
    ) -> None:
        """Add a control point pair (old/new coordinates).

        :param point_id: Identifier of the control point (must be hashable).
        :type point_id: Hashable
        :param old_x: X coordinate in the source system.
        :type old_x: float
        :param old_y: Y coordinate in the source system.
        :type old_y: float
        :param new_x: X coordinate in the target system.
        :type new_x: float
        :param new_y: Y coordinate in the target system.
        :type new_y: float
        :param overwrite: If ``True``, replace an existing control point with the same id.
        :type overwrite: bool
        :raises KeyError: If the point id already exists and ``overwrite`` is ``False``.
        """
        if (not overwrite) and (point_id in self._points):
            raise KeyError(f"Control point '{point_id}' already exists.")

        self._points[point_id] = _ControlPoint(
            point_id=point_id,
            old_xy=np.array([old_x, old_y], dtype=float),
            new_xy=np.array([new_x, new_y], dtype=float),
        )
        self.clear_parameters()

    def update_old_coordinates(
        self, point_id: Hashable, old_x: float, old_y: float
    ) -> None:
        """Update the source coordinates of an existing control point.

        :param point_id: Identifier of the control point.
        :type point_id: Hashable
        :param old_x: New X coordinate in the source system.
        :type old_x: float
        :param old_y: New Y coordinate in the source system.
        :type old_y: float
        :raises KeyError: If the point id does not exist.
        """
        cp = self._points[point_id]
        cp.old_xy = np.array([old_x, old_y], dtype=float)
        cp.residual_xy = None
        cp.residual_norm = None
        self.clear_parameters()

    def update_new_coordinates(
        self, point_id: Hashable, new_x: float, new_y: float
    ) -> None:
        """Update the target coordinates of an existing control point.

        :param point_id: Identifier of the control point.
        :type point_id: Hashable
        :param new_x: New X coordinate in the target system.
        :type new_x: float
        :param new_y: New Y coordinate in the target system.
        :type new_y: float
        :raises KeyError: If the point id does not exist.
        """
        cp = self._points[point_id]
        cp.new_xy = np.array([new_x, new_y], dtype=float)
        cp.residual_xy = None
        cp.residual_norm = None
        self.clear_parameters()

    def remove_control_point(self, point_id: Hashable) -> None:
        """Remove a control point.

        :param point_id: Identifier of the control point.
        :type point_id: Hashable
        :raises KeyError: If the point id does not exist.
        """
        del self._points[point_id]
        self.clear_parameters()

    def list_control_point_ids(self) -> List[Hashable]:
        """Return control point identifiers.

        :returns: List of point ids.
        :rtype: list
        """
        return list(self._points.keys())

    def _stack_points(self) -> Tuple[np.ndarray, np.ndarray, List[Hashable]]:
        """Internal: stack control points as arrays.

        :returns: (old_xy, new_xy, ids)
        :rtype: tuple
        :raises ValueError: If there are no control points.
        """
        if not self._points:
            raise ValueError("No control points available.")

        ids = list(self._points.keys())
        old_xy = np.stack([self._points[i].old_xy for i in ids], axis=0)
        new_xy = np.stack([self._points[i].new_xy for i in ids], axis=0)
        return old_xy, new_xy, ids

    # ---------------------------------------------------------------------
    # Parameter estimation (ports Helmert_coef)
    # ---------------------------------------------------------------------
    def fit(self) -> "HelmertTransformation":
        """Estimate transformation parameters from control points.

        This method ports the logic from the original C implementation
        (closed-form solution under similarity constraints).

        :returns: The fitted instance (for chaining).
        :rtype: HelmertTransformation
        :raises ValueError: If fewer than two control points are available, or
            if the configuration is degenerate (all source points coincide).
        """
        old_xy, new_xy, _ids = self._stack_points()
        n = old_xy.shape[0]
        if n < 2:
            raise ValueError(
                "At least two control points are required to fit a Helmert transformation."
            )

        old_mean = old_xy.mean(axis=0)
        new_mean = new_xy.mean(axis=0)

        x_xg = old_xy[:, 0] - old_mean[0]
        y_yg = old_xy[:, 1] - old_mean[1]

        # Ported from C: leads to t1 = s*cos(theta) - 1, u1 = s*sin(theta)
        Dx = new_xy[:, 0] - x_xg
        Dy = new_xy[:, 1] - y_yg

        S_x_xg_Dx = float(np.sum(x_xg * Dx))
        S_y_yg_Dy = float(np.sum(y_yg * Dy))
        S_x_xg_Dy = float(np.sum(x_xg * Dy))
        S_y_yg_Dx = float(np.sum(y_yg * Dx))

        denom = float(np.sum(x_xg * x_xg) + np.sum(y_yg * y_yg))
        if denom == 0.0:
            raise ValueError(
                "Degenerate configuration: all source control points are identical."
            )

        t1 = (S_x_xg_Dx + S_y_yg_Dy) / denom
        u1 = (S_x_xg_Dy - S_y_yg_Dx) / denom

        a = float(new_mean[0] - (t1 + 1.0) * old_mean[0] + u1 * old_mean[1])
        a1 = float(new_mean[1] - u1 * old_mean[0] - (t1 + 1.0) * old_mean[1])

        b = float(t1 + 1.0)
        c = float(-u1)
        b1 = float(u1)
        c1 = float(t1 + 1.0)

        scale = float(math.sqrt(b * b + c * c))
        rot_rad = float(math.atan2(b1, b))
        rot_gon = float(self.rad_to_gon(rot_rad))

        self.a, self.a1, self.b, self.c, self.b1, self.c1 = a, a1, b, c, b1, c1
        self.scale = scale
        self.rotation_rad = rot_rad
        self.rotation_gon = rot_gon
        self.mean_residual_rms = None

        return self

    # ---------------------------------------------------------------------
    # Transformation (ports Helmert_tran)
    # ---------------------------------------------------------------------
    def transform_xy(self, x: float, y: float) -> Tuple[float, float]:
        """Transform a single 2D point.

        :param x: X coordinate in the source system.
        :type x: float
        :param y: Y coordinate in the source system.
        :type y: float
        :returns: (X, Y) in the target system.
        :rtype: tuple[float, float]
        :raises RuntimeError: If the transformation is not fitted yet.
        """
        if not self.is_fitted:
            raise RuntimeError(
                "Transformation parameters are not available. Call 'fit()' first (or 'set_parameters()')."
            )

        X = self.a + self.b * x + self.c * y
        Y = self.a1 + self.b1 * x + self.c1 * y
        return float(X), float(Y)

    def transform(self, xy: Union[np.ndarray, Sequence[Sequence[float]]]) -> np.ndarray:
        """Transform many 2D points.

        :param xy: Input points as an ``(N, 2)`` array-like structure.
        :type xy: numpy.ndarray or sequence
        :returns: Transformed points as a ``(N, 2)`` NumPy array.
        :rtype: numpy.ndarray
        :raises ValueError: If the input does not have shape ``(N, 2)``.
        :raises RuntimeError: If the transformation is not fitted yet.
        """
        if not self.is_fitted:
            raise RuntimeError(
                "Transformation parameters are not available. Call 'fit()' first (or 'set_parameters()')."
            )

        arr = np.asarray(xy, dtype=float)
        if arr.ndim != 2 or arr.shape[1] != 2:
            raise ValueError(f"Expected an (N,2) array, got shape {arr.shape}.")

        X = self.a + self.b * arr[:, 0] + self.c * arr[:, 1]
        Y = self.a1 + self.b1 * arr[:, 0] + self.c1 * arr[:, 1]
        return np.column_stack([X, Y])

    # ---------------------------------------------------------------------
    # Residuals (ports Helmert_resi / Helmert_result_from_num)
    # ---------------------------------------------------------------------
    def compute_residuals(
        self, *, recompute: bool = True
    ) -> Dict[Hashable, Tuple[float, float, float]]:
        """Compute residuals for each control point.

        Residuals are computed as::

            r_x = X_pred - X_obs
            r_y = Y_pred - Y_obs
            r   = sqrt(r_x^2 + r_y^2)

        The method also computes the RMS of planar residual norms::

            mean_residual_rms = sqrt( sum(r^2) / n )

        :param recompute: If ``True``, recompute residuals even if they were computed before.
        :type recompute: bool
        :returns: Mapping ``{point_id: (r_x, r_y, r)}``.
        :rtype: dict
        :raises RuntimeError: If the transformation is not fitted yet.
        """
        if not self.is_fitted:
            raise RuntimeError(
                "Transformation parameters are not available. Call 'fit()' first (or 'set_parameters()')."
            )

        if (not recompute) and (self.mean_residual_rms is not None):
            return {
                pid: (
                    float(cp.residual_xy[0]),
                    float(cp.residual_xy[1]),
                    float(cp.residual_norm),
                )
                for pid, cp in self._points.items()
            }

        old_xy, new_xy, ids = self._stack_points()
        pred = self.transform(old_xy)

        res_xy = pred - new_xy
        res_norm = np.sqrt(np.sum(res_xy * res_xy, axis=1))

        for pid, rxy, rn in zip(ids, res_xy, res_norm):
            cp = self._points[pid]
            cp.residual_xy = rxy.astype(float)
            cp.residual_norm = float(rn)

        n = len(ids)
        self.mean_residual_rms = (
            float(math.sqrt(float(np.sum(res_norm * res_norm)) / n)) if n > 0 else None
        )

        return {
            pid: (
                float(self._points[pid].residual_xy[0]),
                float(self._points[pid].residual_xy[1]),
                float(self._points[pid].residual_norm),
            )
            for pid in ids
        }

    def get_control_point_result(
        self, point_id: Hashable
    ) -> Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float, float]]:
        """Return control point data and its residuals (if fitted).

        This mirrors the intent of the C function ``Helmert_result_from_num``.

        :param point_id: Identifier of the control point.
        :type point_id: Hashable
        :returns: ``((old_x, old_y), (new_x, new_y), (r_x, r_y, r))``.
        :rtype: tuple
        :raises KeyError: If the point id does not exist.
        :raises RuntimeError: If the transformation is not fitted yet.
        """
        cp = self._points[point_id]
        if not self.is_fitted:
            raise RuntimeError(
                "Transformation parameters are not available. Call 'fit()' first (or 'set_parameters()')."
            )

        if cp.residual_xy is None or cp.residual_norm is None:
            self.compute_residuals()

        old_xy = (float(cp.old_xy[0]), float(cp.old_xy[1]))
        new_xy = (float(cp.new_xy[0]), float(cp.new_xy[1]))
        r = (
            float(cp.residual_xy[0]),
            float(cp.residual_xy[1]),
            float(cp.residual_norm),
        )
        return old_xy, new_xy, r

    # ---------------------------------------------------------------------
    # Bearings (ports Helmert_gise)
    # ---------------------------------------------------------------------
    def transform_bearing_gon(
        self, bearing_gon: float, *, normalize: bool = False
    ) -> float:
        """Transform a bearing (gisement) expressed in grads (gon).

        The original C code simply adds the fitted rotation angle::

            bearing_new = bearing_old + rotation_angle

        :param bearing_gon: Bearing in grads (gon) in the source system.
        :type bearing_gon: float
        :param normalize: If ``True``, normalize the result to the [0, 400) interval.
        :type normalize: bool
        :returns: Bearing in grads (gon) in the target system.
        :rtype: float
        :raises RuntimeError: If the transformation is not fitted yet.
        """
        if not self.is_fitted or self.rotation_gon is None:
            raise RuntimeError(
                "Transformation parameters are not available. Call 'fit()' first (or 'set_parameters()')."
            )

        out = float(bearing_gon + self.rotation_gon)
        return float(self.normalize_gon(out)) if normalize else out

    # ---------------------------------------------------------------------
    # Params I/O (ports Lec_para / Ecr_para)
    # ---------------------------------------------------------------------
    def get_parameters(self) -> Dict[str, float]:
        """Return the fitted parameters as a dictionary.

        Keys match the original C naming:

        - ``a``  : translation on X
        - ``a1`` : translation on Y
        - ``b``  : matrix coefficient (s*cos(theta))
        - ``c``  : matrix coefficient (-s*sin(theta))
        - ``b1`` : matrix coefficient (s*sin(theta))
        - ``c1`` : matrix coefficient (s*cos(theta))
        - ``scale`` : scale factor ``s``
        - ``rotation_gon`` : rotation angle in grads (gon)

        :returns: Parameters mapping.
        :rtype: dict
        :raises RuntimeError: If the transformation is not fitted yet.
        """
        if not self.is_fitted:
            raise RuntimeError(
                "Transformation parameters are not available. Call 'fit()' first (or 'set_parameters()')."
            )

        return {
            "a": float(self.a),
            "a1": float(self.a1),
            "b": float(self.b),
            "c": float(self.c),
            "b1": float(self.b1),
            "c1": float(self.c1),
            "scale": float(self.scale),
            "rotation_gon": float(self.rotation_gon),
            "rotation_rad": float(self.rotation_rad),
        }

    def set_parameters(self, *, a: float, a1: float, b: float, c: float) -> None:
        """Set transformation parameters manually.

        This mirrors the intent of the C function ``Ecr_para``. Only the four
        parameters ``(a, a1, b, c)`` are required; the remaining coefficients
        are derived by enforcing similarity constraints::

            b1 = -c
            c1 =  b

        :param a: Translation on X.
        :type a: float
        :param a1: Translation on Y.
        :type a1: float
        :param b: Rotation/scale coefficient (s*cos(theta)).
        :type b: float
        :param c: Rotation/scale coefficient (-s*sin(theta)).
        :type c: float
        """
        self.a = float(a)
        self.a1 = float(a1)
        self.b = float(b)
        self.c = float(c)
        self.b1 = float(-c)
        self.c1 = float(b)

        self.scale = float(math.sqrt(self.b * self.b + self.c * self.c))
        self.rotation_rad = float(math.atan2(self.b1, self.b))
        self.rotation_gon = float(self.rad_to_gon(self.rotation_rad))

        self.mean_residual_rms = None
        for cp in self._points.values():
            cp.residual_xy = None
            cp.residual_norm = None

    # ---------------------------------------------------------------------
    # Historical helper (ports Helmert_calculEmq)
    # ---------------------------------------------------------------------
    def compute_emq(self, nb: int, nb_false: int) -> float:
        """Compute the RMS residual (``emq``) following the original C logic.

        This is a direct port of the function ``Helmert_calculEmq``::

            if nb - nb_false >= 2:
                emq = RMS(residual_norm)
            elif nb <= 1:
                emq = 0.0
            else:
                emq = -999.99

        :param nb: Total number of points considered by the caller.
        :type nb: int
        :param nb_false: Number of rejected (out-of-tolerance) points.
        :type nb_false: int
        :returns: The RMS residual, 0.0, or -999.99 depending on the inputs.
        :rtype: float
        :raises RuntimeError: If the transformation is not fitted yet.
        """
        if (nb - nb_false) >= 2:
            self.compute_residuals()
            return float(self.mean_residual_rms)
        if nb <= 1:
            return 0.0
        return -999.99
