from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union

from topaze.toolbelt import i18n

try:
    import numpy as np  # type: ignore
except Exception:  # pragma: no cover
    np = None  # type: ignore


XY = Tuple[float, float]
ArrayLikeXY = Union[Sequence[XY], "np.ndarray"]  # noqa: F821


@dataclass(frozen=True)
class ControlPoint:
    """A 2D control point used to estimate a Helmert (similarity) transform.

    :param point_id: Identifier of the control point (any hashable object).
    :param old_x: X coordinate in the source reference frame.
    :param old_y: Y coordinate in the source reference frame.
    :param new_x: X coordinate in the target reference frame.
    :param new_y: Y coordinate in the target reference frame.
    """

    point_id: Hashable
    old_x: float
    old_y: float
    new_x: float
    new_y: float


@dataclass(frozen=True)
class HelmertCoefficients:
    """Coefficients of a 2D Helmert transformation.

    The transformation is::

        X = tx + b  * x + c  * y
        Y = ty + b1 * x + c1 * y

    with the Helmert constraints::

        c  = -b1
        c1 =  b

    :param tx: Translation along X (target units).
    :param ty: Translation along Y (target units).
    :param b:  Matrix coefficient (scale * cos(theta)).
    :param c:  Matrix coefficient (-scale * sin(theta)).
    :param b1: Matrix coefficient (scale * sin(theta)).
    :param c1: Matrix coefficient (scale * cos(theta)).
    :param scale: Uniform scale factor.
    :param rotation_rad: Rotation angle in radians.
    :param rotation_gon: Rotation angle in grads (gons), where 200 gons = π rad.
    """

    tx: float
    ty: float
    b: float
    c: float
    b1: float
    c1: float
    scale: float
    rotation_rad: float
    rotation_gon: float

    def as_affine_matrix(self) -> "np.ndarray":
        """Return the 3x3 affine matrix (homogeneous coordinates).

        :returns: A 3x3 matrix ``A`` such that ``[X, Y, 1]^T = A @ [x, y, 1]^T``.
        :rtype: numpy.ndarray
        :raises RuntimeError: If NumPy is not available.
        """
        if np is None:
            raise RuntimeError(i18n.tr("NumPy is required to build an affine matrix."))
        return np.array(
            [
                [self.b, self.c, self.tx],
                [self.b1, self.c1, self.ty],
                [0.0, 0.0, 1.0],
            ],
            dtype=float,
        )


@dataclass(frozen=True)
class ControlPointResidual:
    """Residuals for one control point.

    :param point_id: Identifier of the control point.
    :param predicted_x: Transformed X coordinate (prediction).
    :param predicted_y: Transformed Y coordinate (prediction).
    :param residual_x: Residual on X (predicted - observed).
    :param residual_y: Residual on Y (predicted - observed).
    :param residual: Planimetric residual magnitude.
    """

    point_id: Hashable
    predicted_x: float
    predicted_y: float
    residual_x: float
    residual_y: float
    residual: float


@dataclass(frozen=True)
class ResidualSummary:
    """Residual summary for all control points.

    :param rmse: Root-mean-square error of planimetric residual magnitudes.
    :param residuals: Per-point residuals (keyed by point_id).
    """

    rmse: float
    residuals: Dict[Hashable, ControlPointResidual]


