# -*- coding: utf-8 -*-
"""
symbology_extractor.py
Extract QGIS layer symbology and convert to KML-compatible styles.
"""

import os
from qgis.PyQt.QtGui import QColor, QImage, QPainter
from qgis.PyQt.QtCore import QSize, Qt

from qgis.core import (
    QgsVectorLayer, QgsSymbol, QgsSymbolLayer,
    QgsSingleSymbolRenderer, QgsCategorizedSymbolRenderer,
    QgsGraduatedSymbolRenderer, QgsRuleBasedRenderer,
    QgsMarkerSymbol, QgsLineSymbol, QgsFillSymbol,
    QgsRenderContext, QgsMapSettings
)


def rgba_to_kml_abgr(color: QColor) -> str:
    """
    Convert QGIS QColor (RGBA) to KML ABGR hex string.
    
    KML uses AABBGGRR format (Alpha, Blue, Green, Red).
    
    Args:
        color: QColor object with RGBA values
        
    Returns:
        8-character hex string in ABGR format (e.g., 'ff0000ff' for opaque red)
    """
    if color is None:
        return 'ffffffff'  # Default opaque white
    
    alpha = color.alpha()
    red = color.red()
    green = color.green()
    blue = color.blue()
    
    return f'{alpha:02x}{blue:02x}{green:02x}{red:02x}'


def extract_symbol_properties(symbol: QgsSymbol, layer_type: str) -> dict:
    """
    Extract styling properties from a QGIS symbol.
    
    Args:
        symbol: QgsSymbol (marker, line, or fill)
        layer_type: 'Point', 'Line', or 'Polygon'
        
    Returns:
        Dictionary with style properties
    """
    if symbol is None:
        return get_default_style(layer_type)
    
    props = {
        'type': layer_type,
        'color': rgba_to_kml_abgr(symbol.color()),
        'opacity': symbol.opacity()
    }
    
    # Get the first symbol layer for detailed properties
    if symbol.symbolLayerCount() > 0:
        sym_layer = symbol.symbolLayer(0)
        layer_props = sym_layer.properties() if sym_layer else {}
        
        if layer_type == 'Point':
            # Marker properties
            props['size'] = float(layer_props.get('size', 6))
            props['size_unit'] = layer_props.get('size_unit', 'MM')
            
        elif layer_type == 'Line':
            # Line properties
            props['width'] = float(layer_props.get('line_width', layer_props.get('width', 0.5)))
            props['width_unit'] = layer_props.get('width_unit', 'MM')
            # Line color might be different
            if 'line_color' in layer_props:
                line_color = parse_color_string(layer_props['line_color'])
                if line_color:
                    props['color'] = rgba_to_kml_abgr(line_color)
            elif 'color' in layer_props:
                line_color = parse_color_string(layer_props['color'])
                if line_color:
                    props['color'] = rgba_to_kml_abgr(line_color)
                    
        elif layer_type == 'Polygon':
            # Fill properties
            props['fill_color'] = props['color']
            
            # Stroke/outline properties
            if 'outline_color' in layer_props:
                stroke_color = parse_color_string(layer_props['outline_color'])
                if stroke_color:
                    props['stroke_color'] = rgba_to_kml_abgr(stroke_color)
            elif 'color' in layer_props:
                # Some symbols use 'color' for stroke
                stroke_color = parse_color_string(layer_props.get('outline_color', layer_props.get('color', '')))
                if stroke_color:
                    props['stroke_color'] = rgba_to_kml_abgr(stroke_color)
            else:
                props['stroke_color'] = props['fill_color']
            
            props['stroke_width'] = float(layer_props.get('outline_width', layer_props.get('width', 0.26)))
            
            # Check if fill is enabled
            style = layer_props.get('style', 'solid')
            props['fill'] = style != 'no'
    
    return props


