from __future__ import annotations

import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np

try:
    # Idéalement tu as déjà ajouté ces fonctions dans report_utils
    from .report_utils import (  # type: ignore
        markdown_to_html,
        markdown_to_odt,
        markdown_to_pdf,
    )
except Exception:  # pragma: no cover - fallback silencieux
    markdown_to_html = None
    markdown_to_pdf = None
    markdown_to_odt = None


# ---------------------------------------------------------------------------
# Dataclass de résultat
# ---------------------------------------------------------------------------


@dataclass
class TrilaterationResult:
    """
    Result of a trilateration computation for a single unknown point.
    """

    point_id: str

    x: float
    y: float
    z: Optional[float]

    sigma_x: float
    sigma_y: float
    sigma_z: Optional[float]

    sigma0_xy: float
    sigma0_z: Optional[float]

    cov_xy: float

    n_obs_xy: int
    n_obs_z: int
    dof_xy: int
    dof_z: int

    distance_residuals: List[Dict[str, Any]]
    vertical_residuals: List[Dict[str, Any]]

    meta: Dict[str, Any]
    config: Dict[str, Any]


# ---------------------------------------------------------------------------
# Lecture JSON + calcul
# ---------------------------------------------------------------------------


def _load_trilateration_json(json_path: Path) -> Dict[str, Any]:
    with json_path.open("r", encoding="utf-8") as f:
        return json.load(f)


def _adjust_xy(
    coords_xy: np.ndarray,
    H_obs: np.ndarray,
    max_iter: int = 20,
    tol: float = 1e-6,
):
    """
    Non-linear LSQ adjustment for XY using horizontal distances.

    Minimise Σ (d_calc - H_obs)² par Gauss-Newton.
    """
    n = H_obs.size
    if n < 2:
        raise ValueError(
            i18n.tr("XY adjustment needs at least 2 horizontal distances.")
        )

    # Initial guess: centroid of known stations
    x = float(coords_xy[:, 0].mean())
    y = float(coords_xy[:, 1].mean())

    # Equal weights
    P = np.eye(n, dtype=float)
    N = None

    for _ in range(max_iter):
        dx = x - coords_xy[:, 0]
        dy = y - coords_xy[:, 1]
        d_calc = np.hypot(dx, dy)
        d_calc = np.where(d_calc < 1e-9, 1e-9, d_calc)

        # residuals r = d_calc - H_obs
        r = d_calc - H_obs

        # Jacobian J = ∂d_calc/∂(X,Y)
        J = np.column_stack((dx / d_calc, dy / d_calc))  # shape (n,2)

        N = J.T @ P @ J
        u = J.T @ P @ r.reshape(-1, 1)

        try:
            delta = -np.linalg.solve(N, u)
        except np.linalg.LinAlgError:
            delta = -np.linalg.pinv(N) @ u

        dx_corr = float(delta[0, 0])
        dy_corr = float(delta[1, 0])

        x += dx_corr
        y += dy_corr

        if math.hypot(dx_corr, dy_corr) < tol:
            break

    # Final residuals & precision estimates
    dx = x - coords_xy[:, 0]
    dy = y - coords_xy[:, 1]
    d_calc = np.hypot(dx, dy)
    d_calc = np.where(d_calc < 1e-9, 1e-9, d_calc)

    v = H_obs - d_calc  # obs - computed
    dof = max(1, n - 2)
    sigma0_sq = float((v @ v) / dof)
    sigma0 = math.sqrt(sigma0_sq)

    if N is None:
        raise RuntimeError("Normal matrix not computed in XY adjustment.")

    try:
        Qxx = np.linalg.inv(N)
    except np.linalg.LinAlgError:
        Qxx = np.linalg.pinv(N)

    sigma_x = sigma0 * math.sqrt(float(Qxx[0, 0]))
    sigma_y = sigma0 * math.sqrt(float(Qxx[1, 1]))
    cov_xy = sigma0**2 * float(Qxx[0, 1])

    return x, y, sigma0, sigma_x, sigma_y, cov_xy, v, d_calc


