
from qgis.PyQt.QtWidgets import (QWidget, QTabWidget, QVBoxLayout, QHBoxLayout, QLabel, QComboBox, QLineEdit,
                                 QSpacerItem, QTreeWidget, QSizePolicy, QPushButton, QCheckBox, QGridLayout, QGroupBox,
                                 QRadioButton, QTreeWidgetItem, QAbstractItemView, QDialog, QDialogButtonBox)
from qgis.PyQt.QtCore import Qt, pyqtSignal, QVariant
from qgis.PyQt.QtGui import QColor, QIntValidator, QDoubleValidator
from qgis.core import QgsMapLayer, QgsMapLayerProxyModel, QgsField, QgsFields
from qgis.gui import QgsFileWidget, QgsMapLayerComboBox, QgsFieldComboBox, QgsColorButton

import math
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg



class GeoProfileGui(QTabWidget):

    lithologyLayerChanged = pyqtSignal(QgsMapLayer)
    lithologyFieldChanged = pyqtSignal(str)

    def __init__(self, tr, parent=None):
        super().__init__(parent)
        self.tr = tr
        self.__setupUI()

    def __setupUI(self):
        # Tabs
        topography_tab = QWidget()
        geology_tab = QWidget()
        export_tab = QWidget()

        # Topography layout
        topography_layout = QVBoxLayout()

        # DEM input
        dem_groupbox = QGroupBox(self.tr('Elevations data'))
        dem_layout = QGridLayout()
        dem_label = QLabel(self.tr('DEM Layer: '))
        dem_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__demLayerComboBox = QgsMapLayerComboBox()
        self.__demLayerComboBox.setFilters(QgsMapLayerProxyModel.RasterLayer)
        depth_label = QLabel(self.tr('Depth (m): '))
        depth_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__depthLineEdit = QLineEdit()
        depth_validator = QIntValidator(0, 1000)
        self.__depthLineEdit.setValidator(depth_validator)
        self.__depthLineEdit.setText('10')
        dem_layout.addWidget(dem_label, 0, 0)
        dem_layout.addWidget(self.__demLayerComboBox, 0, 1, 1, 3)
        dem_layout.addWidget(depth_label, 1, 0)
        dem_layout.addWidget(self.__depthLineEdit, 1, 1, 1, 3)
        dem_groupbox.setLayout(dem_layout)

        # Section input
        section_groupbox = QGroupBox(self.tr('Section'))
        section_layout = QGridLayout()
        section_label = QLabel(self.tr('Section layer: '))
        section_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__sectionLayerComboBox = QgsMapLayerComboBox()
        self.__sectionLayerComboBox.setFilters(QgsMapLayerProxyModel.LineLayer)
        # Section field
        section_field_label = QLabel(self.tr('Field: '))
        section_field_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__sectionFieldComboBox = QgsFieldComboBox()
        # Section feature
        section_feature_label = QLabel(self.tr('Line: '))
        section_feature_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__sectionFeatureComboBox = QComboBox()
        # Reverse section
        self.__invertSectionCheckBox = QCheckBox(self.tr('Invert line'))
        # Add section widgets
        section_layout.addWidget(section_label, 0, 0)
        section_layout.addWidget(self.__sectionLayerComboBox, 0, 1, 1, 3)
        section_layout.addWidget(section_field_label, 1, 0)
        section_layout.addWidget(self.__sectionFieldComboBox, 1, 1, 1, 3)
        section_layout.addWidget(section_feature_label, 2, 0)
        section_layout.addWidget(self.__sectionFeatureComboBox, 2, 1)
        section_layout.addItem(QSpacerItem(25, 10, QSizePolicy.Minimum, QSizePolicy.Minimum), 2, 2)
        section_layout.addWidget(self.__invertSectionCheckBox, 2, 3)
        section_groupbox.setLayout(section_layout)

        # Topography preview button
        self.topographyButton = QPushButton(self.tr('Preview topography'))
        self.topographyButton.setFixedWidth(150)

        # Add topography widgets
        topography_layout.addWidget(dem_groupbox)
        topography_layout.addSpacing(15)
        topography_layout.addWidget(section_groupbox)
        topography_layout.addSpacing(15)
        topography_layout.addWidget(self.topographyButton, alignment=Qt.AlignCenter)

        # Add topography layout
        topography_tab.setLayout(topography_layout)

        # Geology layout
        geology_layout = QVBoxLayout()

        # Lithology input
        lithology_groupbox = QGroupBox(self.tr('Lithology data'))
        lithology_layout = QGridLayout()
        lithology_label = QLabel(self.tr('Lithology layer: '))
        lithology_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__lithologyLayerComboBox = QgsMapLayerComboBox()
        self.__lithologyLayerComboBox.setFilters(QgsMapLayerProxyModel.PolygonLayer)
        lithology_field_label = QLabel(self.tr('Classification: '))
        lithology_field_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__lithologyFieldComboBox = QgsFieldComboBox()
        self.__lithologyFieldComboBox.setLayer(self.__lithologyLayerComboBox.currentLayer())
        self.lithologyColorsButton = QPushButton(self.tr('Colors'))
        # Add lithology widgets
        lithology_layout.addWidget(lithology_label, 0, 0)
        lithology_layout.addWidget(self.__lithologyLayerComboBox, 0, 1, 1, 3)
        lithology_layout.addWidget(lithology_field_label, 1, 0)
        lithology_layout.addWidget(self.__lithologyFieldComboBox, 1, 1, 1, 2)
        lithology_layout.addWidget(self.lithologyColorsButton, 1, 3)
        lithology_groupbox.setLayout(lithology_layout)

        # Structural input
        structural_groupbox = QGroupBox(self.tr('Structural data'))
        structural_layout = QGridLayout()
        structural_label = QLabel(self.tr('Structural layer: '))
        structural_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__structuralLayerComboBox = QgsMapLayerComboBox()
        self.__structuralLayerComboBox.setFilters(QgsMapLayerProxyModel.LineLayer)
        self.__structuralLayerComboBox.setAdditionalItems(['< None >'])
        structural_field_label = QLabel(self.tr('Labels: '))
        structural_field_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__structuralFieldComboBox = QgsFieldComboBox()
        self.__structuralFieldComboBox.setLayer(self.__structuralLayerComboBox.currentLayer())
        self.__structuralColor = QgsColorButton()
        self.__structuralColor.setColor(QColor(0, 0, 255))
        # Add structural widgets
        structural_layout.addWidget(structural_label, 0, 0)
        structural_layout.addWidget(self.__structuralLayerComboBox, 0, 1, 1, 3)
        structural_layout.addWidget(structural_field_label, 1, 0)
        structural_layout.addWidget(self.__structuralFieldComboBox, 1, 1, 1, 2)
        structural_layout.addWidget(self.__structuralColor, 1, 3)
        structural_groupbox.setLayout(structural_layout)

        # Bedding input
        bedding_groupbox = QGroupBox(self.tr('Bedding data'))
        bedding_layout = QGridLayout()
        bedding_label = QLabel(self.tr('Bedding layer: '))
        bedding_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__beddingLayerComboBox = QgsMapLayerComboBox()
        self.__beddingLayerComboBox.setFilters(QgsMapLayerProxyModel.PointLayer)
        self.__beddingLayerComboBox.setAdditionalItems(['< None >'])
        bedding_azimuth_label = QLabel(self.tr('Azimuth (RHR): '))
        bedding_azimuth_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__beddingAzimuthFieldComboBox = QgsFieldComboBox()
        self.__beddingAzimuthFieldComboBox.setLayer(self.__beddingLayerComboBox.currentLayer())
        bedding_dip_label = QLabel(self.tr('Dip: '))
        bedding_dip_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__beddingDipFieldComboBox = QgsFieldComboBox()
        self.__beddingDipFieldComboBox.setLayer(self.__beddingLayerComboBox.currentLayer())
        bedding_length_label = QLabel(self.tr('Line length (m): '))
        bedding_length_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__beddingLengthLineEdit = QLineEdit()
        bedding_length_validator = QDoubleValidator(1.0, 100.0, 1)
        self.__beddingLengthLineEdit.setValidator(bedding_length_validator)
        self.__beddingLengthLineEdit.setText('10')
        bedding_buffer_label = QLabel(self.tr('Buffer (m): '))
        bedding_buffer_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__beddingBufferLineEdit = QLineEdit()
        bedding_buffer_validator = QDoubleValidator(10.0, 2000.0, 1)
        self.__beddingBufferLineEdit.setValidator(bedding_buffer_validator)
        self.__beddingBufferLineEdit.setText('100')
        # Add bedding widgets
        bedding_layout.addWidget(bedding_label, 0, 0)
        bedding_layout.addWidget(self.__beddingLayerComboBox, 0, 1, 1, 4)
        bedding_layout.addWidget(bedding_azimuth_label, 1, 0)
        bedding_layout.addWidget(self.__beddingAzimuthFieldComboBox, 1, 1)
        bedding_layout.addItem(QSpacerItem(5, 15, QSizePolicy.Minimum, QSizePolicy.Minimum), 1, 2)
        bedding_layout.addWidget(bedding_dip_label, 1, 3)
        bedding_layout.addWidget(self.__beddingDipFieldComboBox, 1, 4)
        bedding_layout.addWidget(bedding_length_label, 2, 0)
        bedding_layout.addWidget(self.__beddingLengthLineEdit, 2, 1)
        bedding_layout.addWidget(bedding_buffer_label, 2, 3)
        bedding_layout.addWidget(self.__beddingBufferLineEdit, 2, 4)
        bedding_groupbox.setLayout(bedding_layout)

        # Geology preview button
        self.profileButton = QPushButton(self.tr('Preview profile'))
        self.profileButton.setFixedWidth(150)

        # Add geology widgets
        geology_layout.addWidget(lithology_groupbox)
        geology_layout.addSpacing(15)
        geology_layout.addWidget(structural_groupbox)
        geology_layout.addSpacing(15)
        geology_layout.addWidget(bedding_groupbox)
        geology_layout.addSpacing(15)
        geology_layout.addWidget(self.profileButton, alignment=Qt.AlignCenter)

        # Add geology layout
        geology_tab.setLayout(geology_layout)

        # Export layout
        export_layout = QVBoxLayout()

        # Settings input
        settings_groupbox = QGroupBox(self.tr('Graph settings'))
        settings_layout = QGridLayout()
        title_label = QLabel(self.tr('Section title: '))
        title_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__sectionTitleLineEdit = QLineEdit()
        section_start_label = QLabel(self.tr('Start label: '))
        section_start_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__sectionStartLabelLineEdit = QLineEdit()
        section_end_label = QLabel(self.tr('End label: '))
        section_end_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__sectionEndLabelLineEdit = QLineEdit()
        figure_size_label = QLabel(self.tr('Figure size: '))
        figure_size_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__figureSizeComboBox = QComboBox()
        self.__figureSizeComboBox.addItems(['A5', 'A4', 'A3', 'A2', 'A1', 'A0'])
        self.__figureOrientationButton = QRadioButton(self.tr('Landscape'))
        self.__figureOrientationButton.setChecked(True)
        # Add settings widgets
        settings_layout.addWidget(title_label, 0, 0)
        settings_layout.addWidget(self.__sectionTitleLineEdit, 0, 1, 1, 4)
        settings_layout.addWidget(section_start_label, 1, 0)
        settings_layout.addWidget(self.__sectionStartLabelLineEdit, 1, 1)
        settings_layout.addItem(QSpacerItem(15, 10, QSizePolicy.Minimum, QSizePolicy.Minimum), 1, 2)
        settings_layout.addWidget(section_end_label, 1, 3)
        settings_layout.addWidget(self.__sectionEndLabelLineEdit, 1, 4)
        settings_layout.addWidget(figure_size_label, 2, 0)
        settings_layout.addWidget(self.__figureSizeComboBox, 2, 1)
        settings_layout.addWidget(self.__figureOrientationButton, 2, 3, 1, 2)
        settings_groupbox.setLayout(settings_layout)

        # Output path
        output_groupbox = QGroupBox(self.tr('Output path'))
        output_layout = QHBoxLayout()
        output_label = QLabel(self.tr('Path: '))
        output_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
        self.__outputPathWidget = QgsFileWidget()
        self.__outputPathWidget.setStorageMode(QgsFileWidget.SaveFile)
        self.__outputPathWidget.setFilter('SVG (*.svg);;PDF (*.pdf);;PNG (*.png)')
        output_layout.addWidget(output_label)
        output_layout.addWidget(self.__outputPathWidget)
        output_groupbox.setLayout(output_layout)

        # Save button
        self.saveButton = QPushButton(self.tr('Save'))
        self.saveButton.setFixedWidth(150)

        # Add export widgets
        export_layout.addWidget(settings_groupbox)
        export_layout.addSpacing(15)
        export_layout.addWidget(output_groupbox)
        export_layout.addSpacing(15)
        export_layout.addWidget(self.saveButton, alignment=Qt.AlignCenter)

        # Add export layout
        export_tab.setLayout(export_layout)

        # Add tabs
        self.addTab(topography_tab, self.tr('Topography'))
        self.addTab(geology_tab, self.tr('Geology'))
        self.addTab(export_tab, self.tr('Export'))

        # Load data
        self.setSectionFields(layer=self.__sectionLayerComboBox.currentLayer())

        # Signals
        self.__sectionLayerComboBox.layerChanged.connect(self.setSectionFields)
        self.__sectionFieldComboBox.fieldChanged.connect(self.setSectionFeatures)
        self.__lithologyLayerComboBox.layerChanged.connect(self.setLithologyFields)
        self.__lithologyLayerComboBox.layerChanged.connect(self.__onLithologyLayerChanged)
        self.__lithologyFieldComboBox.fieldChanged.connect(self.__onLithologyFieldChanged)
        self.__structuralLayerComboBox.layerChanged.connect(self.setStructuralFields)
        self.__beddingLayerComboBox.layerChanged.connect(self.setBeddingFields)

    def __onLithologyLayerChanged(self, layer):
        self.lithologyLayerChanged.emit(layer)

    def __onLithologyFieldChanged(self, field):
        self.lithologyFieldChanged.emit(field)

    def setSectionFields(self, layer):

        if not layer:
            return

        layer_fields = layer.fields()

        fields = QgsFields()
        fields.append(QgsField('Feature ID', QVariant.Int))

        for field in layer_fields:
            fields.append(field)

        self.__sectionFieldComboBox.setFields(fields)
        self.setSectionFeatures(field=self.__sectionFieldComboBox.currentText())

    def setSectionFeatures(self, field):

        layer = self.__sectionLayerComboBox.currentLayer()

        if not layer:
            return

        self.__sectionFeatureComboBox.clear()

        values = set()
        for feature in layer.getFeatures():
            values.add(str(feature.id()) if field == 'Feature ID' else str(feature.attribute(field)))

        self.__sectionFeatureComboBox.addItems(sorted(values))

    def setLithologyFields(self, layer):

        if not layer:
            return

        self.__lithologyFieldComboBox.setLayer(layer)

    def setStructuralFields(self, layer):
        self.__structuralFieldComboBox.setLayer(layer)

    def setBeddingFields(self, layer):
        self.__beddingAzimuthFieldComboBox.setLayer(layer)
        self.__beddingDipFieldComboBox.setLayer(layer)

    def demLayer(self):
        return self.__demLayerComboBox.currentLayer()

    def profileDepth(self):
        depth = self.__depthLineEdit.text()
        return float(depth) if depth else 0.0

    def sectionLayer(self):
        return self.__sectionLayerComboBox.currentLayer()

    def sectionField(self):
        return self.__sectionFieldComboBox.currentField()

    def sectionFeature(self):
        return self.__sectionFeatureComboBox.currentText()

    def invertSection(self):
        return self.__invertSectionCheckBox.isChecked()

    def lithologyLayer(self):
        return self.__lithologyLayerComboBox.currentLayer()

    def lithologyField(self):
        return self.__lithologyFieldComboBox.currentField()

    def structuralLayer(self):
        return self.__structuralLayerComboBox.currentLayer()

    def structuralField(self):
        return self.__structuralFieldComboBox.currentField()

    def structuralColor(self):
        return self.__structuralColor.color()

    def beddingLayer(self):
        return self.__beddingLayerComboBox.currentLayer()

    def beddingAzimuthField(self):
        return self.__beddingAzimuthFieldComboBox.currentField()

    def beddingDipField(self):
        return self.__beddingDipFieldComboBox.currentField()

    def beddingLength(self):
        length = self.__beddingLengthLineEdit.text()
        return float(length) if length else 10.0

    def beddingBuffer(self):
        buffer = self.__beddingBufferLineEdit.text()
        return float(buffer) if buffer else 100.0

    def outputPath(self):
        return self.__outputPathWidget.filePath()

    def sectionTitle(self):
        return self.__sectionTitleLineEdit.text()

    def sectionStartLabel(self):
        return self.__sectionStartLabelLineEdit.text()

    def sectionEndLabel(self):
        return self.__sectionEndLabelLineEdit.text()

    def figureSize(self):
        return self.__figureSizeComboBox.currentText()

    def figureOrientation(self):
        if self.__figureOrientationButton.isChecked():
            return 'landscape'
        else:
            return 'portrait'