def parse_color_string(color_str: str) -> QColor:
    """
    Parse QGIS color string format (r,g,b,a) to QColor.
    
    Args:
        color_str: String like '255,0,0,255' or hex color
        
    Returns:
        QColor object or None if parsing fails
    """
    if not color_str:
        return None
    
    try:
        # Try comma-separated format: r,g,b,a
        parts = color_str.split(',')
        if len(parts) >= 3:
            r = int(parts[0].strip())
            g = int(parts[1].strip())
            b = int(parts[2].strip())
            a = int(parts[3].strip()) if len(parts) > 3 else 255
            return QColor(r, g, b, a)
    except (ValueError, IndexError):
        pass
    
    # Try hex format
    if color_str.startswith('#'):
        color = QColor(color_str)
        if color.isValid():
            return color
    
    return None


def get_default_style(layer_type: str) -> dict:
    """
    Get default KML style for a geometry type.
    
    Args:
        layer_type: 'Point', 'Line', or 'Polygon'
        
    Returns:
        Dictionary with default style properties
    """
    defaults = {
        'type': layer_type,
        'color': 'ff0000ff',  # Opaque red
        'opacity': 1.0
    }
    
    if layer_type == 'Point':
        defaults['size'] = 6
    elif layer_type == 'Line':
        defaults['width'] = 1.0
    elif layer_type == 'Polygon':
        defaults['fill_color'] = '7f0000ff'  # Semi-transparent red
        defaults['stroke_color'] = 'ff0000ff'  # Opaque red
        defaults['stroke_width'] = 1.0
        defaults['fill'] = True
        
    return defaults


def get_layer_geometry_type(layer: QgsVectorLayer) -> str:
    """
    Get the geometry type string for a vector layer.
    
    Args:
        layer: QgsVectorLayer
        
    Returns:
        'Point', 'Line', or 'Polygon'
    """
    geom_type = layer.geometryType()
    
    if geom_type == 0:  # Point
        return 'Point'
    elif geom_type == 1:  # Line
        return 'Line'
    elif geom_type == 2:  # Polygon
        return 'Polygon'
    else:
        return 'Unknown'


def extract_layer_styles(layer: QgsVectorLayer) -> dict:
    """
    Extract all styles from a vector layer's renderer.
    
    Supports:
    - QgsSingleSymbolRenderer
    - QgsCategorizedSymbolRenderer
    - QgsGraduatedSymbolRenderer
    - QgsRuleBasedRenderer (basic support)
    
    Args:
        layer: QgsVectorLayer to extract styles from
        
    Returns:
        Dictionary with:
        - 'renderer_type': Type of renderer
        - 'geometry_type': Point, Line, or Polygon
        - 'styles': Dict mapping style_id to style properties
        - 'default_style_id': ID of the default/single style
        - 'category_field': Field name for categorized renderer (if applicable)
        - 'categories': Dict mapping field values to style IDs (if applicable)
    """
    if not layer or not layer.isValid():
        return None
    
    renderer = layer.renderer()
    if not renderer:
        return None
    
    geom_type = get_layer_geometry_type(layer)
    
    result = {
        'renderer_type': 'unknown',
        'geometry_type': geom_type,
        'styles': {},
        'default_style_id': None,
        'category_field': None,
        'categories': {},
        '_renderer': renderer  # Store renderer reference for icon rendering
    }
    
    # Handle different renderer types
    if isinstance(renderer, QgsSingleSymbolRenderer):
        result['renderer_type'] = 'single'
        style_id = f'{layer.name()}_style'
        result['styles'][style_id] = extract_symbol_properties(renderer.symbol(), geom_type)
        result['default_style_id'] = style_id
        
    elif isinstance(renderer, QgsCategorizedSymbolRenderer):
        result['renderer_type'] = 'categorized'
        result['category_field'] = renderer.classAttribute()
        
        for category in renderer.categories():
            cat_value = category.value()
            style_id = f'{layer.name()}_{sanitize_style_id(str(cat_value))}'
            result['styles'][style_id] = extract_symbol_properties(category.symbol(), geom_type)
            result['categories'][str(cat_value)] = style_id
            
            # Set first as default
            if result['default_style_id'] is None:
                result['default_style_id'] = style_id
                
    elif isinstance(renderer, QgsGraduatedSymbolRenderer):
        result['renderer_type'] = 'graduated'
        result['category_field'] = renderer.classAttribute()
        
        for i, range_item in enumerate(renderer.ranges()):
            lower = range_item.lowerValue()
            upper = range_item.upperValue()
            style_id = f'{layer.name()}_range_{i}'
            result['styles'][style_id] = extract_symbol_properties(range_item.symbol(), geom_type)
            # Store range info
            result['categories'][f'{lower}_{upper}'] = {
                'style_id': style_id,
                'lower': lower,
                'upper': upper
            }
            
            if result['default_style_id'] is None:
                result['default_style_id'] = style_id
                
    elif isinstance(renderer, QgsRuleBasedRenderer):
        result['renderer_type'] = 'rule_based'
        # For rule-based, just extract the root symbol as default
        root_rule = renderer.rootRule()
        if root_rule:
            for i, child_rule in enumerate(root_rule.children()):
                style_id = f'{layer.name()}_rule_{i}'
                if child_rule.symbol():
                    result['styles'][style_id] = extract_symbol_properties(child_rule.symbol(), geom_type)
                    result['categories'][child_rule.label() or f'rule_{i}'] = style_id
                    
                    if result['default_style_id'] is None:
                        result['default_style_id'] = style_id
    else:
        # Fallback: try to get any symbol
        result['renderer_type'] = 'unknown'
        symbols = renderer.symbols(QgsRenderContext())
        if symbols:
            style_id = f'{layer.name()}_style'
            result['styles'][style_id] = extract_symbol_properties(symbols[0], geom_type)
            result['default_style_id'] = style_id
    
    # Ensure at least a default style exists
    if not result['styles']:
        style_id = f'{layer.name()}_default'
        result['styles'][style_id] = get_default_style(geom_type)
        result['default_style_id'] = style_id
    
    return result