def _compute_solution(data: Dict[str, Any]) -> TrilaterationResult:
    """
    Utilise la structure trilateration.json :

      - meta: infos projet
      - config: compute_z, linear_scale_factor, tolerances...
      - calcul: ["S35", ...] -> point inconnu
      - stations: [{matricule, x, y, z}, ...] -> points connus
      - obs: liste d'observations avec :
          * type == "reference"
          * origine == point inconnu
          * cible == station connue
          * di = distance de pente (slope)
          * av = angle vertical en gon (100g = horizontale)
          * hi / hp
    """
    meta = data.get("meta", {})
    config = data.get("config", {})

    # Seul config.compute_z est pris en compte (meta.compute_z ignoré)
    compute_z = bool(config.get("compute_z", False))
    scale = float(config.get("linear_scale_factor", 1.0))

    # Unknown point id
    calcul_list = data.get("calcul") or []
    if not calcul_list:
        raise ValueError(i18n.tr("Trilateration JSON has no 'calcul' block."))
    unknown_id = calcul_list[0]

    # Known stations
    stations = data.get("stations", [])
    stations_by_id = {s["matricule"]: s for s in stations}
    if not stations_by_id:
        raise ValueError(i18n.tr("Trilateration JSON has no 'stations' block."))

    # Observations: from unknown point to known stations, type == "reference"
    raw_obs: List[tuple[int, Dict[str, Any]]] = []
    for idx, obs in enumerate(data.get("obs", []), start=1):
        if obs.get("type") != "reference":
            continue
        if obs.get("origine") != unknown_id:
            continue

        target_id = obs.get("cible")
        if target_id not in stations_by_id:
            raise ValueError(
                i18n.tr(
                    "Observation #{idx} references unknown station '{pid}'."
                ).format(idx=idx, pid=target_id)
            )
        raw_obs.append((idx, obs))

    if len(raw_obs) < 3:
        raise ValueError(
            i18n.tr(
                "Trilateration needs at least 3 distance observations from "
                "point '{pid}' to known stations, got {n}."
            ).format(pid=unknown_id, n=len(raw_obs))
        )

    # Build arrays for XY adjustment
    coords_xy: List[tuple[float, float]] = []
    H_obs: List[float] = []
    obs_info: List[Dict[str, Any]] = []

    for idx, obs in raw_obs:
        target_id = obs["cible"]
        st = stations_by_id[target_id]

        xk = float(st["x"])
        yk = float(st["y"])
        zk = float(st.get("z", 0.0))

        # Slope distance + vertical angle
        di = float(obs["di"]) * scale
        av = float(obs["av"])
        hi = float(obs.get("hi", 0.0))
        hp = float(obs.get("hp", 0.0))

        # Convention verticale :
        #   av en gons, 0g = zénith, 100g = horizontale, 200g = nadir
        # On utilise l'angle d'inclinaison i = 100g - av
        inc_gon = 100.0 - av
        inc_rad = inc_gon * math.pi / 200.0

        # Horizontal distance (from slope)
        dH = di * math.cos(inc_rad)
        # Vertical geometric component (axe instrument -> prisme)
        dZ = di * math.sin(inc_rad)

        coords_xy.append((xk, yk))
        H_obs.append(dH)

        obs_info.append(
            {
                "index": idx,
                "target_id": target_id,
                "station_x": xk,
                "station_y": yk,
                "station_z": zk,
                "di": di,
                "av": av,
                "hi": hi,
                "hp": hp,
                "dH_obs": dH,
                "dZ_geom": dZ,
            }
        )

    coords_xy_arr = np.array(coords_xy, dtype=float)
    H_obs_arr = np.array(H_obs, dtype=float)
    n_obs_xy = len(H_obs_arr)

    # XY non-linear LSQ
    X, Y, sigma0_xy, sigma_x, sigma_y, cov_xy, v_H, H_calc = _adjust_xy(
        coords_xy_arr, H_obs_arr
    )

    # Tolerance for horizontal distance residuals
    dist_res_tol = float(
        config.get(
            "distance_residual_tolerance",
            config.get("distance_back_forward_tolerance", 0.01),
        )
    )

    distance_residuals: List[Dict[str, Any]] = []
    for info, v_i, Hc in zip(obs_info, v_H, H_calc):
        distance_residuals.append(
            {
                "obs_index": info["index"],
                "target_id": info["target_id"],
                "measured_horizontal": info["dH_obs"],
                "computed_horizontal": float(Hc),
                "residual": float(v_i),
                "abs_residual": float(abs(v_i)),
                "within_tolerance": abs(v_i) <= dist_res_tol,
            }
        )

    # Z adjustment (optional)
    z_est: Optional[float] = None
    sigma_z: Optional[float] = None
    sigma0_z: Optional[float] = None
    vertical_residuals: List[Dict[str, Any]] = []
    n_obs_z = 0
    dof_z = 0

    if compute_z:
        # Chaque observation fournit:
        #   Z_prisme = Z_unknown + hi + dZ  =>  Z_unknown = Z_prisme - hi - dZ
        Z_candidates: List[float] = []
        for info in obs_info:
            Z_prism = info["station_z"] + info["hp"]
            Zi = Z_prism - info["hi"] - info["dZ_geom"]
            Z_candidates.append(Zi)

        Z_arr = np.array(Z_candidates, dtype=float)
        z_est = float(Z_arr.mean())

        v_Z = Z_arr - z_est
        n_obs_z = len(Z_arr)
        dof_z = max(1, n_obs_z - 1)

        sigma0_z = float(math.sqrt(float((v_Z @ v_Z) / dof_z)))
        # Var(Z) = σ0² / n dans ce modèle simple
        sigma_z = float(sigma0_z / math.sqrt(n_obs_z))

        for info, Zi, vZi in zip(obs_info, Z_arr, v_Z):
            vertical_residuals.append(
                {
                    "obs_index": info["index"],
                    "target_id": info["target_id"],
                    "measured_vertical_angle": info["av"],
                    "di": info["di"],
                    "geometric_dZ": info["dZ_geom"],
                    "z_from_observation": float(Zi),
                    "residual": float(vZi),
                    "abs_residual": float(abs(vZi)),
                }
            )

    dof_xy = max(1, n_obs_xy - 2)

    return TrilaterationResult(
        point_id=unknown_id,
        x=float(X),
        y=float(Y),
        z=z_est,
        sigma_x=float(sigma_x),
        sigma_y=float(sigma_y),
        sigma_z=float(sigma_z) if sigma_z is not None else None,
        sigma0_xy=float(sigma0_xy),
        sigma0_z=float(sigma0_z) if sigma0_z is not None else None,
        cov_xy=float(cov_xy),
        n_obs_xy=n_obs_xy,
        n_obs_z=n_obs_z,
        dof_xy=dof_xy,
        dof_z=dof_z,
        distance_residuals=distance_residuals,
        vertical_residuals=vertical_residuals,
        meta=meta,
        config=config,
    )


