import os
import numpy as np
from osgeo import gdal, gdalconst, osr
from qgis.PyQt import QtCore, QtWidgets, uic
from qgis.PyQt.QtWidgets import QFileDialog, QMessageBox, QListWidgetItem, QProgressDialog
from qgis.PyQt.QtCore import Qt, QThread, pyqtSignal
from qgis.core import QgsProject, QgsRasterLayer, QgsMessageLog, Qgis

FORM_CLASS, _ = uic.loadUiType(os.path.join(os.path.dirname(__file__), 'mosaic_dialog.ui'))

class MosaicWorker(QThread):
    """Worker thread for mosaic processing with compression support"""
    progress = pyqtSignal(int)
    message = pyqtSignal(str)
    finished = pyqtSignal(bool, str)
    
    def __init__(self, input_files, output_path, stat_method, data_type, resampling, compression, jpeg_quality):
        super().__init__()
        self.input_files = input_files
        self.output_path = output_path
        self.stat_method = stat_method
        self.data_type = data_type
        self.resampling = resampling
        self.compression = compression
        self.jpeg_quality = jpeg_quality
        self.is_canceled = False
    
    def cancel(self):
        self.is_canceled = True
    
    def run(self):
        try:
            self.create_mosaic()
            self.finished.emit(True, "Mosaic created successfully!")
        except Exception as e:
            self.finished.emit(False, str(e))
    
    def create_mosaic(self):
        """Create mosaic with compression support"""
        # Open and validate datasets
        datasets = []
        for i, file_path in enumerate(self.input_files):
            if self.is_canceled:
                return
            
            ds = gdal.Open(file_path, gdalconst.GA_ReadOnly)
            if ds is None:
                raise Exception(f"Cannot open: {file_path}")
            datasets.append(ds)
            self.progress.emit(int((i + 1) / len(self.input_files) * 20))
        
        self.message.emit("Calculating mosaic parameters...")
        
        # Calculate target parameters
        target_extent, target_resolution, target_srs = self._calculate_target_params(datasets)
        
        # Determine output data type
        gdal_data_type = self._get_gdal_data_type(self.data_type, datasets[0])
        
        # Calculate dimensions
        width = int((target_extent[2] - target_extent[0]) / target_resolution[0])
        height = int((target_extent[3] - target_extent[1]) / target_resolution[1])
        
        self.message.emit(f"Creating output: {width} x {height} with {self.compression} compression")
        
        # Create output dataset with compression options
        driver = gdal.GetDriverByName('GTiff')
        creation_options = self._get_creation_options()
        
        out_ds = driver.Create(
            self.output_path, width, height, 1, gdal_data_type,
            options=creation_options
        )
        
        if out_ds is None:
            raise Exception(f"Cannot create output file: {self.output_path}")
        
        # Set geotransform and projection
        geotransform = [
            target_extent[0], target_resolution[0], 0,
            target_extent[3], 0, -target_resolution[1]
        ]
        out_ds.SetGeoTransform(geotransform)
        out_ds.SetProjection(target_srs.ExportToWkt())
        
        # Set nodata value
        nodata_value = -9999.0
        out_band = out_ds.GetRasterBand(1)
        out_band.SetNoDataValue(nodata_value)
        
        self.message.emit("Processing mosaic data...")
        
        # Process in chunks
        chunk_size = 1024
        total_chunks = ((height - 1) // chunk_size + 1) * ((width - 1) // chunk_size + 1)
        processed_chunks = 0
        
        for y in range(0, height, chunk_size):
            for x in range(0, width, chunk_size):
                if self.is_canceled:
                    return
                
                chunk_width = min(chunk_size, width - x)
                chunk_height = min(chunk_size, height - y)
                
                chunk_extent = [
                    target_extent[0] + x * target_resolution[0],
                    target_extent[3] - (y + chunk_height) * target_resolution[1],
                    target_extent[0] + (x + chunk_width) * target_resolution[0],
                    target_extent[3] - y * target_resolution[1]
                ]
                
                chunk_data = self._process_chunk(
                    datasets, chunk_extent, chunk_width, chunk_height,
                    self.stat_method, nodata_value, self.resampling
                )
                
                out_band.WriteArray(chunk_data, x, y)
                
                processed_chunks += 1
                progress = 20 + int((processed_chunks / total_chunks) * 75)
                self.progress.emit(progress)
            
            if self.is_canceled:
                return
        
        # Cleanup
        out_band.FlushCache()
        out_ds = None
        for ds in datasets:
            ds = None
        
        self.progress.emit(100)

    def _get_creation_options(self):
        """Get GDAL creation options with compression"""
        options = ['TILED=YES', 'BIGTIFF=IF_SAFER']
        
        if self.compression != 'NONE':
            options.append(f'COMPRESS={self.compression}')
            
            if self.compression == 'JPEG':
                options.append(f'JPEG_QUALITY={self.jpeg_quality}')
            elif self.compression == 'DEFLATE':
                options.append('PREDICTOR=2')
        
        return options

    def _calculate_target_params(self, datasets):
        """Calculate target extent, resolution, and SRS"""
        ref_ds = datasets[0]
        target_srs = osr.SpatialReference()
        target_srs.ImportFromWkt(ref_ds.GetProjection())
        
        extents = []
        resolutions = []
        
        for ds in datasets:
            gt = ds.GetGeoTransform()
            width, height = ds.RasterXSize, ds.RasterYSize
            
            min_x = gt[0]
            max_y = gt[3]
            max_x = min_x + width * gt[1]
            min_y = max_y + height * gt[5]
            
            extents.append([min_x, min_y, max_x, max_y])
            resolutions.append([abs(gt[1]), abs(gt[5])])
        
        target_extent = [
            min(ext[0] for ext in extents),
            min(ext[1] for ext in extents),
            max(ext[2] for ext in extents),
            max(ext[3] for ext in extents)
        ]
        
        target_resolution = [
            min(res[0] for res in resolutions),
            min(res[1] for res in resolutions)
        ]
        
        return target_extent, target_resolution, target_srs

    def _process_chunk(self, datasets, extent, width, height, stat_method, nodata_value, resampling):
        """Process a single chunk"""
        chunk_arrays = []
        
        for ds in datasets:
            temp_ds = gdal.Warp(
                '', ds,
                format='MEM',
                outputBounds=extent,
                width=width,
                height=height,
                resampleAlg=self._get_gdal_resampling(resampling),
                dstNodata=nodata_value
            )
            
            if temp_ds:
                array = temp_ds.GetRasterBand(1).ReadAsArray()
                if array is not None:
                    chunk_arrays.append(array.astype(np.float64))
                temp_ds = None
        
        if not chunk_arrays:
            return np.full((height, width), nodata_value, dtype=np.float64)
        
        return self._apply_statistic(chunk_arrays, stat_method, nodata_value)

    def _apply_statistic(self, arrays, method, nodata_value):
        """Apply statistical method to arrays"""
        if not arrays:
            return np.full((1, 1), nodata_value, dtype=np.float64)
        
        stack = np.stack(arrays, axis=0)
        valid_mask = stack != nodata_value
        result = np.full(stack.shape[1:], nodata_value, dtype=np.float64)
        
        if method == 'first':
            for i in range(stack.shape[0]):
                mask = valid_mask[i] & (result == nodata_value)
                result[mask] = stack[i][mask]
        elif method == 'last':
            for i in range(stack.shape[0] - 1, -1, -1):
                mask = valid_mask[i]
                result[mask] = stack[i][mask]
        elif method == 'min':
            valid_data = np.where(valid_mask, stack, np.inf)
            result = np.min(valid_data, axis=0)
            result[result == np.inf] = nodata_value
        elif method == 'max':
            valid_data = np.where(valid_mask, stack, -np.inf)
            result = np.max(valid_data, axis=0)
            result[result == -np.inf] = nodata_value
        elif method == 'mean':
            result = np.mean(np.where(valid_mask, stack, np.nan), axis=0)
            result[np.isnan(result)] = nodata_value
        elif method == 'median':
            result = np.median(np.where(valid_mask, stack, np.nan), axis=0)
            result[np.isnan(result)] = nodata_value
        elif method == 'sum':
            result = np.sum(np.where(valid_mask, stack, 0), axis=0)
            no_valid = ~np.any(valid_mask, axis=0)
            result[no_valid] = nodata_value
        
        return result

    def _get_gdal_data_type(self, data_type, reference_ds):
        """Get GDAL data type"""
        if data_type == 'Auto':
            return reference_ds.GetRasterBand(1).DataType
        
        type_map = {
            'Byte': gdal.GDT_Byte,
            'UInt16': gdal.GDT_UInt16,
            'Int16': gdal.GDT_Int16,
            'Float32': gdal.GDT_Float32,
            'Float64': gdal.GDT_Float64
        }
        return type_map.get(data_type, gdal.GDT_Float32)

    def _get_gdal_resampling(self, method):
        """Get GDAL resampling method"""
        method_map = {
            'nearest': gdal.GRA_NearestNeighbour,
            'bilinear': gdal.GRA_Bilinear,
            'cubic': gdal.GRA_Cubic
        }
        return method_map.get(method, gdal.GRA_Bilinear)


class MosaicDialog(QtWidgets.QDialog, FORM_CLASS):
    def __init__(self, iface, parent=None):
        super(MosaicDialog, self).__init__(parent)
        self.setupUi(self)
        self.iface = iface
        self.worker = None
        
        # Keep dialog on top of QGIS application only
        self.setWindowFlags(self.windowFlags() | Qt.Tool)
        
        # Connect signals
        self.open_bands_btn.clicked.connect(self.open_bands)
        self.select_all_btn.clicked.connect(self.select_all)
        self.unselect_all_btn.clicked.connect(self.unselect_all)
        self.remove_btn.clicked.connect(self.remove_selected)
        self.output_browse_btn.clicked.connect(self.browse_output)
        self.mosaic_btn.clicked.connect(self.run_mosaic)
        self.cancel_btn.clicked.connect(self.reject)
        
        # Connect compression combo to enable/disable JPEG quality
        self.compression_combo.currentTextChanged.connect(self.on_compression_changed)
        
        # Load current raster layers
        self.load_project_rasters()
        
        # Set default values with corrected indices
        self.mosaic_method_combo.setCurrentIndex(0)  # First (first option)
        self.data_type_combo.setCurrentIndex(0)      # Auto (first option)
        self.resampling_combo.setCurrentIndex(0)     # Nearest (now first option - index 0)
        self.compression_combo.setCurrentIndex(0)    # NONE (now first option - index 0)
        
        # Initialize JPEG quality state
        self.on_compression_changed()

    def on_compression_changed(self):
        """
        Enable/disable JPEG quality based on compression selection
        Also update visual feedback for better user experience
        """
        is_jpeg = self.compression_combo.currentText() == 'JPEG'
        
        # Enable/disable JPEG quality controls
        self.jpeg_quality_spin.setEnabled(is_jpeg)
        
        # Update label styling to indicate when it's relevant
        if hasattr(self, 'quality_label'):
            if is_jpeg:
                self.quality_label.setStyleSheet("color: black; font-weight: bold;")
            else:
                self.quality_label.setStyleSheet("color: gray;")

    def load_project_rasters(self):
        """Load raster layers from current project"""
        self.raster_list.clear()
        project = QgsProject.instance()
        
        for layer in project.mapLayers().values():
            if isinstance(layer, QgsRasterLayer) and layer.isValid():
                item = QListWidgetItem(layer.name())
                item.setCheckState(Qt.Unchecked)
                item.setData(Qt.UserRole, layer.source())
                self.raster_list.addItem(item)

    def open_bands(self):
        """Open raster files dialog"""
        files, _ = QFileDialog.getOpenFileNames(
            self, 
            "Select Raster Files", 
            "", 
            "Raster files (*.tif *.tiff *.img *.jpg *.png *.bmp);;All files (*.*)"
        )
        
        for file_path in files:
            existing = False
            for i in range(self.raster_list.count()):
                if self.raster_list.item(i).data(Qt.UserRole) == file_path:
                    existing = True
                    break
            
            if not existing:
                item = QListWidgetItem(os.path.basename(file_path))
                item.setCheckState(Qt.Unchecked)
                item.setData(Qt.UserRole, file_path)
                self.raster_list.addItem(item)

    def select_all(self):
        """Select all items in the list"""
        for i in range(self.raster_list.count()):
            self.raster_list.item(i).setCheckState(Qt.Checked)

    def unselect_all(self):
        """Unselect all items in the list"""
        for i in range(self.raster_list.count()):
            self.raster_list.item(i).setCheckState(Qt.Unchecked)

    def remove_selected(self):
        """Remove selected items from the list"""
        items_to_remove = []
        for i in range(self.raster_list.count()):
            item = self.raster_list.item(i)
            if item.checkState() == Qt.Checked:
                items_to_remove.append(i)
        
        for i in reversed(items_to_remove):
            self.raster_list.takeItem(i)

    def browse_output(self):
        """Browse for output file"""
        file_path, _ = QFileDialog.getSaveFileName(
            self, 
            "Save Mosaic As", 
            "", 
            "GeoTIFF (*.tif);;All files (*.*)"
        )
        
        if file_path:
            if not file_path.lower().endswith('.tif'):
                file_path += '.tif'
            self.output_path_edit.setText(file_path)

    def run_mosaic(self):
        """Execute the mosaic operation"""
        # Get selected files
        selected_files = []
        for i in range(self.raster_list.count()):
            item = self.raster_list.item(i)
            if item.checkState() == Qt.Checked:
                selected_files.append(item.data(Qt.UserRole))
        
        if len(selected_files) < 2:
            QMessageBox.warning(self, "Mosaic Tool", "Please select at least 2 raster files.")
            return
        
        output_path = self.output_path_edit.text().strip()
        if not output_path:
            QMessageBox.warning(self, "Mosaic Tool", "Please specify an output file path.")
            return
        
        # Get parameters with updated order - NONE and nearest as first options
        stat_methods = ['first', 'last', 'min', 'max', 'mean', 'median', 'sum']
        data_types = ['Auto', 'Byte', 'UInt16', 'Int16', 'Float32', 'Float64']
        resampling_methods = ['nearest', 'bilinear', 'cubic']  # nearest now first
        compression_methods = ['NONE', 'LZW', 'DEFLATE', 'PACKBITS', 'JPEG']  # NONE now first
        
        stat_method = stat_methods[self.mosaic_method_combo.currentIndex()]
        data_type = data_types[self.data_type_combo.currentIndex()]
        resampling = resampling_methods[self.resampling_combo.currentIndex()]
        
        # Get compression from combo box text (more reliable than index)
        compression = self.compression_combo.currentText()
        jpeg_quality = self.jpeg_quality_spin.value()
        
        try:
            # Create progress dialog
            progress_dlg = QProgressDialog("Processing mosaic...", "Cancel", 0, 100, self)
            progress_dlg.setWindowModality(Qt.WindowModal)
            progress_dlg.show()
            
            # Log processing parameters
            QgsMessageLog.logMessage(
                f"Starting mosaic with: Method={stat_method}, DataType={data_type}, "
                f"Resampling={resampling}, Compression={compression}, Quality={jpeg_quality}",
                'MAS Geospatial Tools', Qgis.Info
            )
            
            # Create worker thread
            self.worker = MosaicWorker(
                selected_files, output_path, stat_method, data_type, 
                resampling, compression, jpeg_quality
            )
            
            # Connect signals
            self.worker.progress.connect(progress_dlg.setValue)
            self.worker.message.connect(progress_dlg.setLabelText)
            self.worker.finished.connect(lambda success, msg: self.on_finished(success, msg, progress_dlg))
            progress_dlg.canceled.connect(self.worker.cancel)
            
            # Start processing
            self.worker.start()
            
        except Exception as e:
            QMessageBox.critical(self, "Mosaic Tool", f"Error starting mosaic: {str(e)}")
            QgsMessageLog.logMessage(f"Mosaic start error: {str(e)}", 'MAS Geospatial Tools', Qgis.Critical)

    def on_finished(self, success, message, progress_dlg):
        """Handle completion of mosaic processing"""
        progress_dlg.close()
        
        if success:
            # Load result if requested
            if self.open_output_check.isChecked():
                output_path = self.output_path_edit.text().strip()
                layer = QgsRasterLayer(output_path, "Mosaic Result")
                if layer.isValid():
                    QgsProject.instance().addMapLayer(layer)
                    QgsMessageLog.logMessage(f"Mosaic result loaded: {output_path}", 'MAS Geospatial Tools', Qgis.Info)
                else:
                    QgsMessageLog.logMessage(f"Failed to load result: {output_path}", 'MAS Geospatial Tools', Qgis.Warning)
            
            QMessageBox.information(self, "Mosaic Tool", message)
            self.accept()
        else:
            QMessageBox.critical(self, "Mosaic Tool", f"Error: {message}")
            QgsMessageLog.logMessage(f"Mosaic processing error: {message}", 'MAS Geospatial Tools', Qgis.Critical)
        
        self.worker = None

    def closeEvent(self, event):
        """Handle dialog close event"""
        if self.worker and self.worker.isRunning():
            reply = QMessageBox.question(
                self, 'Mosaic Tool', 
                'Processing is still running. Do you want to cancel and close?',
                QMessageBox.Yes | QMessageBox.No, 
                QMessageBox.No
            )
            
            if reply == QMessageBox.Yes:
                self.worker.cancel()
                self.worker.wait()  # Wait for thread to finish
                event.accept()
            else:
                event.ignore()
        else:
            event.accept()
