# -*- coding: utf-8 -*-

import re
import sqlite3
import xml.etree.ElementTree as ET
import os
import shutil

from qgis.PyQt.QtCore import Qt, QTimer, QRectF, QSize, QRect
from qgis.PyQt.QtGui import QImage, QPainter
from qgis.PyQt.QtWidgets import (
    QDockWidget, QWidget, QVBoxLayout, QHBoxLayout, QPlainTextEdit,
    QPushButton, QFileDialog, QMessageBox, QSplitter, QLabel,
    QToolButton, QMenu, QInputDialog, QCheckBox
)
from qgis.PyQt.QtSvg import QSvgGenerator
from qgis.core import QgsProject, QgsSettings

from .ddl_parser import parse_ddl
from .erd_scene import ErdScene, ErdView


def normalize_sql_for_sqlite(sql: str) -> str:
    """
    SQLite exige FOREIGN KEY (<liste de colonnes>).
    Cette fonction transforme les formes plus tolérantes comme
      FOREIGN KEY "col"
      FOREIGN KEY col
    en
      FOREIGN KEY ("col")
    en laissant tranquilles les cas où les parenthèses sont déjà là.
    """
    pattern = re.compile(
        r'FOREIGN\s+KEY\s+(?!\()(?P<ident>"[^"]+"|\w+)',
        re.IGNORECASE,
    )

    def repl(m):
        ident = m.group("ident")
        return f"FOREIGN KEY ({ident})"

    return pattern.sub(repl, sql)