def sanitize_style_id(value: str) -> str:
    """
    Sanitize a value for use as a KML style ID.
    
    Args:
        value: Raw string value
        
    Returns:
        Safe string for KML ID attribute
    """
    if not value:
        return 'empty'
    
    # Replace invalid characters
    safe = value.replace(' ', '_').replace('/', '_').replace('\\', '_')
    safe = ''.join(c for c in safe if c.isalnum() or c in '_-')
    
    return safe if safe else 'unknown'


def get_feature_style_id(feature, layer_styles: dict, layer: QgsVectorLayer) -> str:
    """
    Get the appropriate style ID for a feature based on renderer type.
    
    Args:
        feature: QgsFeature
        layer_styles: Style dictionary from extract_layer_styles()
        layer: QgsVectorLayer
        
    Returns:
        Style ID string
    """
    if not layer_styles:
        return None
    
    renderer_type = layer_styles.get('renderer_type', 'single')
    
    if renderer_type == 'single':
        return layer_styles.get('default_style_id')
        
    elif renderer_type == 'categorized':
        cat_field = layer_styles.get('category_field')
        if cat_field:
            field_value = str(feature[cat_field]) if cat_field in [f.name() for f in feature.fields()] else None
            if field_value and field_value in layer_styles.get('categories', {}):
                return layer_styles['categories'][field_value]
        return layer_styles.get('default_style_id')
        
    elif renderer_type == 'graduated':
        cat_field = layer_styles.get('category_field')
        if cat_field:
            try:
                field_value = float(feature[cat_field])
                # Find matching range
                for key, range_info in layer_styles.get('categories', {}).items():
                    if isinstance(range_info, dict):
                        lower = range_info.get('lower', float('-inf'))
                        upper = range_info.get('upper', float('inf'))
                        if lower <= field_value <= upper:
                            return range_info['style_id']
            except (ValueError, TypeError, KeyError):
                pass
        return layer_styles.get('default_style_id')
        
    elif renderer_type == 'rule_based':
        # For rule-based, use default (proper rule evaluation would be complex)
        return layer_styles.get('default_style_id')
    
    return layer_styles.get('default_style_id')


