import os
import numpy as np
from osgeo import gdal, gdalconst, osr
from qgis.PyQt.QtCore import QCoreApplication
from qgis.PyQt.QtGui import QIcon
from qgis.core import (
    QgsProcessingAlgorithm, QgsProcessingParameterMultipleLayers,
    QgsProcessingParameterEnum, QgsProcessingParameterRasterDestination,
    QgsProcessingParameterNumber, QgsProcessing, QgsProcessingException, 
    QgsMessageLog, Qgis, QgsProcessingParameterDefinition
)

class MosaicAlgorithm(QgsProcessingAlgorithm):
    """
    Mosaic Tool Algorithm for combining multiple raster layers
    Part of MAS Raster Processing Tools group
    """
    
    # Parameter names
    INPUT_LAYERS = 'INPUT_LAYERS'
    STAT = 'STAT'
    DATA_TYPE = 'DATA_TYPE'
    RESAMPLING = 'RESAMPLING'
    COMPRESSION = 'COMPRESSION'
    JPEG_QUALITY = 'JPEG_QUALITY'
    OUTPUT = 'OUTPUT'
    
    # Available options - REORDERED to put defaults first
    STAT_METHODS = ['first', 'last', 'min', 'max', 'mean', 'median', 'sum']
    DATA_TYPES = ['Auto', 'Byte', 'UInt16', 'Int16', 'Float32', 'Float64']
    RESAMPLING_METHODS = ['nearest', 'bilinear', 'cubic']  # nearest is now first (index 0)
    COMPRESSION_METHODS = ['NONE', 'LZW', 'DEFLATE', 'PACKBITS', 'JPEG']  # NONE is now first (index 0)

    def __init__(self):
        super().__init__()

    def tr(self, string):
        """Translation support"""
        return QCoreApplication.translate('MosaicAlgorithm', string)

    def createInstance(self):
        """Create new instance of the algorithm"""
        return MosaicAlgorithm()

    def name(self):
        """Algorithm name - used as unique identifier"""
        return 'mosaic_tool'

    def displayName(self):
        """Algorithm display name shown to users"""
        return self.tr('Mosaic Tool')

    def group(self):
        """Group name that appears in Processing Toolbox"""
        return self.tr('MAS Raster Processing Tools')

    def groupId(self):
        """Group identifier"""
        return 'mas_raster_processing_tools'

    def icon(self):
        """Algorithm icon (mosaic.png)"""
        icon_path = os.path.join(os.path.dirname(__file__), 'mosaic.png')
        if os.path.exists(icon_path):
            return QIcon(icon_path)
        return QgsProcessingAlgorithm.icon(self)

    def shortHelpString(self):
        """Help text shown in algorithm dialog"""
        return self.tr('''
        <h3>Mosaic Tool</h3>
        <p>Create advanced mosaics from multiple raster layers with spatial alignment, 
        statistical methods for overlapping areas, and compression options.</p>
        
        <h4>Parameters:</h4>
        <ul>
        <li><b>Input Raster Layers:</b> Select 2 or more raster layers to mosaic</li>
        <li><b>Mosaic Method:</b> How to handle overlapping pixels (first, last, min, max, mean, etc.)</li>
        <li><b>Output Data Type:</b> Data type of output raster</li>
        <li><b>Resampling:</b> Method for spatial alignment (default: nearest)</li>
        <li><b>Compression:</b> Compression method for output file (default: NONE)</li>
        <li><b>JPEG Quality:</b> Quality setting for JPEG compression only (1-100)</li>
        </ul>
        ''')

    def initAlgorithm(self, config=None):
        """Initialize algorithm parameters with proper defaults"""
        
        # Input raster layers
        self.addParameter(
            QgsProcessingParameterMultipleLayers(
                self.INPUT_LAYERS,
                self.tr('Input Raster Layers'),
                QgsProcessing.TypeRaster
            )
        )
        
        # Mosaic method
        self.addParameter(
            QgsProcessingParameterEnum(
                self.STAT,
                self.tr('Mosaic Method'),
                options=self.STAT_METHODS,
                defaultValue=0  # First option (first)
            )
        )
        
        # Output data type
        self.addParameter(
            QgsProcessingParameterEnum(
                self.DATA_TYPE,
                self.tr('Output Data Type'),
                options=self.DATA_TYPES,
                defaultValue=0  # First option (Auto)
            )
        )
        
        # Resampling method - nearest is now first (index 0)
        self.addParameter(
            QgsProcessingParameterEnum(
                self.RESAMPLING,
                self.tr('Resampling Method'),
                options=self.RESAMPLING_METHODS,
                defaultValue=0  # First option (nearest)
            )
        )
        
        # Compression method - NONE is now first (index 0)
        self.addParameter(
            QgsProcessingParameterEnum(
                self.COMPRESSION,
                self.tr('Compression'),
                options=self.COMPRESSION_METHODS,
                defaultValue=0  # First option (NONE)
            )
        )
        
        # JPEG Quality - mark as advanced and optional
        jpeg_quality_param = QgsProcessingParameterNumber(
            self.JPEG_QUALITY,
            self.tr('JPEG Quality (1-100, only for JPEG compression)'),
            type=QgsProcessingParameterNumber.Integer,
            minValue=1,
            maxValue=100,
            defaultValue=85,
            optional=True  # Make it optional since it's only needed for JPEG
        )
        # Mark as advanced parameter - hidden by default
        jpeg_quality_param.setFlags(jpeg_quality_param.flags() | QgsProcessingParameterDefinition.FlagAdvanced)
        self.addParameter(jpeg_quality_param)
        
        # Output raster
        self.addParameter(
            QgsProcessingParameterRasterDestination(
                self.OUTPUT,
                self.tr('Output Mosaic')
            )
        )

    def processAlgorithm(self, parameters, context, feedback):
        """Execute the algorithm"""
        try:
            # Get parameters
            layers = self.parameterAsLayerList(parameters, self.INPUT_LAYERS, context)
            stat_method = self.STAT_METHODS[self.parameterAsEnum(parameters, self.STAT, context)]
            data_type = self.DATA_TYPES[self.parameterAsEnum(parameters, self.DATA_TYPE, context)]
            resampling = self.RESAMPLING_METHODS[self.parameterAsEnum(parameters, self.RESAMPLING, context)]
            compression = self.COMPRESSION_METHODS[self.parameterAsEnum(parameters, self.COMPRESSION, context)]
            
            # Handle optional JPEG quality parameter
            if self.JPEG_QUALITY in parameters and parameters[self.JPEG_QUALITY] is not None:
                jpeg_quality = self.parameterAsInt(parameters, self.JPEG_QUALITY, context)
            else:
                jpeg_quality = 85  # Default value if not provided
            
            output_path = self.parameterAsOutputLayer(parameters, self.OUTPUT, context)
            
            if len(layers) < 2:
                raise QgsProcessingException(self.tr("At least 2 input layers are required"))
            
            # Log compression info
            if compression == 'JPEG':
                feedback.pushInfo(self.tr(f"Processing {len(layers)} raster layers with {compression} compression (quality: {jpeg_quality})..."))
            else:
                feedback.pushInfo(self.tr(f"Processing {len(layers)} raster layers with {compression} compression..."))
            
            # Get file paths from layers
            input_files = []
            for layer in layers:
                if hasattr(layer, 'source'):
                    input_files.append(layer.source())
                else:
                    input_files.append(str(layer))
            
            # Create mosaic with compression
            self.create_mosaic(input_files, output_path, stat_method, data_type, resampling, compression, jpeg_quality, feedback)
            
            return {self.OUTPUT: output_path}
            
        except Exception as e:
            QgsMessageLog.logMessage(f"Mosaic error: {str(e)}", 'MAS Geospatial Tools', Qgis.Critical)
            raise QgsProcessingException(f"Processing failed: {str(e)}")

    def _get_creation_options(self, compression, jpeg_quality):
        """Get GDAL creation options with compression"""
        options = ['TILED=YES', 'BIGTIFF=IF_SAFER']
        
        if compression != 'NONE':
            options.append(f'COMPRESS={compression}')
            
            if compression == 'JPEG':
                options.append(f'JPEG_QUALITY={jpeg_quality}')
            elif compression == 'DEFLATE':
                options.append('PREDICTOR=2')
        
        return options

    # ... (include all other methods from previous version - create_mosaic, _calculate_target_params, etc.)
    def create_mosaic(self, input_files, output_path, stat_method, data_type, resampling, compression, jpeg_quality, feedback):
        """Create mosaic with compression support"""
        
        # Open and validate datasets
        datasets = []
        for i, file_path in enumerate(input_files):
            ds = gdal.Open(file_path, gdalconst.GA_ReadOnly)
            if ds is None:
                raise QgsProcessingException(f"Cannot open: {file_path}")
            datasets.append(ds)
            feedback.setProgress(int((i + 1) / len(input_files) * 20))
        
        feedback.pushInfo(self.tr("Calculating mosaic parameters..."))
        
        # Calculate target parameters
        target_extent, target_resolution, target_srs = self._calculate_target_params(datasets)
        
        # Determine output data type
        gdal_data_type = self._get_gdal_data_type(data_type, datasets[0])
        
        # Calculate dimensions
        width = int((target_extent[2] - target_extent[0]) / target_resolution[0])
        height = int((target_extent[3] - target_extent[1]) / target_resolution[1])
        
        feedback.pushInfo(self.tr(f"Output size: {width} x {height} with {compression} compression"))
        
        # Create output dataset with compression
        driver = gdal.GetDriverByName('GTiff')
        creation_options = self._get_creation_options(compression, jpeg_quality)
        
        out_ds = driver.Create(
            output_path, width, height, 1, gdal_data_type,
            options=creation_options
        )
        
        if out_ds is None:
            raise QgsProcessingException(f"Cannot create output file: {output_path}")
        
        # Set geotransform and projection
        geotransform = [
            target_extent[0], target_resolution[0], 0,
            target_extent[3], 0, -target_resolution[1]
        ]
        out_ds.SetGeoTransform(geotransform)
        out_ds.SetProjection(target_srs.ExportToWkt())
        
        # Set nodata value
        nodata_value = -9999.0
        out_band = out_ds.GetRasterBand(1)
        out_band.SetNoDataValue(nodata_value)
        
        feedback.pushInfo(self.tr("Processing mosaic data..."))
        
        # Process in chunks
        chunk_size = 1024
        total_chunks = ((height - 1) // chunk_size + 1) * ((width - 1) // chunk_size + 1)
        processed_chunks = 0
        
        for y in range(0, height, chunk_size):
            for x in range(0, width, chunk_size):
                if feedback.isCanceled():
                    break
                
                chunk_width = min(chunk_size, width - x)
                chunk_height = min(chunk_size, height - y)
                
                chunk_extent = [
                    target_extent[0] + x * target_resolution[0],
                    target_extent[3] - (y + chunk_height) * target_resolution[1],
                    target_extent[0] + (x + chunk_width) * target_resolution[0],
                    target_extent[3] - y * target_resolution[1]
                ]
                
                chunk_data = self._process_chunk(
                    datasets, chunk_extent, chunk_width, chunk_height,
                    stat_method, nodata_value, resampling
                )
                
                out_band.WriteArray(chunk_data, x, y)
                
                processed_chunks += 1
                progress = 20 + int((processed_chunks / total_chunks) * 75)
                feedback.setProgress(progress)
            
            if feedback.isCanceled():
                break
        
        # Cleanup
        out_band.FlushCache()
        out_ds = None
        for ds in datasets:
            ds = None
        
        feedback.pushInfo(self.tr("Mosaic completed successfully!"))
        feedback.setProgress(100)

    def _calculate_target_params(self, datasets):
        """Calculate target extent, resolution, and SRS"""
        ref_ds = datasets[0]
        target_srs = osr.SpatialReference()
        target_srs.ImportFromWkt(ref_ds.GetProjection())
        
        extents = []
        resolutions = []
        
        for ds in datasets:
            gt = ds.GetGeoTransform()
            width, height = ds.RasterXSize, ds.RasterYSize
            
            min_x = gt[0]
            max_y = gt[3]
            max_x = min_x + width * gt[1]
            min_y = max_y + height * gt[5]
            
            extents.append([min_x, min_y, max_x, max_y])
            resolutions.append([abs(gt[1]), abs(gt[5])])
        
        target_extent = [
            min(ext[0] for ext in extents),
            min(ext[1] for ext in extents),
            max(ext[2] for ext in extents),
            max(ext[3] for ext in extents)
        ]
        
        target_resolution = [
            min(res[0] for res in resolutions),
            min(res[1] for res in resolutions)
        ]
        
        return target_extent, target_resolution, target_srs

    def _process_chunk(self, datasets, extent, width, height, stat_method, nodata_value, resampling):
        """Process a single chunk"""
        chunk_arrays = []
        
        for ds in datasets:
            temp_ds = gdal.Warp(
                '', ds,
                format='MEM',
                outputBounds=extent,
                width=width,
                height=height,
                resampleAlg=self._get_gdal_resampling(resampling),
                dstNodata=nodata_value
            )
            
            if temp_ds:
                array = temp_ds.GetRasterBand(1).ReadAsArray()
                if array is not None:
                    chunk_arrays.append(array.astype(np.float64))
                temp_ds = None
        
        if not chunk_arrays:
            return np.full((height, width), nodata_value, dtype=np.float64)
        
        return self._apply_statistic(chunk_arrays, stat_method, nodata_value)

    def _apply_statistic(self, arrays, method, nodata_value):
        """Apply statistical method to arrays"""
        if not arrays:
            return np.full((1, 1), nodata_value, dtype=np.float64)
        
        stack = np.stack(arrays, axis=0)
        valid_mask = stack != nodata_value
        result = np.full(stack.shape[1:], nodata_value, dtype=np.float64)
        
        if method == 'first':
            for i in range(stack.shape[0]):
                mask = valid_mask[i] & (result == nodata_value)
                result[mask] = stack[i][mask]
        elif method == 'last':
            for i in range(stack.shape[0] - 1, -1, -1):
                mask = valid_mask[i]
                result[mask] = stack[i][mask]
        elif method == 'min':
            valid_data = np.where(valid_mask, stack, np.inf)
            result = np.min(valid_data, axis=0)
            result[result == np.inf] = nodata_value
        elif method == 'max':
            valid_data = np.where(valid_mask, stack, -np.inf)
            result = np.max(valid_data, axis=0)
            result[result == -np.inf] = nodata_value
        elif method == 'mean':
            result = np.mean(np.where(valid_mask, stack, np.nan), axis=0)
            result[np.isnan(result)] = nodata_value
        elif method == 'median':
            result = np.median(np.where(valid_mask, stack, np.nan), axis=0)
            result[np.isnan(result)] = nodata_value
        elif method == 'sum':
            result = np.sum(np.where(valid_mask, stack, 0), axis=0)
            no_valid = ~np.any(valid_mask, axis=0)
            result[no_valid] = nodata_value
        
        return result

    def _get_gdal_data_type(self, data_type, reference_ds):
        """Get GDAL data type"""
        if data_type == 'Auto':
            return reference_ds.GetRasterBand(1).DataType
        
        type_map = {
            'Byte': gdal.GDT_Byte,
            'UInt16': gdal.GDT_UInt16,
            'Int16': gdal.GDT_Int16,
            'Float32': gdal.GDT_Float32,
            'Float64': gdal.GDT_Float64
        }
        return type_map.get(data_type, gdal.GDT_Float32)

    def _get_gdal_resampling(self, method):
        """Get GDAL resampling method"""
        method_map = {
            'nearest': gdal.GRA_NearestNeighbour,
            'bilinear': gdal.GRA_Bilinear,
            'cubic': gdal.GRA_Cubic
        }
        return method_map.get(method, gdal.GRA_Bilinear)

    # def parameterAsOutputLayer(self, parameters, name, context):
    #     """Get output layer parameter"""
    #     if name not in parameters:
    #         raise QgsProcessingException(f"Parameter '{name}' not found")
        
    #     output = parameters[name]
    #     if not isinstance(output, str):
    #         raise QgsProcessingException(f"Parameter '{name}' must be a string path")
        
    #     return output