import re
import sqlparse
from sqlparse.sql import Where, Identifier, Parenthesis, Comparison
from sqlparse.tokens import Keyword, Whitespace, Punctuation, Operator as T_Operator, Name as T_Name
from sqlparse.lexer import Lexer

class SqlWhereGraphQLConverter:
    def __init__(self):
        pass

    def register_custom_keywords(self):
        lex = Lexer.get_default_instance()
        # Add custom keyword token mapping for specialized operators
        lex.add_keywords({'IIN': Keyword})

    def sql_where_to_graphql(self, where_string):
        self.register_custom_keywords()
        # Multiple whitespace -> a single space
        # where_string = ' '.join(where_string.split())
        where_string = re.sub(r'\s+(?=(?:[^"\']*["\'][^"\']*["\'])*[^"\']*$)', ' ', where_string) 

        parsed = sqlparse.parse(f"SELECT * FROM t WHERE {where_string}")[0]
        where_clause = next((token for token in parsed.tokens if isinstance(token, Where)), None)
        if where_clause is None:
            return {}
        return self.parse_where_tokens(where_clause.tokens[1:])

    # --- helpers for field tokens ---
    def _is_field_token(self, token):
        return isinstance(token, Identifier) or getattr(token, "ttype", None) == T_Name

    def _field_name(self, token):
        # Identifier may have quotes or aliases, prefer real name if available
        if isinstance(token, Identifier):
            name = getattr(token, "get_real_name", lambda: None)() or token.value
            return name.strip('"')
        # Name token: plain identifier
        return (getattr(token, "value", "") or "").strip('"')

    # --- new: literal detection & operator flip ---
    def _is_literal_token(self, token):
        # Prefer token type checks from sqlparse
        ttype = getattr(token, 'ttype', None)
        if ttype is not None:
            tstr = str(ttype)
            if tstr.startswith('Token.Literal') or tstr.startswith('Token.String') or tstr.startswith('Token.Number'):
                return True
        # Fallback by value shape
        val = getattr(token, 'value', '')
        if isinstance(val, str):
            if (len(val) >= 2 and val[0] == val[-1] == "'"):
                return True
            v = val.strip()
            if v.lstrip('-').isdigit():
                return True
            try:
                float(v)
                return True
            except Exception:
                pass
        return False

    def _flip_operator(self, op):
        """Flip non-commutative comparison operators when sides are swapped."""
        if not isinstance(op, str):
            return op
        m = {
            '>': '<',
            '<': '>',
            '>=': '<=',
            '<=': '>='
        }
        op_up = op.strip().upper()
        return m.get(op_up, op_up)

    def parse_where_tokens(self, tokens):
        tokens = [t for t in tokens if t.ttype not in [Whitespace, Punctuation]]
        stack = []
        i = 0
        while i < len(tokens):
            token = tokens[i]

            if isinstance(token, Parenthesis):
                is_in_list = False
                if i > 0:
                    if hasattr(tokens[i-1], "match") and tokens[i-1].match(Keyword, ('IN', 'IIN')):
                        is_in_list = True
                    elif i > 1 and hasattr(tokens[i-2], "match") and tokens[i-2].match(Keyword, 'NOT') and hasattr(tokens[i-1], "match") and tokens[i-1].match(Keyword, ('IN', 'IIN')):
                        is_in_list = True
                if not is_in_list:
                    group = self.parse_where_tokens(token.tokens[1:-1])
                    if group:
                        stack.append(group)
                    i += 1
                    continue

            if hasattr(token, "match") and token.match(Keyword, 'AND'):
                left = stack.pop() if stack else {}
                right = self.parse_where_tokens(tokens[i+1:])
                joins = []
                if left: joins.append(left)
                if right: joins.append(right)
                if len(joins) == 1:
                    return joins[0]
                return {'AND': joins} if joins else {}

            elif hasattr(token, "match") and token.match(Keyword, 'OR'):
                left = stack.pop() if stack else {}
                right = self.parse_where_tokens(tokens[i+1:])
                joins = []
                if left: joins.append(left)
                if right: joins.append(right)
                if len(joins) == 1:
                    return joins[0]
                return {'OR': joins} if joins else {}

            elif hasattr(token, "match") and token.match(Keyword, 'NOT'):
                # handle NOT IN / NOT IS NULL / NOT IS EMPTY
                if (i+2 < len(tokens) and self._is_field_token(tokens[i+1])
                    and hasattr(tokens[i+2], "match") and tokens[i+2].match(Keyword, ('IN', 'IIN'))
                    and i+3 < len(tokens) and isinstance(tokens[i+3], Parenthesis)):
                    field = self._field_name(tokens[i+1])
                    values = self.extract_values_from_parenthesis(tokens[i+3])
                    filt = self.build_graphql_filter(field, 'NOT ' + tokens[i+2].value.upper(), values)  # 'NOT IN' or 'NOT IIN'
                    if filt: stack.append(filt)
                    i += 4
                    continue
                elif (i+2 < len(tokens) and self._is_field_token(tokens[i+1])
                      and hasattr(tokens[i+2], "match") and tokens[i+2].match(Keyword, 'IS')
                      and i+3 < len(tokens)):
                    field = self._field_name(tokens[i+1])
                    if hasattr(tokens[i+3], "match") and tokens[i+3].match(Keyword, 'NULL'):
                        filt = self.build_graphql_filter(field, 'IS NOT NULL', None)
                        if filt: stack.append(filt)
                        i += 4
                        continue
                    if hasattr(tokens[i+3], "match") and tokens[i+3].match(Keyword, 'EMPTY'):
                        filt = self.build_graphql_filter(field, 'IS NOT EMPTY', None)
                        if filt: stack.append(filt)
                        i += 4
                        continue
                right = self.parse_where_tokens(tokens[i+1:])
                return {'NOT': right} if right else {}

            elif isinstance(token, Comparison):
                field, op, value = self.parse_comparison(token)
                filt = self.build_graphql_filter(field, op, value)
                if filt: stack.append(filt)

            # field-first fallbacks
            elif self._is_field_token(token):
                # IS [NOT] NULL/EMPTY and one-token NOT NULL / NOT EMPTY variants
                if i+1 < len(tokens) and hasattr(tokens[i+1], "match") and tokens[i+1].match(Keyword, 'IS'):
                    field = self._field_name(token)
                    if i+2 < len(tokens):
                        third = tokens[i+2]
                        third_val = getattr(third, 'value', '').upper()

                        # One-token variant: NOT NULL / NOT EMPTY
                        if third_val in ('NOT NULL', 'NOT EMPTY'):
                            mapped = 'IS NOT NULL' if third_val == 'NOT NULL' else 'IS NOT EMPTY'
                            filt = self.build_graphql_filter(field, mapped, None)
                            if filt: stack.append(filt)
                            i += 3
                            continue

                        # Two-token variant: IS NOT NULL/EMPTY
                        if third_val == 'NOT' and i+3 < len(tokens):
                            fourth_val = getattr(tokens[i+3], 'value', '').upper()
                            if fourth_val in ('NULL', 'EMPTY'):
                                mapped = 'IS NOT NULL' if fourth_val == 'NULL' else 'IS NOT EMPTY'
                                filt = self.build_graphql_filter(field, mapped, None)
                                if filt: stack.append(filt)
                                # i += 4
                                i += 4
                                continue

                        # Simple: IS NULL / IS EMPTY
                        if third_val in ('NULL', 'EMPTY'):
                            mapped = 'IS NULL' if third_val == 'NULL' else 'IS EMPTY'
                            filt = self.build_graphql_filter(field, mapped, None)
                            if filt: stack.append(filt)
                            i += 3
                            continue

                # field NOT IN (...), field IN (...)
                if (i+3 < len(tokens) and hasattr(tokens[i+1], "match") and tokens[i+1].match(Keyword, 'NOT')
                    and hasattr(tokens[i+2], "match") and tokens[i+2].match(Keyword, ('IN', 'IIN'))
                    and isinstance(tokens[i+3], Parenthesis)):
                    field = self._field_name(token)
                    values = self.extract_values_from_parenthesis(tokens[i+3])
                    filt = self.build_graphql_filter(field, 'NOT ' + tokens[i+2].value.upper(), values)  # 'NOT IN' or 'NOT IIN'
                    if filt: stack.append(filt)
                    i += 4
                    continue

                if i+1 < len(tokens) and hasattr(tokens[i+1], "match") and tokens[i+1].match(Keyword, ('IN', 'IIN')):
                    field = self._field_name(token)
                    if i+2 < len(tokens) and isinstance(tokens[i+2], Parenthesis):
                        values = self.extract_values_from_parenthesis(tokens[i+2])
                        filt = self.build_graphql_filter(field, tokens[i+1].value.upper(), values)  # 'IN' vagy 'IIN'
                        if filt: stack.append(filt)
                        i += 3
                        continue

                # NOT LIKE / NOT ILIKE via field-based fallback
                if (i+3 < len(tokens) and hasattr(tokens[i+1], "match") and tokens[i+1].match(Keyword, 'NOT')
                    and getattr(tokens[i+2], 'value', '').upper() in ('LIKE', 'ILIKE')):
                    field = self._field_name(token)
                    op = f"NOT {tokens[i+2].value.upper()}"
                    val = self._convert_value_token(tokens[i+3])
                    filt = self.build_graphql_filter(field, op, val)
                    if filt: stack.append(filt)
                    i += 4
                    continue

                # Simple binary ops via field-based fallback (=, !=, >, <, >=, <=, LIKE, ILIKE)
                if i+2 < len(tokens):
                    op_tok = tokens[i+1]
                    op_val = getattr(op_tok, 'value', '').upper()
                    is_cmp = getattr(op_tok, 'ttype', None) == T_Operator.Comparison or op_val in ('=', '!=', '>', '<', '>=', '<=', 'LIKE', 'ILIKE')
                    if is_cmp:
                        field = self._field_name(token)
                        val = self._convert_value_token(tokens[i+2])
                        filt = self.build_graphql_filter(field, op_val, val)
                        if filt: stack.append(filt)
                        i += 3
                        continue

            # NEW: literal-first fallbacks (swap sides and flip operator if needed)
            if self._is_literal_token(token) and i+2 < len(tokens):
                mid = tokens[i+1]
                rhs = tokens[i+2]

                # NOT LIKE / NOT ILIKE with literal first: 'pat' NOT LIKE "field" -> "field" NOT LIKE 'pat'
                if (hasattr(mid, "match") and mid.match(Keyword, 'NOT')
                    and getattr(rhs, 'value', '').upper() in ('LIKE', 'ILIKE')
                    and i+3 < len(tokens) and self._is_field_token(tokens[i+3])):
                    field = self._field_name(tokens[i+3])
                    op = f"NOT {rhs.value.upper()}"
                    value = self._convert_value_token(token)
                    filt = self.build_graphql_filter(field, op, value)
                    if filt: stack.append(filt)
                    i += 4
                    continue

                # LIKE/ILIKE with literal first: 'pat' LIKE "field" -> "field" LIKE 'pat'
                if (getattr(mid, 'value', '').upper() in ('LIKE', 'ILIKE')) and self._is_field_token(rhs):
                    field = self._field_name(rhs)
                    op = getattr(mid, 'value', '').upper()
                    value = self._convert_value_token(token)
                    filt = self.build_graphql_filter(field, op, value)
                    if filt: stack.append(filt)
                    i += 3
                    continue

                # Standard comparison with literal first: '5' > "age" -> "age" < 5
                mid_val = getattr(mid, 'value', '').upper()
                is_cmp = getattr(mid, 'ttype', None) == T_Operator.Comparison or mid_val in ('=', '!=', '>', '<', '>=', '<=')
                if is_cmp and self._is_field_token(rhs):
                    field = self._field_name(rhs)
                    op = self._flip_operator(mid_val)
                    value = self._convert_value_token(token)
                    filt = self.build_graphql_filter(field, op, value)
                    if filt: stack.append(filt)
                    i += 3
                    continue

            # corrected fallback: field name = literal
            if self._is_field_token(token) and i+2 < len(tokens) and getattr(tokens[i+1], 'ttype', None) == T_Operator.Comparison:
                field = self._field_name(token)
                op = tokens[i+1].value.upper()
                value = self._convert_value_token(tokens[i+2])
                filt = self.build_graphql_filter(field, op, value)
                if filt: stack.append(filt)
                i += 3
                continue

            i += 1

        filtered_stack = [x for x in stack if x]
        if len(filtered_stack) == 1:
            return filtered_stack[0]
        elif len(filtered_stack) > 1:
            return {'AND': filtered_stack}
        else:
            return {}

    def parse_comparison(self, token):
        toks = [t for t in token.tokens if t.ttype not in (Whitespace, Punctuation)]
        left_tok = toks[0] if toks else None
        field = left_tok.value.strip('"') if toks else None
        op = None
        val_tok = None

        # NOT IN / NOT IIN
        if len(toks) >= 4 and hasattr(toks[1], "match") and toks[1].match(Keyword, 'NOT') and hasattr(toks[2], "match") and toks[2].match(Keyword, ('IN', 'IIN')):
            op = f"NOT {toks[2].value.upper()}"
            val_tok = toks[3]
        # NOT LIKE / NOT ILIKE
        elif len(toks) >= 4 and hasattr(toks[1], "match") and toks[1].match(Keyword, 'NOT') and getattr(toks[2], 'value', '').upper() in ('LIKE', 'ILIKE'):
            op = f"NOT {toks[2].value.upper()}"
            val_tok = toks[3]
        # IN / IIN
        elif len(toks) >= 3 and hasattr(toks[1], "match") and toks[1].match(Keyword, ('IN', 'IIN')):
            op = toks[1].value.upper()
            val_tok = toks[2]
        else:
            if len(toks) >= 2:
                op = toks[1].value.upper()
            if len(toks) >= 3:
                val_tok = toks[2]

        # literal-left swap...
        if left_tok is not None and val_tok is not None and self._is_literal_token(left_tok) and self._is_field_token(val_tok):
            field = self._field_name(val_tok)
            op = self._flip_operator(op)
            value = self._convert_value_token(left_tok)
        else:
            value = None
            if val_tok is not None:
                if isinstance(val_tok, Parenthesis):
                    value = self.extract_values_from_parenthesis(val_tok)
                else:
                    value = self._convert_value_token(val_tok)

        # LIKE / ILIKE normalization for positive branches (MEGMARAD ITT)
        if isinstance(op, str) and op in ('LIKE', 'ILIKE') and isinstance(value, str):
            op, value = self._normalize_like_pattern(field, op, value)

        return field, op, value

    def _normalize_like_pattern(self, field, op, value):
        op_up = str(op).upper() if op is not None else ""
        if not isinstance(value, str):
            return op_up, value
        s = value
        starts_pct = s.startswith('%')
        ends_pct = s.endswith('%')
        count_pct = s.count('%')

        # NEW: ha nincs egyetlen % sem, akkor pontos egyezés
        # LIKE -> '='  , ILIKE -> 'IEQUALS'  (később equals/iequals)
        if op_up in ('LIKE', 'ILIKE') and count_pct == 0:
            return ('IEQUALS' if op_up == 'ILIKE' else '='), value

        # composite: belső % (nem elején/végén) -> split az utolsó % szerint
        if count_pct >= 1 and not starts_pct and not ends_pct:
            idx = s.rfind('%')
            starts = s[:idx]
            ends = s[idx+1:].replace('%', '')
            if op_up == 'ILIKE':
                return 'ICOMPOSITE', {'starts_with': starts, 'ends_with': ends}
            return 'COMPOSITE', {'starts_with': starts, 'ends_with': ends}

        if starts_pct and ends_pct:
            stripped = s[1:-1]
            return ('ILIKE' if op_up == 'ILIKE' else 'LIKE', stripped)
        elif ends_pct:
            stripped = s[:-1]
            return ('ISTARTS_WITH' if op_up == 'ILIKE' else 'STARTS_WITH', stripped)
        elif starts_pct:
            stripped = s[1:]
            return ('IENDS_WITH' if op_up == 'ILIKE' else 'ENDS_WITH', stripped)

        return op_up, value

    def build_graphql_filter(self, field, op, value):
        # Normalize operator: collapse whitespace and uppercase for lookup
        if isinstance(op, str):
            op_norm = ' '.join(op.strip().upper().split())
        else:
            op_norm = op

        # Normalize NOT LIKE / NOT ILIKE using the positive normalizer
        if isinstance(op_norm, str) and op_norm in ('NOT LIKE', 'NOT ILIKE') and isinstance(value, str):
            inner = 'ILIKE' if op_norm == 'NOT ILIKE' else 'LIKE'
            norm_op, norm_val = self._normalize_like_pattern(field, inner, value)

            if norm_op in ('COMPOSITE', 'ICOMPOSITE') and isinstance(norm_val, dict):
                sw = norm_val.get('starts_with', '')
                ew = norm_val.get('ends_with', '')
                key_sw = 'not_istarts_with' if inner == 'ILIKE' else 'not_starts_with'
                key_ew = 'not_iends_with'  if inner == 'ILIKE' else 'not_ends_with'
                return {'OR': [{field: {key_sw: sw}}, {field: {key_ew: ew}}]}

            elif norm_op in ('ISTARTS_WITH', 'STARTS_WITH', 'IENDS_WITH', 'ENDS_WITH'):
                op_norm = f'NOT {norm_op}'
                value = norm_val

            # NEW: %-mentes NOT LIKE/NOT ILIKE → not_equals/not_iequals
            elif norm_op == '=':
                return {field: {'not_equals': norm_val}}
            elif norm_op == 'IEQUALS':
                return {field: {'not_iequals': norm_val}}
            else:
                op_norm = f'NOT {norm_op}'
                value = norm_val

        GRAPHQL_MAP = {
            '=':                 'equals',
            '!=':                'not_equals',
            '>':                 'greater_than',
            '<':                 'less_than',
            '>=':                'greater_than_or_equals',
            '<=':                'less_than_or_equals',
            'LIKE':              'like',
            'ILIKE':             'ilike',
            'NOT LIKE':          'not_like',
            'NOT ILIKE':         'not_ilike',
            'IN':                'in',
            'NOT IN':            'not_in',
            'IIN':               'iin',
            'NOT IIN':           'not_iin',
            'IS NULL':           'is_null',
            'IS NOT NULL':       'is_not_null',
            'IS EMPTY':          'is_empty',
            'IS NOT EMPTY':      'is_not_empty',
            'STARTS_WITH':       'starts_with',
            'ENDS_WITH':         'ends_with',
            'ISTARTS_WITH':      'istarts_with',
            'IENDS_WITH':        'iends_with',
            'COMPOSITE':         'composite',
            'ICOMPOSITE':        'icomposite',
            'NOT STARTS_WITH':   'not_starts_with',
            'NOT ENDS_WITH':     'not_ends_with',
            'NOT ISTARTS_WITH':  'not_istarts_with',
            'NOT IENDS_WITH':    'not_iends_with',
            'IEQUALS':           'iequals',
            'NOT IEQUALS':       'not_iequals',
        }

        mapped = GRAPHQL_MAP.get(op_norm, None)
        if mapped is None:
            mapped = op_norm.lower().replace(' ', '_') if isinstance(op_norm, str) else op_norm

        # Special: IS NULL/IS NOT NULL/IS EMPTY/IS NOT EMPTY -> True flag
        if mapped in ('is_null', 'is_not_null', 'is_empty', 'is_not_empty'):
            return {field: {mapped: True}}

        # IN/NOT IN: accept list or parenthesis-string
        if mapped in ('in', 'not_in'):
            vals = value
            if isinstance(vals, str):
                vals = self.extract_values_from_parenthesis(vals)
            if not isinstance(vals, list):
                vals = [vals]
            return {field: {mapped: vals}}

        if mapped in ('composite', 'icomposite') and isinstance(value, dict):
            if mapped == 'icomposite':
                return {'AND': [
                    {field: {'istarts_with': value.get('starts_with', '')}},
                    {field: {'iends_with':  value.get('ends_with',  '')}}
                ]}
            else:
                return {'AND': [
                    {field: {'starts_with': value.get('starts_with', '')}},
                    {field: {'ends_with':  value.get('ends_with',  '')}}
                ]}

        if mapped and field is not None:
            return {field: {mapped: value}}
        return None

    def extract_values_from_parenthesis(self, token):
        """
        Extract values from Parenthesis token or parenthesis-like string.
        Always returns list for IN.
        """
        s = token.value.strip() if hasattr(token, 'value') else str(token)
        s = s.strip('()').strip()

        if not s:
            return []

        # If there's actual token objects passed (like sqlparse tokens separated), attempt split by comma
        if isinstance(token, Parenthesis):
            # Get interior tokens as text and split on commas
            inner = ''.join(tok.value for tok in token.tokens[1:-1] if tok.ttype is not Whitespace)
            parts = [p.strip() for p in inner.split(',') if p.strip()]
            return [self._convert_value_token(p if not isinstance(p, str) else p) for p in parts]

        if ',' in s:
            parts = [part.strip() for part in s.split(',')]
            return [self._convert_value_token(p) for p in parts]

        return [self._convert_value_token(s)]

    def _convert_value_token(self, token):
        """
        Convert a token or string to Python value: string, int, float, bool, None.
        """
        if isinstance(token, (int, float, bool)):
            return token
        if token is None:
            return None

        raw = token.value if hasattr(token, "value") else str(token)
        s = raw.strip()

        # Strip quotes
        if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
            return s[1:-1]

        # Special values
        low = s.lower()
        if low == 'null':
            return None
        if low == 'true':
            return True
        if low == 'false':
            return False

        # Numbers
        try:
            if s.lstrip('-').isdigit():
                return int(s)
            if '.' in s:
                return float(s)
        except (ValueError, AttributeError):
            pass

        # Else string
        return s