def build_drawio_xml(tables, fks) -> bytes:
    """
    Construit un XML draw.io avec :
      - une cellule shape=table par table
      - une ligne par colonne (shape=tableRow)
      - une petite cellule "PK"/"FK" + une cellule texte "col : type"
      - des arêtes orthogonales connectées aux lignes de colonnes
    """
    mxfile = ET.Element("mxfile")
    diagram = ET.SubElement(mxfile, "diagram", id="diagram1", name="Page-1")
    model = ET.SubElement(diagram, "mxGraphModel")
    root = ET.SubElement(model, "root")
    ET.SubElement(root, "mxCell", id="0")
    ET.SubElement(root, "mxCell", id="1", parent="0")

    TABLE_STYLE = (
        "shape=table;startSize=30;container=1;collapsible=1;"
        "childLayout=tableLayout;fixedRows=1;rowLines=0;fontStyle=1;"
        "align=center;resizeLast=1;html=1;"
    )
    ROW_STYLE_COMMON = (
        "shape=tableRow;horizontal=0;startSize=0;swimlaneHead=0;"
        "swimlaneBody=0;fillColor=none;collapsible=0;dropTarget=0;"
        "points=[[0,0.5],[1,0.5]];portConstraint=eastwest;"
        "top=0;left=0;right=0;bottom={bottom};"
    )
    PK_CELL_STYLE = (
        "shape=partialRectangle;connectable=0;fillColor=none;top=0;"
        "left=0;bottom=0;right=0;fontStyle=1;overflow=hidden;"
        "whiteSpace=wrap;html=1;"
    )
    COL_CELL_STYLE = (
        "shape=partialRectangle;connectable=0;fillColor=none;top=0;"
        "left=0;bottom=0;right=0;align=left;spacingLeft=6;fontStyle=5;"
        "overflow=hidden;whiteSpace=wrap;html=1;"
    )
    EDGE_STYLE = (
        "edgeStyle=orthogonalEdgeStyle;rounded=1;orthogonalLoop=1;"
        "jettySize=auto;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;"
        "endArrow=classicThin;endFill=1;startArrow=none;startFill=0;"
        "curved=0;"
    )
    CARD_STYLE = (
        "edgeLabel;html=1;align=center;verticalAlign=middle;"
        "resizable=0;points=[];"
    )

    # Pour pouvoir connecter les arêtes sur les lignes de colonnes
    column_row_ids = {}

    cols_per_row = max(1, int(len(tables) ** 0.5)) if tables else 1
    spacing_x = 260
    spacing_y = 160
    base_x = 40
    base_y = 40

    # --------------------- tables + lignes ---------------------
    for idx, (tname, table) in enumerate(tables.items()):
        table_id = f"t{idx+1}"
        row = idx // cols_per_row
        col = idx % cols_per_row
        x = base_x + col * spacing_x
        y = base_y + row * spacing_y

        # cellule principale de table
        table_cell = ET.SubElement(
            root,
            "mxCell",
            id=table_id,
            value=tname,
            style=TABLE_STYLE,
            vertex="1",
            parent="1",
        )
        n_rows = max(1, len(table.columns))
        table_width = 200
        table_height = 30 + 30 * n_rows
        geo = ET.SubElement(
            table_cell,
            "mxGeometry",
            x=str(x),
            y=str(y),
            width=str(table_width),
            height=str(table_height),
        )
        geo.set("as", "geometry")

        # une ligne par colonne
        cols_list = list(table.columns.values())
        for ridx, col_obj in enumerate(cols_list):
            row_style = ROW_STYLE_COMMON.format(bottom=0)
            row_id = f"{table_id}_r{ridx+1}"
            row_cell = ET.SubElement(
                root,
                "mxCell",
                id=row_id,
                value="",
                style=row_style,
                vertex="1",
                parent=table_id,
            )
            row_geo = ET.SubElement(
                row_cell,
                "mxGeometry",
                y=str(30 + 30 * ridx),
                width=str(table_width),
                height="30",
            )
            row_geo.set("as", "geometry")

            # cellule PK/FK
            flag = ""
            if col_obj.primary_key:
                flag = "PK"
            elif col_obj.foreign_key:
                flag = "FK"

            flag_cell = ET.SubElement(
                root,
                "mxCell",
                id=f"{row_id}_flag",
                value=flag,
                style=PK_CELL_STYLE,
                vertex="1",
                parent=row_id,
            )
            flag_geo = ET.SubElement(
                flag_cell,
                "mxGeometry",
                width="30",
                height="30",
            )
            flag_geo.set("as", "geometry")

            # cellule texte "col : type"
            text = col_obj.name
            if col_obj.type:
                text += f"<span style=\"white-space: pre;\">\t</span>{col_obj.type}"

            text_cell = ET.SubElement(
                root,
                "mxCell",
                id=f"{row_id}_col",
                value=text,
                style=COL_CELL_STYLE,
                vertex="1",
                parent=row_id,
            )
            text_geo = ET.SubElement(
                text_cell,
                "mxGeometry",
                x="30",
                width=str(table_width - 30),
                height="30",
            )
            text_geo.set("as", "geometry")

            column_row_ids[(tname, col_obj.name)] = row_id

    # --------------------- arêtes FK ---------------------------
    edge_counter = 1
    for fk in fks:
        if not fk.src_cols or not fk.ref_cols:
            continue

        src_table = fk.src_table
        ref_table = fk.ref_table
        src_col = fk.src_cols[0]
        ref_col = fk.ref_cols[0]

        src_row = column_row_ids.get((src_table, src_col))
        ref_row = column_row_ids.get((ref_table, ref_col))

        if not src_row or not ref_row:
            continue

        edge_id = f"e{edge_counter}"
        edge_counter += 1

        edge_cell = ET.SubElement(
            root,
            "mxCell",
            id=edge_id,
            value="",
            style=EDGE_STYLE,
            edge="1",
            parent="1",
            source=src_row,
            target=ref_row,
        )
        edge_geo = ET.SubElement(
            edge_cell,
            "mxGeometry",
            relative="1",
        )
        edge_geo.set("as", "geometry")

        # cardinalité calculée à partir de src_not_null / src_unique
        if fk.src_unique and fk.src_not_null:
            card = "1-1"
        elif fk.src_unique and not fk.src_not_null:
            card = "0-1"
        elif fk.src_not_null:
            card = "1-N"
        else:
            card = "0-N"

        lbl_cell = ET.SubElement(
            root,
            "mxCell",
            id=f"{edge_id}_lbl",
            value=card,
            style=CARD_STYLE,
            vertex="0",
            connectable="0",
            parent=edge_id,
        )
        lbl_geo = ET.SubElement(
            lbl_cell,
            "mxGeometry",
            x="0.5",
            y="-1",
            relative="1",
        )
        lbl_geo.set("as", "geometry")

    return ET.tostring(mxfile, encoding="utf-8", xml_declaration=True)