# ---------------------------------------------------------------------------
# Génération du rapport Markdown
# ---------------------------------------------------------------------------


def _format_float(value: Optional[float], ndigits: int = 4) -> str:
    if value is None:
        return ""
    return f"{value:.{ndigits}f}"


def _build_markdown_report(data: Dict[str, Any], result: TrilaterationResult) -> str:
    """
    Construit le rapport Markdown, dans l'esprit de _build_markdown_report()
    de closed_traverse.py, mais adapté à une trilatération ponctuelle.
    """
    meta = data.get("meta", {})
    config = data.get("config", {})

    project_name = meta.get("project_name", "")
    observer = meta.get("observer", "")
    instrument = meta.get("instrument", "")
    date = meta.get("date", "")
    angle_unit = meta.get("angle_unit", "gon")
    distance_unit = meta.get("distance_unit", "m")

    # Tolérance utilisée pour interpréter les résidus de distances horizontales
    dist_tol = config.get("distance_residual_tolerance")
    if dist_tol is None:
        dist_tol = config.get("distance_back_forward_tolerance")

    md_parts: List[str] = []

    md_parts.append(f"# {i18n.tr('Trilateration report')}")
    md_parts.append("")
    md_parts.append(f"- **{i18n.tr('Project')}**: {project_name}")
    md_parts.append(f"- **{i18n.tr('Observer')}**: {observer}")
    md_parts.append(f"- **{i18n.tr('Instrument')}**: {instrument}")
    md_parts.append(f"- **{i18n.tr('Date')}**: {date}")
    md_parts.append(f"- **{i18n.tr('Unknown point')}**: {result.point_id}")
    md_parts.append(
        f"- **{i18n.tr('Number of horizontal distance observations')}**: {result.n_obs_xy}"
    )
    md_parts.append("")

    # Bloc tolérances : on ne garde que la tolérance sur les résidus de distance
    if dist_tol is not None:
        md_parts.append(f"### {i18n.tr('Check tolerances')}")
        md_parts.append("")
        md_parts.append(
            f"- **{i18n.tr('Horizontal distance residual tolerance')}**: "
            f"{dist_tol} {distance_unit}"
        )
        md_parts.append("")

    # Coordonnées
    md_parts.append(f"## {i18n.tr('Coordinates of unknown point')}")
    md_parts.append("")
    md_parts.append("| Component | Value (m) | Std. dev. (m) |")
    md_parts.append("|-----------|-----------|---------------|")
    md_parts.append(
        f"| X | {_format_float(result.x, 4)} | {_format_float(result.sigma_x, 4)} |"
    )
    md_parts.append(
        f"| Y | {_format_float(result.y, 4)} | {_format_float(result.sigma_y, 4)} |"
    )
    if result.z is not None and result.sigma_z is not None:
        md_parts.append(
            f"| Z | {_format_float(result.z, 4)} | {_format_float(result.sigma_z, 4)} |"
        )
    md_parts.append("")

    # Résidus horizontaux
    md_parts.append(f"## {i18n.tr('Horizontal distance residuals')}")
    md_parts.append("")
    md_parts.append(
        "| Obs index | Target point | Measured H (m) | Computed H (m) | "
        "Residual (m) | |Residual| (m) | OK? |"
    )
    md_parts.append(
        "|-----------|--------------|---------------:|---------------:|"
        "-------------:|---------------:|:---:|"
    )

    for r in result.distance_residuals:
        ok_str = ""
        if r.get("within_tolerance") is True:
            ok_str = "OK"
        elif r.get("within_tolerance") is False:
            ok_str = "KO"

        md_parts.append(
            "| {idx} | {pid} | {mH} | {cH} | {res} | {ares} | {ok} |".format(
                idx=r["obs_index"],
                pid=r["target_id"],
                mH=_format_float(r["measured_horizontal"], 4),
                cH=_format_float(r["computed_horizontal"], 4),
                res=_format_float(r["residual"], 4),
                ares=_format_float(r["abs_residual"], 4),
                ok=ok_str,
            )
        )

    md_parts.append("")
    md_parts.append(
        "_Horizontal distances are derived from raw slope distances and vertical angles._"
    )
    md_parts.append("")

    # Partie verticale (si calcul Z activé)
    if result.n_obs_z > 0:
        md_parts.append(f"## {i18n.tr('Vertical information')}")
        md_parts.append("")
        md_parts.append(
            f"- **{i18n.tr('Number of vertical observations')}**: {result.n_obs_z}"
        )
        md_parts.append(f"- **{i18n.tr('Degrees of freedom (Z)')}**: {result.dof_z}")
        md_parts.append("")
        md_parts.append(
            "| Obs index | Target point | av ({ang}) | di ({dist}) | "
            "dZ geom. (m) | Z from obs (m) | Residual (m) | |Residual| (m) |".format(
                ang=angle_unit, dist=distance_unit
            )
        )
        md_parts.append(
            "|-----------|--------------|-----------:|-----------:|"
            "-------------:|--------------:|-------------:|---------------:|"
        )

        for r in result.vertical_residuals:
            md_parts.append(
                "| {idx} | {pid} | {av} | {di} | {dz} | {zobs} | {res} | {ares} |".format(
                    idx=r["obs_index"],
                    pid=r["target_id"],
                    av=_format_float(r["measured_vertical_angle"], 4),
                    di=_format_float(r["di"], 4),
                    dz=_format_float(r["geometric_dZ"], 4),
                    zobs=_format_float(r["z_from_observation"], 4),
                    res=_format_float(r["residual"], 4),
                    ares=_format_float(r["abs_residual"], 4),
                )
            )

        md_parts.append("")
        md_parts.append(
            "_Each vertical observation provides an individual estimate of the unknown point elevation._"
        )
        md_parts.append(
            "_The final Z is obtained by a least-squares average of these individual estimates._"
        )
        md_parts.append("")

    return "\n".join(md_parts)


