from qgis.core import QgsVectorLayer, QgsSymbol, QgsSingleSymbolRenderer, QgsProject
from qgis.PyQt.QtGui import QColor

import os
import numpy as np
import rasterio
from rasterio.windows import from_bounds
from rasterio.warp import reproject, Resampling

def apply_point_layer_style(layer, color, opacity=1.0, size=2.0):
    """
    Apply a customizable style to the point layer.
    """
    if not layer.isValid():
        print("Invalid layer provided to style.")
        return

    geometry_type = layer.geometryType()
    if geometry_type == 0:  # Point
        symbol = QgsSymbol.defaultSymbol(geometry_type)
        if symbol:
            symbol.setColor(QColor(color))
            symbol.setOpacity(opacity)
            symbol.setSize(size)
            renderer = QgsSingleSymbolRenderer(symbol)
            layer.setRenderer(renderer)
            layer.triggerRepaint()
            print(f"Style applied to point layer '{layer.name()}'")
        else:
            print("Failed to create default point symbol.")
    else:
        print("This function is currently for point layers only.")

def generate_common_mask(raster_paths, output_dir, mask_name="common_mask.tif"):
    """
    Generate a common valid-area mask (where all input raster layers have valid values)
    from multiple environmental raster layers.
    This version aligns all rasters to the same grid before computing the mask.

    Parameters
    ----------
    raster_paths : list[str]
        List of file paths to environmental raster layers.
    output_dir : str
        Directory path where the generated mask raster will be saved.
    mask_name : str, optional
        File name for the output mask raster (default: 'common_mask.tif').

    Returns
    -------
    str
        File path to the generated mask raster.
    """
    if not raster_paths:
        raise ValueError("No raster paths provided.")

    datasets = [rasterio.open(p) for p in raster_paths]

    # Compute the intersection extent (bounding box shared by all rasters)
    left = max(ds.bounds.left for ds in datasets)
    right = min(ds.bounds.right for ds in datasets)
    bottom = max(ds.bounds.bottom for ds in datasets)
    top = min(ds.bounds.top for ds in datasets)
    common_bounds = (left, bottom, right, top)

    # Use the first raster as reference for resolution and CRS
    ref = datasets[0]
    res_x, res_y = ref.res
    crs = ref.crs

    # Define transform and grid size for the common extent
    transform = rasterio.transform.from_origin(left, top, res_x, res_y)
    width = int((right - left) / res_x)
    height = int((top - bottom) / res_y)

    # Reproject all rasters to the common grid
    mask_stack = []
    for ds in datasets:
        arr = np.full((height, width), np.nan, dtype=np.float32)
        reproject(
            source=rasterio.band(ds, 1),
            destination=arr,
            src_transform=ds.transform,
            src_crs=ds.crs,
            dst_transform=transform,
            dst_crs=crs,
            resampling=Resampling.nearest,
        )
        arr[arr == ds.nodata] = np.nan
        mask_stack.append(~np.isnan(arr))

    # Compute the intersection mask (valid pixels across all rasters)
    valid_mask = np.all(mask_stack, axis=0).astype(np.uint8)

    # Save the mask raster
    os.makedirs(output_dir, exist_ok=True)
    mask_path = os.path.join(output_dir, mask_name)

    with rasterio.open(
        mask_path,
        "w",
        driver="GTiff",
        height=valid_mask.shape[0],
        width=valid_mask.shape[1],
        count=1,
        dtype="uint8",
        crs=crs,
        transform=transform,
    ) as dst:
        dst.write(valid_mask, 1)

    # Close datasets
    for ds in datasets:
        ds.close()

    print(f"Common mask saved to: {mask_path}")
    return mask_path
