import logging
from functools import lru_cache
from typing import Dict, List, Union, Tuple, Optional
from shapely.geometry import shape, mapping
from qgis.core import (  # pylint: disable=import-error
    QgsProject, QgsCoordinateTransform, QgsCoordinateReferenceSystem, QgsPointXY)
from qgis.gui import QgsMapCanvas  # pylint: disable=import-error
from qgis.utils import iface  # pylint: disable=import-error
from .logger import Logger, log

logger = Logger(__name__).get()


def check_crs_mismatch(layer_crs: Union[str, QgsCoordinateReferenceSystem], 
                      canvas: Optional[QgsMapCanvas] = None,
                      canvas_crs: Optional[Union[str, QgsCoordinateReferenceSystem]] = None) -> bool:
    """
    Check if there's a mismatch between layer CRS and map canvas CRS.
    
    Args:
        layer_crs: CRS of the layer (can be string or QgsCoordinateReferenceSystem)
        canvas: QgsMapCanvas instance (optional)
        canvas_crs: Canvas CRS if canvas not provided (optional)
        
    Returns:
        bool: True if there's a mismatch, False otherwise
        
    Raises:
        ValueError: If neither canvas nor canvas_crs is provided
    """
    if not canvas and not canvas_crs:
        raise ValueError("Either canvas or canvas_crs must be provided")
    
    # Convert layer_crs to QgsCoordinateReferenceSystem if string
    if isinstance(layer_crs, str):
        layer_crs = QgsCoordinateReferenceSystem(layer_crs)
    
    # Get canvas CRS
    if canvas:
        dest_crs = canvas.mapSettings().destinationCrs()
    else:
        dest_crs = canvas_crs if isinstance(canvas_crs, QgsCoordinateReferenceSystem) \
                  else QgsCoordinateReferenceSystem(canvas_crs)
    
    # Check validity
    if not layer_crs.isValid() or not dest_crs.isValid():
        logger.warning("Invalid CRS detected")
        return True
    
    # Check if CRS matches
    is_mismatch = layer_crs.authid() != dest_crs.authid()
    
    if is_mismatch:
        warning_msg = (
            f"CRS mismatch detected!\n"
            f"Layer CRS: {layer_crs.description()} ({layer_crs.authid()})\n"
            f"Canvas CRS: {dest_crs.description()} ({dest_crs.authid()})\n"
            f"Please reproject your layer to {dest_crs.authid()} for optimal functionality."
        )
        logger.warning(warning_msg)
        
        # You can also use QGIS messaging system if needed:
        # from qgis.core import Qgis, QgsMessageLog
        # QgsMessageLog.logMessage(warning_msg, 'CRS Warning', Qgis.Warning)
        
    return is_mismatch


@lru_cache(maxsize=32)
def get_transformer_from_string(from_srs_str: str, to_srs_str: str) -> QgsCoordinateTransform:
    """Cached coordinate transformer creation from CRS strings"""
    from_crs = QgsCoordinateReferenceSystem(from_srs_str)
    to_crs = QgsCoordinateReferenceSystem(to_srs_str)
    from_crs = QgsCoordinateReferenceSystem(from_srs_str)
    to_crs = QgsCoordinateReferenceSystem(to_srs_str)
    if not from_crs.isValid() or not to_crs.isValid():
        raise ValueError("Invalid CRS")
    return get_transformer_from_string(from_srs_str, to_srs_str)

def get_transformer(from_srs, to_srs) -> QgsCoordinateTransform:
    """Get transformer handling both string and QgsCoordinateReferenceSystem inputs"""
    if isinstance(from_srs, QgsCoordinateReferenceSystem) and isinstance(to_srs, QgsCoordinateReferenceSystem):
        if not from_srs.isValid() or not to_srs.isValid():
            raise ValueError("Invalid CRS")
        return QgsCoordinateTransform(from_srs, to_srs, QgsProject.instance())
    
    # Convert to string representation for caching
    from_srs_str = from_srs if isinstance(from_srs, str) else from_srs.authid()
    to_srs_str = to_srs if isinstance(to_srs, str) else to_srs.authid()
        
    from_crs = QgsCoordinateReferenceSystem(from_srs_str)
    to_crs = QgsCoordinateReferenceSystem(to_srs_str)
    if not from_crs.isValid() or not to_crs.isValid():
        raise ValueError("Invalid CRS")
    return QgsCoordinateTransform(from_crs, to_crs, QgsProject.instance())

# Keep your existing transform_coordinate and transform_point functions if they're used elsewhere
@log(logging.DEBUG, print_args=True, print_return=True)
def transform_coordinate(x, y, from_srs, to_srs):
    """Transform coordinate to crs"""
    transform = get_transformer(from_srs, to_srs)  # Now uses cached transformer
    transformed = transform.transform(x, y)
    return transformed