# ---------------------------------------------------------------------------
# Classe Trilateration avec méthode de classe (même style que ClosedTraverse)
# ---------------------------------------------------------------------------


class Trilateration:
    """
    Helper class for computing an isolated point by trilateration.

    Main entry point:
        Trilateration.compute_trilateration_from_json(json_path, output_dir)
    """

    @classmethod
    def compute_trilateration_from_json(
        cls,
        json_path: Path,
        output_dir: Path,
    ) -> Dict[str, Any]:
        """
        Calcule un point isolé par trilatération à partir d'un fichier JSON
        (par ex. 'trilateration.json').

        Étapes:
          - lecture du JSON,
          - conversion distances de pente + angles verticaux → distances horizontales,
          - ajustement non-linéaire XY sur les distances horizontales,
          - calcul optionnel de Z depuis les composantes verticales,
          - calcul des résidus (horizontaux + verticaux),
          - génération d'un rapport Markdown 'trilateration.md',
          - conversion éventuelle en HTML, PDF et ODT si les fonctions
            markdown_to_html / markdown_to_pdf / markdown_to_odt sont disponibles.

        Retourne un dict contenant le résultat et les chemins de rapport.
        """
        output_dir.mkdir(parents=True, exist_ok=True)

        data = _load_trilateration_json(json_path)
        result = _compute_solution(data)

        # Rapport Markdown
        md_content = _build_markdown_report(data, result)
        md_path = output_dir / "trilateration.md"
        md_path.write_text(md_content, encoding="utf-8")

        # Chemins pour conversions éventuelles
        html_path = output_dir / "trilateration.html"
        pdf_path = output_dir / "trilateration.pdf"
        odt_path = output_dir / "trilateration.odt"

        # Conversions HTML / PDF / ODT (même pattern que closed_traverse)
        if markdown_to_html is not None:
            try:
                try:
                    markdown_to_html(md_path, html_path)  # type: ignore[arg-type]
                except TypeError:
                    markdown_to_html(md_path)  # type: ignore[arg-type]
            except Exception:
                pass

        if markdown_to_pdf is not None:
            try:
                try:
                    markdown_to_pdf(md_path, pdf_path)  # type: ignore[arg-type]
                except TypeError:
                    markdown_to_pdf(md_path)  # type: ignore[arg-type]
            except Exception:
                pass

        if markdown_to_odt is not None:
            try:
                try:
                    markdown_to_odt(md_path, odt_path)  # type: ignore[arg-type]
                except TypeError:
                    markdown_to_odt(md_path)  # type: ignore[arg-type]
            except Exception:
                pass

        return {
            "result": result,
            "report_paths": {
                "md": md_path,
                "html": html_path if html_path.exists() else None,
                "pdf": pdf_path if pdf_path.exists() else None,
                "odt": odt_path if odt_path.exists() else None,
            },
        }


# Stub i18n pour tests hors QGIS (même idée que dans closed_traverse.py)
class i18n:
    @staticmethod
    def tr(s: str) -> str:
        return s


if __name__ == "__main__":
    # Exemple d'appel direct pour tests:
    Trilateration.compute_trilateration_from_json(
        Path("c:/temp/trilateration.json"),
        Path("c:/temp"),
    )
    pass
