# -*- coding: utf-8 -*-
"""
ddl_parser.py
Analyse un script SQL (DDL) pour en extraire les tables, colonnes, PK et FKs.
Conçu pour PostgreSQL / SQLite/SpatiaLite "classiques".
"""

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import re


@dataclass
class Column:
    name: str
    type: str = ""
    not_null: bool = False
    primary_key: bool = False
    unique: bool = False
    default: Optional[str] = None
    foreign_key: bool = False


@dataclass
class ForeignKey:
    name: str
    src_table: str
    src_cols: List[str]
    ref_table: str
    ref_cols: List[str]
    src_not_null: bool = False
    src_unique: bool = False


@dataclass
class Table:
    name: str
    columns: Dict[str, Column] = field(default_factory=dict)
    pk: List[str] = field(default_factory=list)


# ----------------------------- Helpers ---------------------------------


def _strip_comments(sql: str) -> str:
    sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.S)
    sql = re.sub(r"--[^\n]*", "", sql)
    return sql


def _split_statements(sql: str) -> List[str]:
    out: List[str] = []
    buf: List[str] = []
    depth = 0
    in_squote = False
    in_dquote = False
    i = 0
    while i < len(sql):
        ch = sql[i]
        nxt = sql[i + 1] if i + 1 < len(sql) else ""

        if ch == "'" and not in_dquote:
            if in_squote and nxt == "'":
                buf.append(ch)
                buf.append(nxt)
                i += 2
                continue
            in_squote = not in_squote
            buf.append(ch)
            i += 1
            continue

        if ch == '"' and not in_squote:
            in_dquote = not in_dquote
            buf.append(ch)
            i += 1
            continue

        if not in_squote and not in_dquote:
            if ch == "(":
                depth += 1
            elif ch == ")":
                depth = max(0, depth - 1)
            elif ch == ";" and depth == 0:
                stmt = "".join(buf).strip()
                if stmt:
                    out.append(stmt)
                buf = []
                i += 1
                continue

        buf.append(ch)
        i += 1

    last = "".join(buf).strip()
    if last:
        out.append(last)
    return out


def _split_top_level_commas(text: str) -> List[str]:
    parts: List[str] = []
    buf: List[str] = []
    depth = 0
    in_squote = False
    in_dquote = False
    i = 0
    while i < len(text):
        ch = text[i]
        nxt = text[i + 1] if i + 1 < len(text) else ""

        if ch == "'" and not in_dquote:
            if in_squote and nxt == "'":
                buf.append(ch)
                buf.append(nxt)
                i += 2
                continue
            in_squote = not in_squote
            buf.append(ch)
            i += 1
            continue

        if ch == '"' and not in_squote:
            in_dquote = not in_dquote
            buf.append(ch)
            i += 1
            continue

        if not in_squote and not in_dquote:
            if ch == "(":
                depth += 1
            elif ch == ")":
                depth = max(0, depth - 1)
            elif ch == "," and depth == 0:
                part = "".join(buf).strip()
                if part:
                    parts.append(part)
                buf = []
                i += 1
                continue

        buf.append(ch)
        i += 1

    last = "".join(buf).strip()
    if last:
        parts.append(last)
    return parts


def _unquote_ident(ident: str) -> str:
    ident = ident.strip()
    if ident.startswith('"') and ident.endswith('"'):
        return ident[1:-1].replace('""', '"')
    return ident


def _normalize_table_name(raw: str) -> str:
    raw = raw.strip()
    raw = re.sub(r"\s+", " ", raw)
    return raw


def _parse_ident_list(text: str) -> List[str]:
    return [_unquote_ident(x.strip()) for x in text.split(",") if x.strip()]


# ------------------------------ Regex ----------------------------------


CREATE_TABLE_RE = re.compile(
    r'^\s*CREATE\s+TABLE\s+(IF\s+NOT\s+EXISTS\s+)?'
    r'(?P<name>("[^"]+"|\w+)(\s*\.\s*("[^"]+"|\w+))?)\s*'
    r'\((?P<body>.*)\)\s*$',
    re.IGNORECASE | re.S,
)