class SqlErDock(QDockWidget):
    def __init__(self, parent=None):
        super().__init__("sQlER – Diagramme ER SQL", parent)
        self.setObjectName("SqlErDock")

        self._timer = QTimer(self)
        self._timer.setSingleShot(True)
        self._timer.setInterval(600)
        self._timer.timeout.connect(self.render_diagram)

        # pour export DRAWIO
        self._last_tables = None
        self._last_fks = None

        root = QWidget()
        self.setWidget(root)
        layout = QVBoxLayout(root)
        layout.setContentsMargins(4, 4, 4, 4)

        # ------------------------------------------------------------------ barre de boutons

        bar = QHBoxLayout()

        self.btn_open = QPushButton("Ouvrir .sql…")

        # bouton Export…
        self.btn_export = QToolButton()
        self.btn_export.setText("Exporter…")
        self.btn_export.setToolTip("Exporter le diagramme")
        export_menu = QMenu(self.btn_export)
        act_export_png = export_menu.addAction("PNG…")
        act_export_svg = export_menu.addAction("SVG…")
        act_export_drawio = export_menu.addAction("DRAWIO…")
        self.btn_export.setMenu(export_menu)
        self.btn_export.setPopupMode(QToolButton.MenuButtonPopup)

        # bouton Appliquer SQL…
        self.btn_exec = QToolButton()
        self.btn_exec.setText("Appliquer dans SpatiaLite…")
        self.btn_exec.setToolTip("Exécuter le script SQL dans une base SpatiaLite")
        exec_menu = QMenu(self.btn_exec)
        act_exec_connected = exec_menu.addAction("Vers une base SpatiaLite connectée…")
        act_exec_file = exec_menu.addAction("Vers un fichier SpatiaLite…")
        self.btn_exec.setMenu(exec_menu)
        self.btn_exec.setPopupMode(QToolButton.MenuButtonPopup)

        bar.addWidget(self.btn_open)
        bar.addStretch(1)
        bar.addWidget(self.btn_export)
        bar.addWidget(self.btn_exec)
        layout.addLayout(bar)

        # ------------------------------------------------------------------ splitter vertical

        splitter = QSplitter(Qt.Vertical)
        layout.addWidget(splitter, 1)

        # éditeur SQL en haut
        self.editor = QPlainTextEdit()
        self.editor.setPlaceholderText("Colle ici ton DDL (CREATE TABLE, FOREIGN KEY, ...)")
        splitter.addWidget(self.editor)

        # diagramme en bas
        bottom = QWidget()
        bottom_layout = QVBoxLayout(bottom)
        bottom_layout.setContentsMargins(0, 0, 0, 0)

        self.info_label = QLabel("Aucun diagramme généré pour le moment.")
        bottom_layout.addWidget(self.info_label)

        self.view = ErdView()
        bottom_layout.addWidget(self.view, 1)

        splitter.addWidget(bottom)
        splitter.setStretchFactor(0, 1)
        splitter.setStretchFactor(1, 2)

        # ------------------------------------------------------------------ connexions

        self.editor.textChanged.connect(self._on_text_changed)
        self.btn_open.clicked.connect(self.open_sql_file)

        act_export_png.triggered.connect(self.export_png)
        act_export_svg.triggered.connect(self.export_svg)
        act_export_drawio.triggered.connect(self.export_drawio)

        act_exec_connected.triggered.connect(self.exec_sql_connected)
        act_exec_file.triggered.connect(self.exec_sql_file)

        # ------------------------------------------------------------------ exemple par défaut

        self.editor.setPlainText(
            'CREATE TABLE customers (\n'
            '  customer_id INTEGER PRIMARY KEY,\n'
            '  name TEXT NOT NULL,\n'
            '  email TEXT\n'
            ');\n\n'
            'CREATE TABLE products (\n'
            '  product_id INTEGER PRIMARY KEY,\n'
            '  name TEXT NOT NULL,\n'
            '  price NUMERIC NOT NULL\n'
            ');\n\n'
            'CREATE TABLE orders (\n'
            '  order_id INTEGER PRIMARY KEY,\n'
            '  customer_id INTEGER NOT NULL,\n'
            '  order_date DATE NOT NULL,\n'
            '  FOREIGN KEY (customer_id) REFERENCES customers(customer_id)\n'
            ');\n\n'
            'CREATE TABLE order_items (\n'
            '  order_item_id INTEGER PRIMARY KEY,\n'
            '  order_id INTEGER NOT NULL,\n'
            '  product_id INTEGER NOT NULL,\n'
            '  quantity INTEGER NOT NULL DEFAULT 1,\n'
            '  FOREIGN KEY (order_id) REFERENCES orders(order_id),\n'
            '  FOREIGN KEY (product_id) REFERENCES products(product_id)\n'
            ');\n'
        )

        self.render_diagram()

    # ------------------------------------------------------------------ outils internes

    def _on_text_changed(self):
        # debounce pour ne pas recalculer à chaque frappe
        self._timer.start()

    def _current_sql(self) -> str | None:
        sql = self.editor.toPlainText()
        if not sql.strip():
            QMessageBox.warning(self, "sQlER", "Aucun SQL à analyser.")
            return None
        return sql

    # ------------------------------------------------------------------ actions UI

    def open_sql_file(self):
        path, _ = QFileDialog.getOpenFileName(
            self, "Ouvrir un fichier SQL", "", "SQL (*.sql *.txt);;Tous les fichiers (*.*)"
        )
        if not path:
            return
        try:
            with open(path, "r", encoding="utf-8") as f:
                txt = f.read()
        except UnicodeDecodeError:
            with open(path, "r", encoding="cp1252", errors="replace") as f:
                txt = f.read()
        self.editor.setPlainText(txt)

    def render_diagram(self):
        sql = self.editor.toPlainText()
        if not sql.strip():
            return
        try:
            tables, fks = parse_ddl(sql)
        except Exception as ex:
            QMessageBox.warning(self, "Analyse SQL", f"Erreur lors du parsing du script :\n{ex}")
            return

        self._last_tables = tables
        self._last_fks = fks

        scene = ErdScene(tables, fks)
        self.view.setScene(scene)
        self.view.reset_view()

        self.info_label.setText(f"{len(tables)} tables, {len(fks)} contraintes de clés étrangères.")

    # ------------------------------------------------------------------ exports

    def export_png(self):
        scene = self.view.scene()
        if not scene:
            QMessageBox.warning(self, "Export PNG", "Aucun diagramme à exporter.")
            return
        path, _ = QFileDialog.getSaveFileName(
            self, "Exporter le diagramme en PNG", "diagram.png", "PNG (*.png)"
        )
        if not path:
            return

        rect = scene.itemsBoundingRect()
        img = QImage(int(rect.width()) + 40, int(rect.height()) + 40, QImage.Format_ARGB32)
        img.fill(Qt.transparent)

        painter = QPainter(img)
        painter.setRenderHint(QPainter.Antialiasing)
        painter.translate(-rect.left() + 20, -rect.top() + 20)
        scene.render(painter, target=QRectF(0, 0, rect.width(), rect.height()), source=rect)
        painter.end()

        if not img.save(path):
            QMessageBox.warning(self, "Export PNG", "Échec de l'enregistrement du fichier PNG.")

    def export_svg(self):
        scene = self.view.scene()
        if not scene:
            QMessageBox.warning(self, "Export SVG", "Aucun diagramme à exporter.")
            return
        path, _ = QFileDialog.getSaveFileName(
            self, "Exporter le diagramme en SVG", "diagram.svg", "SVG (*.svg)"
        )
        if not path:
            return

        rect = scene.itemsBoundingRect()
        gen = QSvgGenerator()
        gen.setFileName(path)
        gen.setSize(QSize(int(rect.width()) + 40, int(rect.height()) + 40))
        gen.setViewBox(QRect(0, 0, int(rect.width()) + 40, int(rect.height()) + 40))
        gen.setTitle("sQlER diagram")
        gen.setDescription("Diagramme ER généré par sQlER.")

        painter = QPainter(gen)
        painter.setRenderHint(QPainter.Antialiasing)
        painter.translate(-rect.left() + 20, -rect.top() + 20)
        scene.render(painter, target=QRectF(0, 0, rect.width(), rect.height()), source=rect)
        painter.end()

    def export_drawio(self):
        if not self._last_tables or not self._last_fks:
            sql = self._current_sql()
            if sql is None:
                return
            try:
                tables, fks = parse_ddl(sql)
            except Exception as ex:
                QMessageBox.warning(self, "Export DRAWIO", f"Erreur lors du parsing du script :\n{ex}")
                return
            self._last_tables, self._last_fks = tables, fks

        tables = self._last_tables
        fks = self._last_fks

        path, _ = QFileDialog.getSaveFileName(
            self, "Exporter en format draw.io", "diagram.drawio", "Draw.io (*.drawio)"
        )
        if not path:
            return

        xml_bytes = build_drawio_xml(tables, fks)
        try:
            with open(path, "wb") as f:
                f.write(xml_bytes)
        except Exception as ex:
            QMessageBox.warning(self, "Export DRAWIO", f"Erreur lors de l'écriture du fichier :\n{ex}")
            return

    # ------------------------------------------------------------------ Exécution SQL dans SpatiaLite
    def _confirm_and_maybe_backup(self, db_path: str) -> bool:
        """
        Demande confirmation à l'utilisateur et propose de créer une sauvegarde
        de la base avant l'exécution du script.
        Retourne True si on peut continuer, False sinon.
        """
        msg = (
            "Es-tu sûr·e de vouloir exécuter ces instructions dans la base suivante :\n\n"
            f"{db_path}\n\n"
            "Tu peux aussi créer une sauvegarde avant l'exécution."
        )

        box = QMessageBox(self)
        box.setIcon(QMessageBox.Warning)
        box.setWindowTitle("Confirmer l'exécution du script SQL")
        box.setText(msg)

        backup_cb = QCheckBox("Créer une sauvegarde de la base avant l'exécution")
        box.setCheckBox(backup_cb)

        box.setStandardButtons(QMessageBox.Yes | QMessageBox.No)
        box.setDefaultButton(QMessageBox.No)

        ret = box.exec()
        if ret != QMessageBox.Yes:
            # utilisateur a annulé
            return False

        # Si la sauvegarde est demandée, on la crée
        if backup_cb.isChecked():
            base, ext = os.path.splitext(db_path)
            suffix, ok = QInputDialog.getText(
                self,
                "Suffixe de sauvegarde",
                "Suffixe à ajouter après un underscore :",
                text="backup",
            )
            if not ok:
                # annulation = on ne continue pas
                return False

            suffix = suffix.strip()
            if not suffix:
                suffix = "backup"

            backup_path = f"{base}_{suffix}{ext}"

            try:
                shutil.copy2(db_path, backup_path)
            except Exception as ex:
                QMessageBox.warning(
                    self,
                    "Sauvegarde",
                    f"Impossible de créer la sauvegarde :\n{backup_path}\n\n{ex}"
                )
                cont = QMessageBox.question(
                    self,
                    "Sauvegarde",
                    "La sauvegarde a échoué.\n"
                    "Veux-tu quand même continuer l'exécution du script ?",
                    QMessageBox.Yes | QMessageBox.No,
                    QMessageBox.No,
                )
                if cont != QMessageBox.Yes:
                    return False

        return True

    def _run_sql_on_sqlite(self, db_path: str, sql: str):
        """
        Exécute le script SQL dans une base SpatiaLite en utilisant sqlite3.
        On tente de charger l'extension 'mod_spatialite' pour permettre
        l'utilisation de fonctions SpatiaLite comme AddGeometryColumn().
        """
        # 1) normalisation légère pour SQLite (FOREIGN KEY col -> FOREIGN KEY (col))
        fixed_sql = normalize_sql_for_sqlite(sql)

        try:
            conn = sqlite3.connect(db_path)
        except Exception as ex:
            QMessageBox.warning(
                self,
                "Exécuter SQL",
                f"Impossible d'ouvrir la base SQLite :\n{db_path}\n\n{ex}",
            )
            return

        try:
            # tentative de chargement de l'extension SpatiaLite
            try:
                conn.enable_load_extension(True)
                try:
                    conn.load_extension("mod_spatialite")
                except Exception:
                    # Si on n'arrive pas à charger SpatiaLite, on continue quand même :
                    # les fonctions SpatiaLite pourront échouer, mais le SQL standard marchera.
                    pass
            except Exception:
                # enable_load_extension pas dispo ou désactivé : on ignore
                pass

            with conn:
                conn.executescript(fixed_sql)
        except Exception as ex:
            QMessageBox.warning(
                self,
                "Exécuter SQL",
                f"Erreur lors de l'exécution du script dans :\n{db_path}\n\n{ex}",
            )
            return
        finally:
            conn.close()

        QMessageBox.information(
            self,
            "Exécuter SQL",
            f"Script exécuté avec succès dans :\n{db_path}",
        )

    def exec_sql_connected(self):
        sql = self._current_sql()
        if sql is None:
            return

        # Connexions SpatiaLite enregistrées (comme dans DBManager)
        settings = QgsSettings()
        settings.beginGroup("SpatiaLite/connections")
        conn_names = settings.childGroups()
        settings.endGroup()

        if not conn_names:
            QMessageBox.information(
                self,
                "Exécuter SQL",
                "Aucune connexion SpatiaLite n'est configurée dans QGIS.\n"
                "Ajoute d'abord une connexion dans le gestionnaire de bases de données.",
            )
            return

        items = []
        label_to_path = {}
        base = "SpatiaLite/connections/"

        for name in conn_names:
            path = settings.value(f"{base}{name}/sqlitepath", "", type=str)
            if not path:
                continue
            label = f"{name} – {path}"
            items.append(label)
            label_to_path[label] = path

        if not items:
            QMessageBox.information(
                self,
                "Exécuter SQL",
                "Impossible de retrouver les chemins des connexions SpatiaLite.",
            )
            return

        if len(items) == 1:
            choice = items[0]
        else:
            choice, ok = QInputDialog.getItem(
                self,
                "Choisir une base SpatiaLite",
                "Base de données cible :",
                items,
                0,
                False,
            )
            if not ok:
                return

        db_path = label_to_path[choice]

        if not self._confirm_and_maybe_backup(db_path):
            return

        self._run_sql_on_sqlite(db_path, sql)

    def exec_sql_file(self):
        sql = self._current_sql()
        if sql is None:
            return

        path, _ = QFileDialog.getOpenFileName(
            self,
            "Choisir une base SpatiaLite",
            "",
            "Bases SpatiaLite (*.sqlite *.db);;Tous les fichiers (*.*)"
        )
        if not path:
            return

        if not self._confirm_and_maybe_backup(path):
            return

        self._run_sql_on_sqlite(path, sql)