class HelmertTransformation:
    """2D Helmert (similarity) transformation estimated from control points.

    This is a direct, Pythonic port of the core logic found in the provided
    ``helmert.c`` file: translation + uniform scale + rotation.

    The model is defined by six coefficients ``(tx, ty, b, c, b1, c1)``::

        X = tx + b  * x + c  * y
        Y = ty + b1 * x + c1 * y

    and the Helmert constraints are enforced::

        c  = -b1
        c1 =  b
    """

    def __init__(self) -> None:
        """Create an empty transformation (no control points, not fitted)."""
        self._points: Dict[Hashable, ControlPoint] = {}
        self._coeffs: Optional[HelmertCoefficients] = None

    # ---------------------------------------------------------------------
    # Control points management
    # ---------------------------------------------------------------------
    def add_control_point(
        self,
        point_id: Hashable,
        old_x: float,
        old_y: float,
        new_x: float,
        new_y: float,
    ) -> None:
        """Add a control point.

        :param point_id: Identifier of the control point.
        :param old_x: X coordinate in the source frame.
        :param old_y: Y coordinate in the source frame.
        :param new_x: X coordinate in the target frame.
        :param new_y: Y coordinate in the target frame.
        :raises ValueError: If ``point_id`` already exists.
        """
        if point_id in self._points:
            raise ValueError(
                i18n.tr("Control point '{point_id}' already exists.").format(
                    point_id=point_id
                )
            )
        self.set_control_point(point_id, old_x, old_y, new_x, new_y)

    def set_control_point(
        self,
        point_id: Hashable,
        old_x: float,
        old_y: float,
        new_x: float,
        new_y: float,
    ) -> None:
        """Create or update a control point (upsert).

        :param point_id: Identifier of the control point.
        :param old_x: X coordinate in the source frame.
        :param old_y: Y coordinate in the source frame.
        :param new_x: X coordinate in the target frame.
        :param new_y: Y coordinate in the target frame.
        """
        self._points[point_id] = ControlPoint(
            point_id=point_id,
            old_x=float(old_x),
            old_y=float(old_y),
            new_x=float(new_x),
            new_y=float(new_y),
        )
        self._coeffs = None

    def update_old_coordinates(
        self, point_id: Hashable, old_x: float, old_y: float
    ) -> None:
        """Update the source-frame coordinates of an existing control point.

        :param point_id: Identifier of the control point.
        :param old_x: New X coordinate in the source frame.
        :param old_y: New Y coordinate in the source frame.
        :raises KeyError: If the control point does not exist.
        """
        p = self._points[point_id]
        self._points[point_id] = ControlPoint(
            point_id=p.point_id,
            old_x=float(old_x),
            old_y=float(old_y),
            new_x=p.new_x,
            new_y=p.new_y,
        )
        self._coeffs = None

    def update_new_coordinates(
        self, point_id: Hashable, new_x: float, new_y: float
    ) -> None:
        """Update the target-frame coordinates of an existing control point.

        :param point_id: Identifier of the control point.
        :param new_x: New X coordinate in the target frame.
        :param new_y: New Y coordinate in the target frame.
        :raises KeyError: If the control point does not exist.
        """
        p = self._points[point_id]
        self._points[point_id] = ControlPoint(
            point_id=p.point_id,
            old_x=p.old_x,
            old_y=p.old_y,
            new_x=float(new_x),
            new_y=float(new_y),
        )
        self._coeffs = None

    def remove_control_point(self, point_id: Hashable) -> None:
        """Remove a control point.

        :param point_id: Identifier of the control point to remove.
        :raises KeyError: If the control point does not exist.
        """
        del self._points[point_id]
        self._coeffs = None

    def clear(self) -> None:
        """Remove all control points and reset the fitted coefficients."""
        self._points.clear()
        self._coeffs = None

    @property
    def control_points(self) -> Mapping[Hashable, ControlPoint]:
        """Read-only view of current control points.

        :returns: A mapping ``{point_id: ControlPoint}``.
        :rtype: Mapping[Hashable, ControlPoint]
        """
        return dict(self._points)

    # ---------------------------------------------------------------------
    # Estimation (fit) and parameters I/O
    # ---------------------------------------------------------------------
    def fit(self) -> HelmertCoefficients:
        """Estimate the Helmert coefficients from current control points.

        The computation follows the same principle as the C function
        ``Helmert_coef`` (centroids + cross-covariances).

        :returns: Estimated coefficients.
        :rtype: HelmertCoefficients
        :raises ValueError: If fewer than 2 control points are available.
        :raises ValueError: If the control points are degenerate (zero variance).
        :raises RuntimeError: If NumPy is not available.
        """
        if np is None:
            raise RuntimeError(
                i18n.tr("NumPy is required to estimate coefficients (fit).")
            )

        if len(self._points) < 2:
            raise ValueError(
                i18n.tr(
                    "At least 2 control points are required to fit a 2D Helmert transformation."
                )
            )

        old = np.array([(p.old_x, p.old_y) for p in self._points.values()], dtype=float)
        new = np.array([(p.new_x, p.new_y) for p in self._points.values()], dtype=float)

        old_mean = old.mean(axis=0)
        new_mean = new.mean(axis=0)

        old_c = old - old_mean
        new_c = new - new_mean

        den = float(np.sum(old_c[:, 0] ** 2 + old_c[:, 1] ** 2))
        if den == 0.0:
            raise ValueError(
                i18n.tr(
                    "Degenerate configuration: all source points are identical (zero variance)."
                )
            )

        # b  = scale * cos(theta)
        # b1 = scale * sin(theta)
        b = float(np.sum(old_c[:, 0] * new_c[:, 0] + old_c[:, 1] * new_c[:, 1]) / den)
        b1 = float(np.sum(old_c[:, 0] * new_c[:, 1] - old_c[:, 1] * new_c[:, 0]) / den)

        c = -b1
        c1 = b

        tx = float(new_mean[0] - b * old_mean[0] - c * old_mean[1])
        ty = float(new_mean[1] - b1 * old_mean[0] - c1 * old_mean[1])

        scale = float(math.hypot(b, b1))
        rotation_rad = float(math.atan2(b1, b))
        rotation_gon = float(rotation_rad * 200.0 / math.pi)

        self._coeffs = HelmertCoefficients(
            tx=tx,
            ty=ty,
            b=b,
            c=c,
            b1=b1,
            c1=c1,
            scale=scale,
            rotation_rad=rotation_rad,
            rotation_gon=rotation_gon,
        )
        return self._coeffs

    @property
    def coefficients(self) -> HelmertCoefficients:
        """Return fitted coefficients.

        :returns: The fitted coefficients.
        :rtype: HelmertCoefficients
        :raises RuntimeError: If the transformation has not been fitted yet.
        """
        if self._coeffs is None:
            raise RuntimeError(
                i18n.tr(
                    "Transformation not fitted yet. Call fit() first, or set coefficients explicitly."
                )
            )
        return self._coeffs

    def set_coefficients_from_bc(
        self, tx: float, ty: float, b: float, c: float
    ) -> HelmertCoefficients:
        """Set coefficients from the (b, c) form used in the legacy C code.

        The original ``Ecr_para`` stores ``b`` and ``c`` then rebuilds ``b1`` and
        ``c1`` using the Helmert constraints::

            b1 = -c
            c1 =  b

        :param tx: Translation along X.
        :param ty: Translation along Y.
        :param b: Matrix coefficient ``b``.
        :param c: Matrix coefficient ``c`` (note: ``c = -b1``).
        :returns: The resulting coefficient set.
        :rtype: HelmertCoefficients
        """
        return self.set_coefficients(tx=tx, ty=ty, b=b, b1=-c)

    def set_coefficients(
        self, tx: float, ty: float, b: float, b1: float
    ) -> HelmertCoefficients:
        """Set coefficients from ``(tx, ty, b, b1)`` and enforce Helmert constraints.

        :param tx: Translation along X.
        :param ty: Translation along Y.
        :param b: Matrix coefficient ``b`` (scale * cos(theta)).
        :param b1: Matrix coefficient ``b1`` (scale * sin(theta)).
        :returns: The resulting coefficient set.
        :rtype: HelmertCoefficients
        :raises ValueError: If the implied scale is zero.
        """
        b = float(b)
        b1 = float(b1)
        scale = float(math.hypot(b, b1))
        if scale == 0.0:
            raise ValueError(i18n.tr("Invalid coefficients: implied scale is zero."))

        c = -b1
        c1 = b

        rotation_rad = float(math.atan2(b1, b))
        rotation_gon = float(rotation_rad * 200.0 / math.pi)

        self._coeffs = HelmertCoefficients(
            tx=float(tx),
            ty=float(ty),
            b=b,
            c=c,
            b1=b1,
            c1=c1,
            scale=scale,
            rotation_rad=rotation_rad,
            rotation_gon=rotation_gon,
        )
        return self._coeffs

    @classmethod
    def from_translation_rotation_scale(
        cls,
        tx: float,
        ty: float,
        rotation_gon: float,
        scale: float = 1.0,
    ) -> "HelmertTransformation":
        """Create a transformation directly from translation, rotation and scale.

        :param tx: Translation along X.
        :param ty: Translation along Y.
        :param rotation_gon: Rotation angle in grads (gons).
        :param scale: Uniform scale factor.
        :returns: A fitted transformation instance.
        :rtype: HelmertTransformation
        :raises ValueError: If ``scale`` is zero.
        """
        if scale == 0.0:
            raise ValueError(i18n.tr("Scale must be non-zero."))
        rotation_rad = float(rotation_gon) * math.pi / 200.0
        b = float(scale) * math.cos(rotation_rad)
        b1 = float(scale) * math.sin(rotation_rad)

        obj = cls()
        obj.set_coefficients(tx=tx, ty=ty, b=b, b1=b1)
        return obj

    # ---------------------------------------------------------------------
    # Forward / inverse transformation
    # ---------------------------------------------------------------------
    def transform_xy(self, x: float, y: float) -> XY:
        """Transform one point from source to target frame.

        :param x: Source X.
        :param y: Source Y.
        :returns: ``(X, Y)`` in the target frame.
        :rtype: tuple[float, float]
        :raises RuntimeError: If the transformation is not fitted.
        """
        c = self.coefficients
        X = c.tx + c.b * float(x) + c.c * float(y)
        Y = c.ty + c.b1 * float(x) + c.c1 * float(y)
        return X, Y

    def transform_many(self, xy: ArrayLikeXY) -> "np.ndarray":
        """Transform many points from source to target frame.

        :param xy: Array-like of shape ``(n, 2)``.
        :returns: A NumPy array of shape ``(n, 2)``.
        :rtype: numpy.ndarray
        :raises RuntimeError: If NumPy is not available.
        :raises RuntimeError: If the transformation is not fitted.
        """
        if np is None:
            raise RuntimeError(i18n.tr("NumPy is required to transform many points."))
        c = self.coefficients

        arr = np.asarray(xy, dtype=float)
        if arr.ndim != 2 or arr.shape[1] != 2:
            raise ValueError(i18n.tr("Expected an array-like of shape (n, 2)."))

        X = c.tx + c.b * arr[:, 0] + c.c * arr[:, 1]
        Y = c.ty + c.b1 * arr[:, 0] + c.c1 * arr[:, 1]
        return np.column_stack([X, Y])

    def inverse_transform_xy(self, X: float, Y: float) -> XY:
        """Inverse-transform one point from target to source frame.

        :param X: Target X.
        :param Y: Target Y.
        :returns: ``(x, y)`` in the source frame.
        :rtype: tuple[float, float]
        :raises RuntimeError: If the transformation is not fitted.
        """
        c = self.coefficients
        dx = float(X) - c.tx
        dy = float(Y) - c.ty

        # For a Helmert transform, det = b*c1 - c*b1 = b^2 + b1^2 = scale^2
        det = c.scale * c.scale
        x = (c.b * dx + c.b1 * dy) / det
        y = (-c.b1 * dx + c.b * dy) / det
        return x, y

    def inverse_transform_many(self, XY: ArrayLikeXY) -> "np.ndarray":
        """Inverse-transform many points from target to source frame.

        :param XY: Array-like of shape ``(n, 2)``.
        :returns: A NumPy array of shape ``(n, 2)``.
        :rtype: numpy.ndarray
        :raises RuntimeError: If NumPy is not available.
        :raises RuntimeError: If the transformation is not fitted.
        """
        if np is None:
            raise RuntimeError(
                i18n.tr("NumPy is required to inverse-transform many points.")
            )
        c = self.coefficients

        arr = np.asarray(XY, dtype=float)
        if arr.ndim != 2 or arr.shape[1] != 2:
            raise ValueError(i18n.tr("Expected an array-like of shape (n, 2)."))

        dx = arr[:, 0] - c.tx
        dy = arr[:, 1] - c.ty

        det = c.scale * c.scale
        x = (c.b * dx + c.b1 * dy) / det
        y = (-c.b1 * dx + c.b * dy) / det
        return np.column_stack([x, y])

    # ---------------------------------------------------------------------
    # Residuals (calage quality)
    # ---------------------------------------------------------------------
    def compute_residuals(self) -> ResidualSummary:
        """Compute residuals on all control points.

        This corresponds to the C function ``Helmert_resi``:
        each control point is transformed and compared to its observed target
        coordinates.

        The returned ``rmse`` is::

            rmse = sqrt( sum(residual_i^2) / n )

        where ``residual_i`` is the planimetric residual magnitude of point ``i``.

        :returns: Residuals per point and global RMSE.
        :rtype: ResidualSummary
        :raises RuntimeError: If the transformation is not fitted.
        :raises ValueError: If no control points are available.
        """
        c = self.coefficients
        if not self._points:
            raise ValueError(i18n.tr("No control points available."))

        residuals: Dict[Hashable, ControlPointResidual] = {}
        sse = 0.0

        for p in self._points.values():
            pred_x = c.tx + c.b * p.old_x + c.c * p.old_y
            pred_y = c.ty + c.b1 * p.old_x + c.c1 * p.old_y
            rx = pred_x - p.new_x
            ry = pred_y - p.new_y
            r = math.hypot(rx, ry)
            residuals[p.point_id] = ControlPointResidual(
                point_id=p.point_id,
                predicted_x=float(pred_x),
                predicted_y=float(pred_y),
                residual_x=float(rx),
                residual_y=float(ry),
                residual=float(r),
            )
            sse += r * r

        rmse = math.sqrt(sse / len(residuals))
        return ResidualSummary(rmse=float(rmse), residuals=residuals)

    def control_point_result(self, point_id: Hashable) -> ControlPointResidual:
        """Return residual information for one control point.

        Equivalent to the C function ``Helmert_result_from_num``.

        :param point_id: Identifier of the control point.
        :returns: Residual data for that point.
        :rtype: ControlPointResidual
        :raises KeyError: If the control point does not exist.
        :raises RuntimeError: If the transformation is not fitted.
        """
        p = self._points[point_id]
        c = self.coefficients

        pred_x = c.tx + c.b * p.old_x + c.c * p.old_y
        pred_y = c.ty + c.b1 * p.old_x + c.c1 * p.old_y
        rx = pred_x - p.new_x
        ry = pred_y - p.new_y
        r = math.hypot(rx, ry)

        return ControlPointResidual(
            point_id=p.point_id,
            predicted_x=float(pred_x),
            predicted_y=float(pred_y),
            residual_x=float(rx),
            residual_y=float(ry),
            residual=float(r),
        )

    def format_residual_report(self, ndigits: int = 6) -> str:
        """Build a human-readable residual report.

        This is a replacement for the original C function ``Helmert_reli``
        which printed through MicroStation-specific routines.

        :param ndigits: Number of digits after the decimal point.
        :returns: A multi-line string report.
        :rtype: str
        """
        summary = self.compute_residuals()

        fmt = f"{{:.{ndigits}f}}"
        lines = [i18n.tr("Control point residuals"), ""]
        lines.append(i18n.tr("point_id;X_pred;Y_pred;dX;dY;d"))

        for pid, r in summary.residuals.items():
            lines.append(
                f"{pid};"
                f"{fmt.format(r.predicted_x)};"
                f"{fmt.format(r.predicted_y)};"
                f"{fmt.format(r.residual_x)};"
                f"{fmt.format(r.residual_y)};"
                f"{fmt.format(r.residual)}"
            )

        lines.append("")
        lines.append(i18n.tr("RMSE = {rmse}").format(rmse=fmt.format(summary.rmse)))
        return "\n".join(lines)

    # ---------------------------------------------------------------------
    # Bearing / gisement
    # ---------------------------------------------------------------------
    def transform_bearing_gon(self, bearing_gon: float, wrap: bool = True) -> float:
        """Transform a bearing (gisement) expressed in grads (gons).

        Equivalent to the C function ``Helmert_gise``: it simply adds the
        fitted rotation angle.

        :param bearing_gon: Bearing in grads (gons).
        :param wrap: If True, wrap the result into the [0, 400[ range.
        :returns: Bearing in grads (gons) in the target frame.
        :rtype: float
        """
        if wrap:
            return (float(bearing_gon) + self.coefficients.rotation_gon) % 400.0
        return float(bearing_gon) + self.coefficients.rotation_gon