class ProfileWidget(QWidget):

    def __init__(self, tr, parent=None):
        super().__init__(parent)
        self.setWindowFlags(Qt.Window)
        self.tr = tr

        # Layout
        layout = QVBoxLayout()

        # Canvas
        self.figure, self.ax = plt.subplots(figsize=(8, 5))
        self.ax.set_ylabel(self.tr('Elevation'), fontsize=9)
        self.ax.tick_params(axis='both', which='both', labelsize=8, labelright=True, right=True)
        self.ax.spines.top.set_visible(False)
        self.ax.minorticks_on()
        self.ax.set_axisbelow(True)
        self.canvas = FigureCanvasQTAgg(self.figure)

        # Add widgets
        layout.addWidget(self.canvas)
        self.setLayout(layout)

    def setTopographicProfile(self, distances, elevations, depth):

        self.ax.set_title(self.tr('Topographic Profile'), fontsize=11, pad=25)
        self.ax.plot(distances, elevations, color='black')
        self.ax.set_xlim(min(distances), max(distances))
        self.ax.set_ylim(min(elevations) - depth, max(elevations) + (0.1 * max(distances)))
        self.ax.set_aspect('equal', adjustable='box')
        self.figure.tight_layout()


    def setGeologicalProfile(self, lithology, depth, x_limits, y_limits, structures=None,
                             bedding=None, bedding_length=None):

        self.ax.set_title(self.tr('Geological Profile'), fontsize=11, pad=25)

        for l in lithology:
            self.ax.plot(l.distances, l.elevations, label=l.name, color=l.color)

        self.ax.set_xlim(x_limits[0], x_limits[1])
        self.ax.set_ylim(y_limits[0] - depth, y_limits[1] + (0.1 * x_limits[1]))

        if structures:
            s_x = [s.distance for s in structures]
            s_y = [s.elevation for s in structures]
            s_colors = [s.color for s in structures]
            s_names = [s.name for s in structures]

            self.ax.scatter(s_x, s_y, color=s_colors, marker='+', s=80, label=s_names)

            for i, name in enumerate(s_names):
                self.ax.annotate(name, xy=(s_x[i], s_y[i] + (0.015 * x_limits[1])),
                                 xytext=(s_x[i], y_limits[1] + (0.05 * x_limits[1])), fontsize=9, ha='center',
                                 arrowprops=dict(arrowstyle="->, head_width=0.25", lw=1, shrinkA=0, shrinkB=0))

        if bedding:

            length = bedding_length if bedding_length else 50

            for bed in bedding:
                beta = math.radians(bed.apparent_dip)
                dx = length * math.cos(beta)
                dz = - length * math.sin(beta)

                if bed.dip_inclination == 'left':
                    dx = - dx

                self.ax.plot([bed.distance, bed.distance + dx], [bed.elevation, bed.elevation + dz], color='black')

        self.ax.set_aspect('equal', adjustable='box')