TABLE_FK_RE = re.compile(
    r'^(CONSTRAINT\s+(?P<cname>("[^"]+"|\w+))\s+)?FOREIGN\s+KEY\s*'
    r'(\((?P<src>[^)]+)\)|(?P<src_single>("[^"]+"|\w+)))\s*'
    r'REFERENCES\s+(?P<ref>("[^"]+"|\w+)(\s*\.\s*("[^"]+"|\w+))?)\s*'
    r'\((?P<refcols>[^)]+)\)',
    re.IGNORECASE | re.S,
)

TABLE_PK_RE = re.compile(
    r'^(CONSTRAINT\s+(?P<cname>("[^"]+"|\w+))\s+)?PRIMARY\s+KEY\s*\((?P<cols>[^)]+)\)',
    re.IGNORECASE | re.S,
)

TABLE_UNIQUE_RE = re.compile(
    r'^(CONSTRAINT\s+(?P<cname>("[^"]+"|\w+))\s+)?UNIQUE\s*\((?P<cols>[^)]+)\)',
    re.IGNORECASE | re.S,
)

ALTER_FK_RE = re.compile(
    r'^\s*ALTER\s+TABLE\s+(ONLY\s+)?(?P<table>("[^"]+"|\w+)(\s*\.\s*("[^"]+"|\w+))?)\s+'
    r'ADD\s+(CONSTRAINT\s+(?P<cname>("[^"]+"|\w+))\s+)?FOREIGN\s+KEY\s*'
    r'(\((?P<src>[^)]+)\)|(?P<src_single>("[^"]+"|\w+)))\s*'
    r'REFERENCES\s+(?P<ref>("[^"]+"|\w+)(\s*\.\s*("[^"]+"|\w+))?)\s*'
    r'\((?P<refcols>[^)]+)\)',
    re.IGNORECASE | re.S,
)

INLINE_REF_RE = re.compile(
    r'\bREFERENCES\s+(?P<ref>("[^"]+"|\w+)(\s*\.\s*("[^"]+"|\w+))?)\s*\((?P<refcols>[^)]+)\)',
    re.IGNORECASE,
)

COLUMN_DEF_RE = re.compile(
    r'^(?P<col>"[^"]+"|\w+)\s+(?P<rest>.+)$',
    re.S,
)


def _parse_column_def(item: str):
    s = item.strip()
    if not s:
        return None
    if re.match(r"^(CONSTRAINT|PRIMARY\s+KEY|FOREIGN\s+KEY|UNIQUE|CHECK)\b", s, flags=re.I):
        return None

    m = COLUMN_DEF_RE.match(s)
    if not m:
        return None

    col = _unquote_ident(m.group("col"))
    rest = m.group("rest").strip()

    keywords = ["not", "null", "default", "constraint", "primary", "references", "unique", "check", "collate"]
    tokens = re.split(r"\s+", rest)
    type_tokens: List[str] = []
    i = 0
    while i < len(tokens):
        tk = tokens[i]
        if tk.lower() in keywords:
            break
        type_tokens.append(tk)
        i += 1
    sql_type = " ".join(type_tokens) if type_tokens else ""

    not_null = bool(re.search(r"\bNOT\s+NULL\b", rest, flags=re.I))
    is_pk = bool(re.search(r"\bPRIMARY\s+KEY\b", rest, flags=re.I))
    is_unique = bool(re.search(r"\bUNIQUE\b", rest, flags=re.I))
    default_expr = None
    mdef = re.search(r"\bDEFAULT\b\s+(.+)$", rest, flags=re.I)
    if mdef:
        default_expr = mdef.group(1).strip()

    inline_ref = None
    rm = INLINE_REF_RE.search(rest)
    if rm:
        ref_table = _normalize_table_name(rm.group("ref"))
        ref_cols = _parse_ident_list(rm.group("refcols"))
        inline_ref = (ref_table, ref_cols)

    return col, sql_type, not_null, is_pk, is_unique, default_expr, inline_ref


# --------------------------- parse_ddl ---------------------------------