@log(logging.DEBUG, print_args=True, print_return=True)
def transform_point(point, from_srs, to_srs):
    """Transform point to crs"""
    transform = get_transformer(from_srs, to_srs)  # Now uses cached transformer
    transformed = transform.transform(point)
    return transformed

@log(logging.DEBUG, print_args=True, print_return=True)
def transform_json_feature_collection(fc: Dict, from_srs: Union[str, QgsCoordinateReferenceSystem], 
                                   to_srs: Union[str, QgsCoordinateReferenceSystem]) -> Dict:
    """Optimized feature collection transformation"""
    # Early validation
    if isinstance(from_srs, str):
        from_crs = QgsCoordinateReferenceSystem(from_srs)
    else:
        from_crs = from_srs
        
    if isinstance(to_srs, str):
        to_crs = QgsCoordinateReferenceSystem(to_srs)
    else:
        to_crs = to_srs
    
    # Early returns for invalid or identical CRS
    if not from_crs.isValid() or not to_crs.isValid():
        logger.warning("Invalid CRS provided, returning original feature collection")
        return fc
        
    if from_crs.authid() == to_crs.authid():
        return fc
       # Check for CRS mismatch if canvas is provided
    # Get the current map canvas and check for CRS mismatch
    #current_canvas = iface.mapCanvas()
    #check_crs_mismatch(from_srs, canvas=current_canvas)

    try:
        transformer = get_transformer(from_srs, to_srs)
    except ValueError as e:
        logger.warning(f"CRS transformation error: {str(e)}")
        return fc

    def transform_coords(coords: List[float]) -> List[float]:
        """Transform a single coordinate pair efficiently"""
        transformed = transformer.transform(coords[0], coords[1])
        return [transformed.x(), transformed.y()]

    def transform_geometry(geometry: Dict) -> Dict:
        """Transform geometry with minimal overhead"""
        if not geometry or 'type' not in geometry or 'coordinates' not in geometry:
            return geometry

        coords = geometry['coordinates']
        geom_type = geometry['type']

        if geom_type == 'Point':
            geometry['coordinates'] = transform_coords(coords)
        
        elif geom_type in ('LineString', 'MultiPoint'):
            geometry['coordinates'] = [transform_coords(coord) for coord in coords]
        
        elif geom_type in ('Polygon', 'MultiLineString'):
            geometry['coordinates'] = [[transform_coords(coord) for coord in ring] 
                                    for ring in coords]
        
        elif geom_type == 'MultiPolygon':
            geometry['coordinates'] = [[[transform_coords(coord) for coord in ring]
                                    for ring in poly] for poly in coords]
        
        return geometry

    # Transform features in place
    for feature in fc.get('features', []):
        if not isinstance(feature, dict):
            continue
            
        if 'geometry' in feature:
            feature['geometry'] = transform_geometry(feature['geometry'])
        
        # Transform bbox if present
        if 'bbox' in feature:
            try:
                min_coord = transform_coords([feature['bbox'][0], feature['bbox'][1]])
                max_coord = transform_coords([feature['bbox'][2], feature['bbox'][3]])
                feature['bbox'] = [*min_coord, *max_coord]
            except Exception as e:
                logger.warning(f"Failed to transform bbox: {str(e)}")
                feature.pop('bbox', None)

    # Transform dataset bbox if present
    if 'bbox' in fc:
        try:
            min_coord = transform_coords([fc['bbox'][0], fc['bbox'][1]])
            max_coord = transform_coords([fc['bbox'][2], fc['bbox'][3]])
            fc['bbox'] = [*min_coord, *max_coord]
        except Exception as e:
            logger.warning(f"Failed to transform dataset bbox: {str(e)}")
            fc.pop('bbox', None)

    return fc

# Keep your existing simplify_json_feature_collection function
@log(logging.DEBUG, print_args=True, print_return=True)
def simplify_json_feature_collection(fc, tolerance):
    """Transforms a featurecollection"""
    if tolerance:
        features = [_simplify_feature(feature, tolerance)
                    for feature in fc['features']]
        fc['features'] = features

    return fc


def _simplify_feature(feature, tolerance):
    """Transforms a feature"""
    geom = feature['geometry']
    print(geom)
    geom = shape(geom)
    simplified_geom = geom.simplify(tolerance, preserve_topology=False)
    print(mapping(simplified_geom))
    feature['geometry'] = mapping(simplified_geom)
    return feature


@log(logging.DEBUG, print_args=True, print_return=True)
def simplify_json_feature_collection(fc, tolerance):
    """Transforms a featurecollection"""
    if tolerance:
        features = [_simplify_feature(feature, tolerance)
                    for feature in fc['features']]
        fc['features'] = features

    return fc