class ClassificationColorsWidget(QDialog):

    def __init__(self, tr, parent=None):
        super().__init__(parent)
        self.tr = tr
        self.setupUI()

    def setupUI(self):
        # Layout
        layout = QVBoxLayout()

        # Lithology tree widget
        self.treeWidget = QTreeWidget()
        self.treeWidget.setColumnCount(2)
        self.treeWidget.setHeaderLabels([self.tr('Name'), self.tr('Color')])
        self.treeWidget.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded)
        self.treeWidget.setDragEnabled(False)
        self.treeWidget.setDragDropMode(QAbstractItemView.NoDragDrop)
        self.treeWidget.setAlternatingRowColors(True)
        self.treeWidget.setTextElideMode(Qt.ElideLeft)

        # Buttons
        buttons = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
        buttons.accepted.connect(self.accept)
        buttons.rejected.connect(self.reject)

        # Add layout
        layout.addWidget(self.treeWidget)
        layout.addWidget(buttons)
        self.setLayout(layout)

    def setTitle(self, name: str):
        self.setWindowTitle(name)

    def insertItem(self, item, color):
        item_widget = QTreeWidgetItem(self.treeWidget)
        item_widget.setText(0, str(item))
        item_widget.setData(0, Qt.UserRole, item)

        item_color = QgsColorButton()
        item_color.setColor(QColor(color))

        self.treeWidget.setItemWidget(item_widget, 1, item_color)

    def getColors(self):
        colors = {}
        for i in range(self.treeWidget.topLevelItemCount()):
            item = self.treeWidget.topLevelItem(i)
            key = item.data(0, Qt.UserRole)
            color_button = self.treeWidget.itemWidget(item, 1)
            if color_button:
                colors[key] = color_button.color()

        return colors