def parse_ddl(sql_text: str) -> Tuple[Dict[str, Table], List[ForeignKey]]:
    sql = _strip_comments(sql_text)
    stmts = _split_statements(sql)

    tables: Dict[str, Table] = {}
    fks: List[ForeignKey] = []
    fk_counter = 1

    # CREATE TABLE
    for st in stmts:
        m = CREATE_TABLE_RE.match(st.strip())
        if not m:
            continue

        tname = _normalize_table_name(m.group("name"))
        body = m.group("body").strip()
        table = tables.setdefault(tname, Table(name=tname))

        items = _split_top_level_commas(body)
        unique_sets: List[List[str]] = []

        for raw_item in items:
            item = raw_item.strip()
            if not item:
                continue

            pm = TABLE_PK_RE.match(item)
            if pm:
                pk_cols = _parse_ident_list(pm.group("cols"))
                for c in pk_cols:
                    if c not in table.pk:
                        table.pk.append(c)
                        if c in table.columns:
                            table.columns[c].primary_key = True
                continue

            um = TABLE_UNIQUE_RE.match(item)
            if um:
                ucols = _parse_ident_list(um.group("cols"))
                unique_sets.append(ucols)
                continue

            fm = TABLE_FK_RE.match(item)
            if fm:
                cname = fm.group("cname")
                cname = _unquote_ident(cname) if cname else f"fk_{fk_counter}"
                fk_counter += 1

                if fm.group("src") is not None:
                    src_cols = _parse_ident_list(fm.group("src"))
                else:
                    src_cols = [_unquote_ident(fm.group("src_single"))]

                ref_table = _normalize_table_name(fm.group("ref"))
                ref_cols = _parse_ident_list(fm.group("refcols"))

                fks.append(ForeignKey(
                    name=cname,
                    src_table=tname,
                    src_cols=src_cols,
                    ref_table=ref_table,
                    ref_cols=ref_cols,
                ))
                continue

            cd = _parse_column_def(item)
            if cd:
                col, sql_type, not_null, is_pk, is_unique, default_expr, inline_ref = cd
                col_obj = table.columns.get(col) or Column(name=col)
                col_obj.type = sql_type
                col_obj.not_null = col_obj.not_null or not_null
                col_obj.primary_key = col_obj.primary_key or is_pk
                col_obj.unique = col_obj.unique or is_unique
                if default_expr is not None:
                    col_obj.default = default_expr
                table.columns[col] = col_obj

                if is_pk and col not in table.pk:
                    table.pk.append(col)

                if inline_ref:
                    ref_table, ref_cols = inline_ref
                    cname = f"fk_{fk_counter}"
                    fk_counter += 1
                    fks.append(ForeignKey(
                        name=cname,
                        src_table=tname,
                        src_cols=[col],
                        ref_table=ref_table,
                        ref_cols=ref_cols,
                        src_not_null=not_null,
                    ))

        for c in table.pk:
            if c in table.columns:
                table.columns[c].primary_key = True
                table.columns[c].not_null = True

        for ucols in unique_sets:
            if len(ucols) == 1:
                c = ucols[0]
                if c in table.columns:
                    table.columns[c].unique = True

    # ALTER TABLE ... ADD CONSTRAINT ... FOREIGN KEY
    for st in stmts:
        am = ALTER_FK_RE.match(st.strip())
        if not am:
            continue
        tname = _normalize_table_name(am.group("table"))
        table = tables.setdefault(tname, Table(name=tname))

        cname = am.group("cname")
        cname = _unquote_ident(cname) if cname else f"fk_{fk_counter}"
        fk_counter += 1

        if am.group("src") is not None:
            src_cols = _parse_ident_list(am.group("src"))
        else:
            src_cols = [_unquote_ident(am.group("src_single"))]

        ref_table = _normalize_table_name(am.group("ref"))
        ref_cols = _parse_ident_list(am.group("refcols"))

        tables.setdefault(ref_table, Table(name=ref_table))

        fks.append(ForeignKey(
            name=cname,
            src_table=tname,
            src_cols=src_cols,
            ref_table=ref_table,
            ref_cols=ref_cols,
        ))

    # marquage FK + cardinalités
    for fk in fks:
        table = tables.get(fk.src_table)
        if not table or not fk.src_cols:
            continue

        for c in fk.src_cols:
            col = table.columns.get(c)
            if col:
                col.foreign_key = True

        fk.src_not_null = all(table.columns.get(c, Column(c)).not_null for c in fk.src_cols)

        fk.src_unique = False
        if set(fk.src_cols) == set(table.pk) and table.pk:
            fk.src_unique = True
        elif len(fk.src_cols) == 1:
            c = fk.src_cols[0]
            if c in table.columns and table.columns[c].unique:
                fk.src_unique = True

    return tables, fks