from PyQt5.QtWidgets import (QDialog, QVBoxLayout, QComboBox, QSpinBox, QGroupBox, 
                             QFormLayout, QLabel, QDoubleSpinBox, QListWidget, QPushButton, 
                             QListWidgetItem, QLineEdit, QFileDialog, QCheckBox, QHBoxLayout,
                             QTableWidget, QTableWidgetItem, QHeaderView, QWidget, QProgressBar)
from PyQt5.QtCore import Qt
from osgeo import gdal
import os
from qgis.core import QgsRasterLayer

# Check for sklearn availability
try:
    from sklearn.cluster import AgglomerativeClustering, DBSCAN, SpectralClustering
    sklearn_available = True
except ImportError:
    sklearn_available = False

class UnsupervisedClassifierDialog(QDialog):
    def __init__(self, iface, parent=None):
        super().__init__(parent)
        self.iface = iface
        self.setWindowTitle("Unsupervised Classifier")
        self.setGeometry(100, 100, 900, 700)
        self.setMinimumWidth(800)
        self.setWindowFlags(Qt.Dialog)
        
        self.layout = QVBoxLayout(self)
        
        # ===== RASTER LAYERS TABLE SECTION =====
        self.rasterTableLabel = QLabel("Raster Layers for Batch Processing:", self)
        self.layout.addWidget(self.rasterTableLabel)
        
        # Create table for raster layers
        self.rasterTable = QTableWidget(self)
        self.rasterTable.setColumnCount(3)
        self.rasterTable.setHorizontalHeaderLabels(["Select", "Raster Name", "Output File Name"])
        self.rasterTable.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeToContents)
        self.rasterTable.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
        self.rasterTable.horizontalHeader().setSectionResizeMode(2, QHeaderView.Stretch)
        self.rasterTable.setMinimumHeight(150)
        self.rasterTable.setRowCount(0)
        self.layout.addWidget(self.rasterTable)
        
        # Select All / Unselect All button
        self.selectAllButton = QPushButton("Select All", self)
        self.selectAllButton.clicked.connect(self.toggle_select_all)
        self.selectAllButton.setMaximumWidth(150)
        self.layout.addWidget(self.selectAllButton)
        
        # Input selection
        self.inputFileLabel = QLabel("Add Input Raster:", self)
        self.layout.addWidget(self.inputFileLabel)
        
        self.inputFileLineEdit = QLineEdit(self)
        self.inputFileButton = QPushButton("...", self)
        self.inputFileButton.setMaximumWidth(50)
        self.inputFileButton.clicked.connect(self.select_input_file)
        self.addRasterButton = QPushButton("Add to List", self)
        self.addRasterButton.setMaximumWidth(100)
        self.addRasterButton.clicked.connect(self.add_raster_to_table)
        
        self.inputFileLayout = QHBoxLayout()
        self.inputFileLayout.addWidget(self.inputFileLineEdit)
        self.inputFileLayout.addWidget(self.inputFileButton)
        self.inputFileLayout.addWidget(self.addRasterButton)
        self.layout.addLayout(self.inputFileLayout)
        
        # ===== OUTPUT FOLDER SECTION =====
        self.outputFolderLabel = QLabel("Output Folder:", self)
        self.layout.addWidget(self.outputFolderLabel)
        
        self.outputFolderLineEdit = QLineEdit(self)
        self.outputFolderButton = QPushButton("...", self)
        self.outputFolderButton.setMaximumWidth(50)
        self.outputFolderButton.clicked.connect(self.select_output_folder)
        
        self.outputFolderLayout = QHBoxLayout()
        self.outputFolderLayout.addWidget(self.outputFolderLineEdit)
        self.outputFolderLayout.addWidget(self.outputFolderButton)
        self.layout.addLayout(self.outputFolderLayout)
        
        # Checkbox for "Save output same as input"
        self.sameAsInputCheckBox = QCheckBox("Save output same as input folder?", self)
        self.sameAsInputCheckBox.stateChanged.connect(self.toggle_output_folder)
        self.layout.addWidget(self.sameAsInputCheckBox)
        
        # Algorithm selection
        self.algorithmLabel = QLabel("Select Clustering Method:", self)
        self.layout.addWidget(self.algorithmLabel)
        self.algorithmComboBox = QComboBox(self)
        self.algorithmComboBox.addItem("Kmeans (Best Method)")
        self.algorithmComboBox.addItem("ISODATA (Time Taking)")
        
        if sklearn_available:
            self.algorithmComboBox.addItem("Agglomerative Clustering")
            self.algorithmComboBox.addItem("DBSCAN")
            self.algorithmComboBox.addItem("Spectral Clustering")
        
        self.layout.addWidget(self.algorithmComboBox)
        
        # Number of clusters
        self.numClustersSpinBox = QSpinBox(self)
        self.numClustersSpinBox.setMinimum(2)
        self.numClustersSpinBox.setMaximum(10)
        self.numClustersSpinBox.setValue(5)
        self.numClustersSpinBox.setPrefix("Number of Clusters: ")
        self.layout.addWidget(self.numClustersSpinBox)
        
        # ISODATA options
        self.isodataOptionsGroupBox = QGroupBox("ISODATA Options", self)
        self.isodataOptionsLayout = QFormLayout(self.isodataOptionsGroupBox)
        
        self.maxIterLabel = QLabel("Max Iterations", self)
        self.maxIterSpinBox = QSpinBox(self)
        self.maxIterSpinBox.setMaximum(1000)
        self.maxIterSpinBox.setValue(100)
        self.isodataOptionsLayout.addRow(self.maxIterLabel, self.maxIterSpinBox)
        
        self.maxMergeLabel = QLabel("Max Merge", self)
        self.maxMergeDoubleSpinBox = QDoubleSpinBox(self)
        self.maxMergeDoubleSpinBox.setMaximum(10.0)
        self.maxMergeDoubleSpinBox.setValue(0.5)
        self.isodataOptionsLayout.addRow(self.maxMergeLabel, self.maxMergeDoubleSpinBox)
        
        self.minSplitStdLabel = QLabel("Min Split Std", self)
        self.minSplitStdDoubleSpinBox = QDoubleSpinBox(self)
        self.minSplitStdDoubleSpinBox.setMaximum(10.0)
        self.minSplitStdDoubleSpinBox.setValue(0.5)
        self.isodataOptionsLayout.addRow(self.minSplitStdLabel, self.minSplitStdDoubleSpinBox)
        
        self.maxStdLabel = QLabel("Max Std", self)
        self.maxStdDoubleSpinBox = QDoubleSpinBox(self)
        self.maxStdDoubleSpinBox.setMaximum(10.0)
        self.maxStdDoubleSpinBox.setValue(1.0)
        self.isodataOptionsLayout.addRow(self.maxStdLabel, self.maxStdDoubleSpinBox)
        
        self.minSamplesLabel = QLabel("Min Samples", self)
        self.minSamplesSpinBox = QSpinBox(self)
        self.minSamplesSpinBox.setMaximum(1000)
        self.minSamplesSpinBox.setValue(10)
        self.isodataOptionsLayout.addRow(self.minSamplesLabel, self.minSamplesSpinBox)
        
        self.layout.addWidget(self.isodataOptionsGroupBox)
        
        # Band selection options
        self.useNumBandsCheckBox = QCheckBox("Do you want to select available bands?", self)
        self.layout.addWidget(self.useNumBandsCheckBox)
        
        self.numBandsSpinBox = QSpinBox(self)
        self.numBandsSpinBox.setMinimum(1)
        self.numBandsSpinBox.setMaximum(10)
        self.numBandsSpinBox.setValue(4)
        self.numBandsSpinBox.setPrefix("Number of Bands: ")
        self.layout.addWidget(self.numBandsSpinBox)
        
        self.selectedBandsListWidget = QListWidget(self)
        self.layout.addWidget(self.selectedBandsListWidget)
        
        # Open output in QGIS
        self.openInQgisCheckBox = QCheckBox("Open the output in QGIS", self)
        self.layout.addWidget(self.openInQgisCheckBox)
        
        # ===== PROGRESS BAR SECTION =====
        self.progressLabel = QLabel("", self)
        self.layout.addWidget(self.progressLabel)
        
        self.progressBar = QProgressBar(self)
        self.progressBar.setMinimum(0)
        self.progressBar.setMaximum(100)
        self.progressBar.setValue(0)
        self.progressBar.setTextVisible(True)
        self.progressBar.setFormat("%p% - %v/%m files")
        self.progressBar.hide()  # Hidden by default
        self.layout.addWidget(self.progressBar)
        
        # Run button
        self.runButton = QPushButton("Run Clustering", self)
        self.layout.addWidget(self.runButton)
        
        # Connect signals
        self.algorithmComboBox.currentIndexChanged.connect(self.toggle_options)
        self.useNumBandsCheckBox.stateChanged.connect(self.toggle_band_selection)
        self.numBandsSpinBox.valueChanged.connect(self.populate_band_options)
        self.inputFileLineEdit.textChanged.connect(self.update_band_options)
        
        # Initial setup
        self.toggle_options()
        self.toggle_band_selection()
        self.populate_band_options()
        self.isodataOptionsGroupBox.hide()
        
        self.all_selected = True
    
    def update_progress(self, current, total, message=""):
        """Update the progress bar and label"""
        self.progressBar.setMaximum(total)
        self.progressBar.setValue(current)
        self.progressLabel.setText(message)
        self.progressBar.show()
        # Force GUI update
        from PyQt5.QtWidgets import QApplication
        QApplication.processEvents()
    
    def hide_progress(self):
        """Hide the progress bar and label"""
        self.progressBar.hide()
        self.progressLabel.setText("")
    
    def toggle_select_all(self):
        row_count = self.rasterTable.rowCount()
        if row_count == 0:
            return
        
        if self.all_selected:
            for row in range(row_count):
                checkbox_widget = self.rasterTable.cellWidget(row, 0)
                if checkbox_widget:
                    checkbox = checkbox_widget.findChild(QCheckBox)
                    if checkbox:
                        checkbox.setChecked(False)
            self.selectAllButton.setText("Select All")
            self.all_selected = False
        else:
            for row in range(row_count):
                checkbox_widget = self.rasterTable.cellWidget(row, 0)
                if checkbox_widget:
                    checkbox = checkbox_widget.findChild(QCheckBox)
                    if checkbox:
                        checkbox.setChecked(True)
            self.selectAllButton.setText("Unselect All")
            self.all_selected = True
    
    def toggle_output_folder(self):
        if self.sameAsInputCheckBox.isChecked():
            self.outputFolderLineEdit.setEnabled(False)
            self.outputFolderButton.setEnabled(False)
            self.outputFolderLineEdit.clear()
        else:
            self.outputFolderLineEdit.setEnabled(True)
            self.outputFolderButton.setEnabled(True)
    
    def select_output_folder(self):
        folder = QFileDialog.getExistingDirectory(self, "Select Output Folder", "")
        if folder:
            self.outputFolderLineEdit.setText(folder)
    
    def add_raster_to_table(self):
        input_file = self.inputFileLineEdit.text().strip()
        
        if not input_file:
            return
        
        if not os.path.exists(input_file):
            return
        
        for row in range(self.rasterTable.rowCount()):
            raster_item = self.rasterTable.item(row, 1)
            if raster_item and raster_item.data(Qt.UserRole) == input_file:
                return
        
        base_name = os.path.splitext(os.path.basename(input_file))[0]
        default_output_name = f"{base_name}_classified.tif"
        
        row_position = self.rasterTable.rowCount()
        self.rasterTable.insertRow(row_position)
        
        checkbox = QCheckBox()
        checkbox.setChecked(True)
        checkbox_widget = QWidget()
        checkbox_layout = QHBoxLayout(checkbox_widget)
        checkbox_layout.addWidget(checkbox)
        checkbox_layout.setAlignment(Qt.AlignCenter)
        checkbox_layout.setContentsMargins(0, 0, 0, 0)
        self.rasterTable.setCellWidget(row_position, 0, checkbox_widget)
        
        raster_item = QTableWidgetItem(base_name)
        raster_item.setFlags(raster_item.flags() & ~Qt.ItemIsEditable)
        raster_item.setData(Qt.UserRole, input_file)
        self.rasterTable.setItem(row_position, 1, raster_item)
        
        output_item = QTableWidgetItem(default_output_name)
        self.rasterTable.setItem(row_position, 2, output_item)
        
        self.rasterTable.setRowHeight(row_position, 30)
        self.inputFileLineEdit.clear()
    
    def get_selected_rasters(self):
        selected = []
        for row in range(self.rasterTable.rowCount()):
            checkbox_widget = self.rasterTable.cellWidget(row, 0)
            if checkbox_widget:
                checkbox = checkbox_widget.findChild(QCheckBox)
                
                if checkbox and checkbox.isChecked():
                    raster_item = self.rasterTable.item(row, 1)
                    output_item = self.rasterTable.item(row, 2)
                    
                    if raster_item and output_item:
                        input_path = raster_item.data(Qt.UserRole)
                        output_name = output_item.text()
                        
                        if self.sameAsInputCheckBox.isChecked():
                            input_dir = os.path.dirname(input_path)
                            output_path = os.path.join(input_dir, output_name)
                        else:
                            output_folder = self.outputFolderLineEdit.text()
                            if output_folder:
                                output_path = os.path.join(output_folder, output_name)
                            else:
                                input_dir = os.path.dirname(input_path)
                                output_path = os.path.join(input_dir, output_name)
                        
                        selected.append({
                            'input': input_path,
                            'output': output_path
                        })
        return selected
    
    def toggle_options(self):
        if self.algorithmComboBox.currentText() == "ISODATA (Time Taking)":
            self.isodataOptionsGroupBox.show()
        else:
            self.isodataOptionsGroupBox.hide()
        self.adjustSize()
    
    def toggle_band_selection(self):
        if self.useNumBandsCheckBox.isChecked():
            self.numBandsSpinBox.hide()
            self.selectedBandsListWidget.show()
        else:
            self.numBandsSpinBox.show()
            self.selectedBandsListWidget.hide()
        self.adjustSize()
    
    def populate_band_options(self):
        self.selectedBandsListWidget.clear()
        for i in range(1, self.numBandsSpinBox.value() + 1):
            item = QListWidgetItem(f"Band {i}")
            item.setCheckState(Qt.Checked)
            self.selectedBandsListWidget.addItem(item)
    
    def select_input_file(self):
        filename, _ = QFileDialog.getOpenFileName(self, "Select Input File", "", "GeoTIFF Files (*.tif *.tiff);;All Files (*.*)")
        if filename:
            self.inputFileLineEdit.setText(filename)
            self.update_band_options()
    
    def update_band_options(self):
        input_file = self.inputFileLineEdit.text()
        if input_file and os.path.exists(input_file):
            try:
                dataset = gdal.Open(input_file)
                if dataset:
                    self.selectedBandsListWidget.clear()
                    num_bands = dataset.RasterCount
                    self.numBandsSpinBox.setMaximum(num_bands)
                    self.numBandsSpinBox.setValue(num_bands)
                    for i in range(1, num_bands + 1):
                        band = dataset.GetRasterBand(i)
                        description = band.GetDescription() or f"Band {i}"
                        item = QListWidgetItem(description)
                        item.setCheckState(Qt.Checked)
                        self.selectedBandsListWidget.addItem(item)
                    dataset = None
            except Exception as e:
                print(f"Error reading bands: {e}")