def render_marker_to_png(symbol, size: int, output_path: str) -> bool:
    """
    Render a marker symbol to a PNG file for use as KML icon.
    
    Args:
        symbol: QgsMarkerSymbol to render
        size: Image size in pixels (square)
        output_path: Path to save PNG file
        
    Returns:
        True if successful, False otherwise
    """
    from qgis.PyQt.QtCore import QPointF
    
    if not symbol:
        return False
    
    # Check if it's a marker symbol
    if not isinstance(symbol, QgsMarkerSymbol):
        return False
    
    try:
        # Ensure output directory exists
        out_dir = os.path.dirname(output_path)
        if out_dir:
            os.makedirs(out_dir, exist_ok=True)
        
        # Create image with transparent background
        image = QImage(QSize(size, size), QImage.Format_ARGB32_Premultiplied)
        image.fill(Qt.transparent)
        
        # Create painter
        painter = QPainter(image)
        painter.setRenderHint(QPainter.Antialiasing, True)
        painter.setRenderHint(QPainter.SmoothPixmapTransform, True)
        
        # Create render context
        ms = QgsMapSettings()
        ms.setOutputSize(QSize(size, size))
        context = QgsRenderContext.fromMapSettings(ms)
        context.setPainter(painter)
        # Scale factor for rendering (96 DPI / 25.4 mm per inch)
        context.setScaleFactor(96 / 25.4)
        
        # Calculate center point
        center = QPointF(size / 2.0, size / 2.0)
        
        # Render the symbol at center
        symbol.startRender(context)
        symbol.renderPoint(center, None, context)
        symbol.stopRender(context)
        
        painter.end()
        
        # Save image
        return image.save(output_path, 'PNG')
        
    except Exception as e:
        print(f"Error rendering marker to PNG: {e}")
        import traceback
        traceback.print_exc()
        return False


def render_all_marker_styles(layer_styles: dict, icons_dir: str, size: int = 32) -> dict:
    """
    Render all marker symbols from layer styles to PNG files.
    
    Args:
        layer_styles: Dictionary from extract_layer_styles()
        icons_dir: Directory to save icon PNG files
        size: Icon size in pixels (default 32x32)
        
    Returns:
        Dictionary mapping style_id to relative icon path
    """
    if not layer_styles or layer_styles.get('geometry_type') != 'Point':
        return {}
    
    icon_paths = {}
    renderer = layer_styles.get('_renderer')
    
    if not renderer:
        return {}
    
    os.makedirs(icons_dir, exist_ok=True)
    
    # Get symbols based on renderer type
    renderer_type = layer_styles.get('renderer_type', 'single')
    
    if renderer_type == 'single' and isinstance(renderer, QgsSingleSymbolRenderer):
        symbol = renderer.symbol()
        if symbol and isinstance(symbol, QgsMarkerSymbol):
            style_id = layer_styles.get('default_style_id')
            icon_filename = f"{sanitize_style_id(style_id)}.png"
            icon_path = os.path.join(icons_dir, icon_filename)
            if render_marker_to_png(symbol, size, icon_path):
                icon_paths[style_id] = f"icons/{icon_filename}"
                
    elif renderer_type == 'categorized' and isinstance(renderer, QgsCategorizedSymbolRenderer):
        for category in renderer.categories():
            symbol = category.symbol()
            if symbol and isinstance(symbol, QgsMarkerSymbol):
                cat_value = str(category.value())
                style_id = layer_styles.get('categories', {}).get(cat_value)
                if style_id:
                    icon_filename = f"{sanitize_style_id(style_id)}.png"
                    icon_path = os.path.join(icons_dir, icon_filename)
                    if render_marker_to_png(symbol, size, icon_path):
                        icon_paths[style_id] = f"icons/{icon_filename}"
                        
    elif renderer_type == 'graduated' and isinstance(renderer, QgsGraduatedSymbolRenderer):
        for i, range_item in enumerate(renderer.ranges()):
            symbol = range_item.symbol()
            if symbol and isinstance(symbol, QgsMarkerSymbol):
                style_id = f"{layer_styles.get('default_style_id', 'style').rsplit('_', 1)[0]}_range_{i}"
                # Find the style_id from categories
                for key, range_info in layer_styles.get('categories', {}).items():
                    if isinstance(range_info, dict) and range_info.get('style_id'):
                        if f'range_{i}' in range_info['style_id']:
                            style_id = range_info['style_id']
                            break
                icon_filename = f"{sanitize_style_id(style_id)}.png"
                icon_path = os.path.join(icons_dir, icon_filename)
                if render_marker_to_png(symbol, size, icon_path):
                    icon_paths[style_id] = f"icons/{icon_filename}"
    
    return icon_paths

