import os
import numpy as np
import warnings
warnings.filterwarnings('ignore')

from qgis.core import (
    QgsProcessingAlgorithm,
    QgsProcessingParameterRasterLayer,
    QgsProcessingParameterVectorLayer,
    QgsProcessingParameterCrs,
    QgsProcessingParameterFolderDestination,
    QgsProcessingParameterBoolean,
    QgsProcessingException,
    QgsProcessing,
    QgsProject,
    QgsRasterLayer,
    QgsCoordinateReferenceSystem
)

try:
    import rasterio
    from rasterio.mask import mask
    from rasterio.warp import calculate_default_transform, reproject, Resampling
    import geopandas as gpd
    from shapely.geometry import mapping
except ImportError:
    pass

class DemSlopeProcessing(QgsProcessingAlgorithm):
    DEM = 'DEM'
    CATCHMENT = 'CATCHMENT'
    TARGET_CRS = 'TARGET_CRS'
    OUTPUT_FOLDER = 'OUTPUT_FOLDER'
    LOAD_IN_QGIS = 'LOAD_IN_QGIS'

    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterRasterLayer(
            self.DEM, 'Input DEM Raster'))
            
        self.addParameter(QgsProcessingParameterVectorLayer(
            self.CATCHMENT, 'Catchment Boundary',
            types=[QgsProcessing.TypeVectorPolygon]))
            
        self.addParameter(QgsProcessingParameterCrs(
            self.TARGET_CRS, 'Target CRS for Processing',
            defaultValue='EPSG:32643'))  # Default to UTM 43N
            
        self.addParameter(QgsProcessingParameterBoolean(
            self.LOAD_IN_QGIS, 'Load processed layers in QGIS', 
            defaultValue=True))
            
        self.addParameter(QgsProcessingParameterFolderDestination(
            self.OUTPUT_FOLDER, 'Output Folder'))

    def processAlgorithm(self, parameters, context, feedback):
        try:
            import rasterio
            from rasterio.mask import mask
            from rasterio.warp import calculate_default_transform, reproject, Resampling
            import geopandas as gpd
            from shapely.geometry import mapping
        except ImportError as e:
            raise QgsProcessingException(f"Required packages not available: {str(e)}")

        dem_layer = self.parameterAsRasterLayer(parameters, self.DEM, context)
        catchment_layer = self.parameterAsVectorLayer(parameters, self.CATCHMENT, context)
        target_crs = self.parameterAsCrs(parameters, self.TARGET_CRS, context)
        load_in_qgis = self.parameterAsBoolean(parameters, self.LOAD_IN_QGIS, context)
        output_folder = self.parameterAsString(parameters, self.OUTPUT_FOLDER, context)

        # Create output directory
        os.makedirs(output_folder, exist_ok=True)

        dem_path = dem_layer.source()
        catchment_path = catchment_layer.source()

        feedback.pushInfo("Starting DEM and slope processing...")

        # Step 1: Reproject DEM to target CRS
        feedback.pushInfo(f"Step 1: Reprojecting DEM to {target_crs.authid()}...")
        utm_dem_path = os.path.join(output_folder, "dem_reprojected.tif")
        
        with rasterio.open(dem_path) as src:
            # Define target CRS
            utm_crs = target_crs.authid()
            transform, width, height = calculate_default_transform(
                src.crs, utm_crs, src.width, src.height, *src.bounds)
            
            kwargs = src.meta.copy()
            kwargs.update({
                'crs': utm_crs,
                'transform': transform,
                'width': width,
                'height': height
            })
            
            with rasterio.open(utm_dem_path, 'w', **kwargs) as dst:
                for i in range(1, src.count + 1):
                    reproject(
                        source=rasterio.band(src, i),
                        destination=rasterio.band(dst, i),
                        src_transform=src.transform,
                        src_crs=src.crs,
                        dst_transform=transform,
                        dst_crs=utm_crs,
                        resampling=Resampling.bilinear)

        feedback.pushInfo("DEM reprojection completed")

        # Step 2: Crop DEM to catchment
        feedback.pushInfo("Step 2: Cropping DEM to catchment...")
        cropped_dem_path = os.path.join(output_folder, "dem_cropped.tif")
        
        catchment_gdf = gpd.read_file(catchment_path)
        catchment_utm = catchment_gdf.to_crs(utm_crs)
        
        with rasterio.open(utm_dem_path) as src:
            geoms = [mapping(geom) for geom in catchment_utm.geometry]
            out_image, out_transform = mask(src, geoms, crop=True)
            out_meta = src.meta.copy()
            
            out_meta.update({
                "height": out_image.shape[1],
                "width": out_image.shape[2],
                "transform": out_transform
            })
            
            with rasterio.open(cropped_dem_path, "w", **out_meta) as dest:
                dest.write(out_image)

        feedback.pushInfo("DEM cropping completed")

        # Step 3: Calculate slope
        feedback.pushInfo("Step 3: Calculating slope...")
        slope_path = os.path.join(output_folder, "slope_percent.tif")
        
        self.calculate_slope(cropped_dem_path, slope_path, feedback)

        # Step 4: Crop slope to catchment
        feedback.pushInfo("Step 4: Cropping slope to catchment...")
        cropped_slope_path = os.path.join(output_folder, "slope_cropped.tif")
        
        with rasterio.open(slope_path) as src:
            geoms = [mapping(geom) for geom in catchment_utm.geometry]
            out_image, out_transform = mask(src, geoms, crop=True)
            out_meta = src.meta.copy()
            
            out_meta.update({
                "height": out_image.shape[1],
                "width": out_image.shape[2],
                "transform": out_transform
            })
            
            with rasterio.open(cropped_slope_path, "w", **out_meta) as dest:
                dest.write(out_image)

        feedback.pushInfo("Slope processing completed")

        # Load in QGIS if requested
        if load_in_qgis:
            # Load cropped DEM
            dem_layer = QgsRasterLayer(cropped_dem_path, "Processed DEM")
            if dem_layer.isValid():
                QgsProject.instance().addMapLayer(dem_layer)
                feedback.pushInfo("Loaded processed DEM in QGIS")
            else:
                feedback.pushWarning("Failed to load processed DEM")
            
            # Load cropped slope
            slope_layer = QgsRasterLayer(cropped_slope_path, "Slope Percent")
            if slope_layer.isValid():
                QgsProject.instance().addMapLayer(slope_layer)
                feedback.pushInfo("Loaded slope layer in QGIS")
            else:
                feedback.pushWarning("Failed to load slope layer")

        return {'OUTPUT_FOLDER': output_folder}

    def calculate_slope(self, dem_path, output_path, feedback):
        """Calculate slope in percent from DEM"""
        with rasterio.open(dem_path) as src:
            dem = src.read(1)
            transform = src.transform
            
            # Replace nodata with NaN
            if src.nodata is not None:
                dem = np.where(dem == src.nodata, np.nan, dem)
            
            # Calculate gradients
            dx, dy = np.gradient(dem, transform[0], -transform[4])
            
            # Calculate slope in percent
            slope_percent = np.sqrt(dx**2 + dy**2) * 100
            
            # Update metadata
            profile = src.profile.copy()
            profile.update(dtype=rasterio.float32, nodata=np.nan)
            
            with rasterio.open(output_path, 'w', **profile) as dst:
                dst.write(slope_percent.astype(np.float32), 1)

    def name(self):
        return 'dem_slope_processing'

    def displayName(self):
        return 'DEM and Slope Processing'

    def group(self):
        return 'Hydrology'

    def groupId(self):
        return 'hydrology'

    def createInstance(self):
        return DemSlopeProcessing()