# -*- coding: utf-8 -*-
"""
/***************************************************************************
 SGTool
                                 A QGIS plugin
 Simple Potential Field Processing
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2024-11-17
        git sha              : $Format:%H$
        copyright            : (C) 2024 by Mark Jessell
        email                : mark.jessell@uwa.edu.au
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

from qgis.PyQt.QtCore import QSettings, QTranslator, QCoreApplication, Qt
from qgis.core import QgsMapLayerProxyModel
from qgis.PyQt.QtCore import QUrl
from qgis.PyQt.QtGui import QIcon, QDesktopServices
from qgis.PyQt.QtWidgets import QAction
from qgis.core import (
    Qgis,
    QgsCoordinateReferenceSystem,
    QgsVectorLayer,
    QgsProject,
    QgsRasterLayer,
    QgsSingleBandGrayRenderer,
    QgsFeature,
    QgsField,
    QgsProcessingFeedback,
)

from qgis.PyQt.QtWidgets import QAction, QFileDialog
from qgis.PyQt.QtCore import (
    QSettings,
    QTranslator,
    QCoreApplication,
    QFileInfo,
    QVariant,
    Qt,
)
from qgis.PyQt.QtWidgets import QDockWidget
from qgis.PyQt.QtCore import Qt

from qgis.core import QgsRasterLayer
from qgis.core import (
    QgsProject,
    QgsVectorLayer,
    QgsFeature,
    QgsGeometry,
    QgsField,
    QgsFields,
    QgsPointXY,
)
from qgis.PyQt.QtCore import QVariant
import re

# Initialize Qt resources from file resources.py
from .resources import *

# Import the code for the DockWidget
from .SGTool_dockwidget import SGToolDockWidget
from .calcs.GeophysicalProcessor import GeophysicalProcessor
from .calcs.geosoft_grid_parser import *
from .calcs.PSplot import PowerSpectrumDock
from .calcs.ConvolutionFilter import ConvolutionFilter
from .calcs.ConvolutionFilter import OddPositiveIntegerValidator
from .calcs.GridData_no_pandas import GridData
from .calcs.GridData_no_pandas import QGISGridData

from .calcs.SG_Util import SG_Util
from .igrf.igrf_utils import igrf_utils as IGRF
import os.path
import numpy as np
import subprocess
from scipy.spatial import cKDTree
from scipy import interpolate
import tempfile

from osgeo import gdal, osr

from datetime import datetime
from pyproj import Transformer
import processing
from osgeo import gdal, osr
import platform


class SGTool:
    """QGIS Plugin Implementation."""

    def __init__(self, iface):
        """Constructor.

        :param iface: An interface instance that will be passed to this class
            which provides the hook by which you can manipulate the QGIS
            application at run time.
        :type iface: QgsInterface
        """
        # Save reference to the QGIS interface
        self.iface = iface

        # initialize plugin directory
        self.plugin_dir = os.path.dirname(__file__)

        # initialize locale
        locale = str(QSettings().value("locale/userLocale"))[0:2]
        locale_path = os.path.join(
            self.plugin_dir, "i18n", "SGTool_{}.qm".format(locale)
        )

        if os.path.exists(locale_path):
            self.translator = QTranslator()
            self.translator.load(locale_path)
            QCoreApplication.installTranslator(self.translator)

        # Declare instance attributes
        self.actions = []
        self.menu = self.tr("&SGTool")
        # TODO: We are going to let the user set this up in a future iteration
        self.toolbar = self.iface.addToolBar("SGTool")
        self.toolbar.setObjectName("SGTool")

        # print "** INITIALIZING SGTool"

        self.pluginIsActive = False
        self.dlg = None
        self.last_directory = None

        def install_library(library_name):
            try:
                if platform.system == "Windows":
                    subprocess.check_call(
                        ["python", "-m", "pip", "install", library_name]
                    )
                else:
                    subprocess.check_call(
                        ["python3", "-m", "pip3", "install", library_name]
                    )
                print(f"Successfully installed {library_name}")
            except subprocess.CalledProcessError as e:
                print(f"Error installing {library_name}: {e}")

        # Library shapely
        try:
            import shapely

        except:
            install_library("shapely")
            import shapely

        # Library scikit-learn
        try:
            import sklearn

        except:
            install_library("scikit-learn")
            import sklearn

    # noinspection PyMethodMayBeStatic
    def tr(self, message):
        """Get the translation for a string using Qt translation API.

        We implement this ourselves since we do not inherit QObject.

        :param message: String for translation.
        :type message: str, QString

        :returns: Translated version of message.
        :rtype: QString
        """
        # noinspection PyTypeChecker,PyArgumentList,PyCallByClass
        return QCoreApplication.translate("SGTool", message)

    def add_action(
        self,
        icon_path,
        text,
        callback,
        enabled_flag=True,
        add_to_menu=True,
        add_to_toolbar=True,
        status_tip=None,
        whats_this=None,
        parent=None,
    ):
        """Add a toolbar icon to the toolbar.

        :param icon_path: Path to the icon for this action. Can be a resource
            path (e.g. ':/plugins/foo/bar.png') or a normal file system path.
        :type icon_path: str

        :param text: Text that should be shown in menu items for this action.
        :type text: str

        :param callback: Function to be called when the action is triggered.
        :type callback: function

        :param enabled_flag: A flag indicating if the action should be enabled
            by default. Defaults to True.
        :type enabled_flag: bool

        :param add_to_menu: Flag indicating whether the action should also
            be added to the menu. Defaults to True.
        :type add_to_menu: bool

        :param add_to_toolbar: Flag indicating whether the action should also
            be added to the toolbar. Defaults to True.
        :type add_to_toolbar: bool

        :param status_tip: Optional text to show in a popup when mouse pointer
            hovers over the action.
        :type status_tip: str

        :param parent: Parent widget for the new action. Defaults None.
        :type parent: QWidget

        :param whats_this: Optional text to show in the status bar when the
            mouse pointer hovers over the action.

        :returns: The action that was created. Note that the action is also
            added to self.actions list.
        :rtype: QAction
        """

        icon = QIcon(icon_path)
        action = QAction(icon, text, parent)
        action.triggered.connect(callback)
        action.setEnabled(enabled_flag)

        if status_tip is not None:
            action.setStatusTip(status_tip)

        if whats_this is not None:
            action.setWhatsThis(whats_this)

        if add_to_toolbar:
            self.toolbar.addAction(action)

        if add_to_menu:
            self.iface.addPluginToMenu(self.menu, action)

        self.actions.append(action)

        return action

    def initGui(self):
        """Create the menu entries and toolbar icons inside the QGIS GUI."""

        icon_path = ":/plugins/SGTool/icon.png"
        self.add_action(
            icon_path,
            text=self.tr("SGTool"),
            callback=self.run,
            parent=self.iface.mainWindow(),
        )

    # --------------------------------------------------------------------------

    def onClosePlugin(self):
        """Cleanup necessary items here when plugin dockwidget is closed"""

        # print "** CLOSING SGTool"

        # disconnects
        self.dlg.closingPlugin.disconnect(self.onClosePlugin)

        # remove this statement if dockwidget is to remain
        # for reuse if plugin is reopened
        # Commented next statement since it causes QGIS crashe
        # when closing the docked window:
        # self.dlg = None

        self.pluginIsActive = False

    def unload(self):
        """Removes the plugin menu item and icon from QGIS GUI."""

        # print "** UNLOAD SGTool"

        for action in self.actions:
            self.iface.removePluginMenu(self.tr("&SGTool"), action)
            self.iface.removeToolBarIcon(action)
        # remove the toolbar
        del self.toolbar

    def define_tips(self):
        self.dlg.mMapLayerComboBox_selectGrid.setToolTip("File selected for processing")
        self.dlg.mMapLayerComboBox_selectGrid_Conv.setToolTip(
            "File selected for processing"
        )
        self.dlg.pushButton_2_selectGrid.setToolTip("Load new file for processing")
        self.dlg.checkBox_3_DirClean.setToolTip(
            "Filter [DirC} a specific direction and wavelength,\nUseful for filtering flight line noise"
        )

        self.dlg.lineEdit_3_azimuth.setToolTip(
            "Azimuth of high frequency noise to be filtered (degrees clockwise from North)"
        )

        self.dlg.checkBox_4_RTE_P.setToolTip(
            "Reduction to pole or equator\nThe reduction to the pole (RTP) or to Equator (RTE) is a process in geophysics\nwhere magnetic data are transformed to look as though\n they were measured at the magnetic pole/equator\nCorrects the asymmetry of magnetic anomalies caused by\n the Earth's field, making them appear directly above their sources"
        )
        self.dlg.pushButton_4_calcIGRF.setToolTip(
            "Calculate IGRF Inc & Dec based on centroid of selected grid and specified survey height and date"
        )
        self.dlg.comboBox_3_rte_p_list.setToolTip(
            "Choose Pole(high mag latitudes)\n or Equator (low mag latitudes)"
        )
        self.dlg.lineEdit_6_inc.setToolTip(
            "Manually define magnetic inclination [degrees from horizontal]"
        )
        self.dlg.lineEdit_5_dec.setToolTip(
            "Manually define magnetic declination [degrees clockwise from North]"
        )
        self.dlg.lineEdit_6_int.setToolTip("Survey intensity in nT")
        self.dlg.dateEdit.setToolTip("Survey date (1900-2025)")
        self.dlg.checkBox_4_PGrav.setToolTip(
            "Vertical Integration:\nWhen applied to RTE/P result converts magnetic anomalies into gravity-like anomalies (i.e. same decay with distance from source) for comparison or joint interpretation\nAlso good for stitched grids with very different line spacing."
        )

        self.dlg.checkBox_5_regional.setToolTip(
            "Remove regional (RR) based on wavelenth"
        )
        self.dlg.lineEdit_9_removeReg_wavelength.setToolTip(
            "Wavelength to define regional [m or other length unit]"
        )
        self.dlg.checkBox_6_derivative.setToolTip(
            "Calculate derivate (d+power+direction) parallel to x, y or z\nHighlights near-surface/short-wavelength features"
        )

        self.dlg.comboBox_derivDirection.setToolTip("Select derivative direction")
        self.dlg.lineEdit_9_derivePower.setToolTip("Power of derivative")
        self.dlg.checkBox_7_tiltDerivative.setToolTip(
            "Tilt Derivative (TD)\nIt is often applied to magnetic or gravity data to enhance edges and detect shallow sources\nTends to overconnect structural features"
        )

        self.dlg.checkBox_8_analyticSignal.setToolTip(
            "Analytic Signal (AS)\nIt combines horizontal and vertical derivatives to highlight anomaly edges and amplitude variations, independent of direction"
        )

        self.dlg.checkBox_9_continuation.setToolTip(
            "Upward or downward continuation\nUpward Continuation (UC) data by continuing it to a higher altitude, attenuating high-frequency noise and shallow features\nDownward Continuation (DC) enhances shallow or high-frequency anomalies by continuing the field to a lower altitude"
        )

        self.dlg.comboBox_2_continuationDirection.setToolTip(
            "Select direction of continuation"
        )
        self.dlg.lineEdit_10_continuationHeight.setToolTip(
            "Select amount of continuation [m only]"
        )
        self.dlg.checkBox_10_bandPass.setToolTip(
            "Band pass filter (BP)\nIsolates specific wavelength features."
        )

        self.dlg.lineEdit_12_bandPassLow.setToolTip(
            "Low wavelength cutoff [m or other length unit]"
        )
        self.dlg.lineEdit_11_bandPassHigh.setToolTip(
            "High wavelength cutoff [m or other length unit]"
        )
        self.dlg.checkBox_10_freqCut.setToolTip(
            "High or Low pass filter\nIsolates specific short wavelength (HP) or long wavelength (LP) features."
        )

        self.dlg.comboBox_2_FreqCutType.setToolTip("Cut off type")
        self.dlg.lineEdit_12_FreqPass.setToolTip(
            "Cutoff wavelength [m or other length unit]"
        )
        self.dlg.checkBox_11_1vd_agc.setToolTip(
            "Automatic Gain Control (AGC) or Amplitude Normalisation\nHighlights short wavelength/low amplitude features"
        )

        self.dlg.lineEdit_13_agc_window.setToolTip("Window size for normalisation")
        self.dlg.pushButton_3_applyProcessing.setToolTip(
            "Apply selected processing steps in parallel to selected grid"
        )
        self.dlg.pushButton_3_applyProcessing_Conv.setToolTip(
            "Apply selected processing steps in parallel to selected grid"
        )
        self.dlg.lineEdit_13_max_buffer.setToolTip(
            "Maximum buffer to be applied to grid to reduce edge effects"
        )
        self.dlg.checkBox_11_tot_hz_grad.setToolTip(
            "Total Horizointal Gradient Calculation"
        )
        self.dlg.pushButton_rad_power_spectrum.setToolTip(
            "Provides pop-up display of grid plus Radial Averaged Power Spectrum (needs testing!)"
        )

        self.dlg.checkBox_Mean.setToolTip("Mean of values around central pixel")

        self.dlg.checkBox_Median.setToolTip("Median of values around central pixel")

        self.dlg.checkBox_Gaussian.setToolTip("Gaussian smoothing of image")

        self.dlg.checkBox_Directional.setToolTip("Directional enhancement")

        self.dlg.checkBox_SunShading.setToolTip("Sun Shading")

        self.dlg.pushButton_selectPoints.setToolTip(
            "Select csv or xyz format points file"
        )
        self.dlg.comboBox_grid_x.setToolTip(
            "Define X coordinate column (for csv files)"
        )
        self.dlg.comboBox_grid_y.setToolTip(
            "Define Y coordinate column (for csv files)"
        )
        self.dlg.mQgsProjectionSelectionWidget.setToolTip(
            "DEfine Coordinate System of point data"
        )
        self.dlg.checkBox_load_tie_lines.setToolTip(
            "For xyz format files optionally load Tie lines"
        )
        self.dlg.pushButton_load_point_data.setToolTip(
            "Load points file and convert to layer\nWith polyline layer of lines for xyz format files"
        )
        self.dlg.mMapLayerComboBox_selectGrid_3.setToolTip(
            "Select from currently loaded points files for gridding"
        )
        self.dlg.comboBox_select_grid_data_field.setToolTip("Select field to grid")
        self.dlg.doubleSpinBox_cellsize.setToolTip("Define cell size in layer units")

        self.dlg.pushButton_2_selectGrid_RGB.setToolTip(
            "Select RGB image that you want to attempt to convert to a monotonic grayscale image"
        )
        self.dlg.textEdit_2_colour_list.setToolTip(
            "Comma separated list of CSS colours, leave blank to get link to list of colours"
        )
        self.dlg.groupBox_7.setToolTip(
            "1) Load a RGB raster image,\n2) Define a Look Up Table by defining a comma separated sequence of colours using CSS colour names and\n3) Convert to monotonically increasing greyscale image\n\nDo not use if any shading has been applied to the image!"
        )
        self.dlg.mQgsDoubleSpinBox_LUT_min.setToolTip(
            "Define min and max values for rescaling of grid values"
        )
        self.dlg.mQgsDoubleSpinBox_LUT_max.setToolTip(
            "Define min and max values for rescaling of grid values"
        )
        self.dlg.pushButton_CSSS_Colours.setToolTip(
            "See full suite of CSS Colours (requires network connection)"
        )
        self.dlg.lineEdit.setToolTip("Example colour sequence, can be copy pasted")

        self.dlg.spinBox_levels.setToolTip("Number of levels")
        self.dlg.doubleSpinBox_base.setToolTip("Lowest height to worm")
        self.dlg.doubleSpinBox_inc.setToolTip("Increment in metres between levels ")
        self.dlg.groupBox_8.setToolTip("Create csv file of worms using bsdwormer code")
        self.dlg.mMapLayerComboBox_selectGrid_worms.setToolTip(
            "Grid to be wormed\n[Must be gravity or RTP_E + Vertical Integration of mag]"
        )

        self.dlg.checkBox_NaN.setToolTip("Threshold background values to NaN")
        self.dlg.radioButton_NaN_Above.setToolTip(
            "Threshold background values above set value to NaN"
        )
        self.dlg.radioButton_NaN_Below.setToolTip(
            "Threshold background values below set value to NaN"
        )
        self.dlg.radioButton_NaN_Both.setToolTip(
            "Threshold background values between set values to NaN"
        )
        self.dlg.doubleSpinBox_NaN_Above.setToolTip("Upper threshold value")

        self.dlg.checkBox_worms_shp.setToolTip(
            "Convert worms to polyline shapefile\n(Can be slow and best to start worms at >=2000m)"
        )

        self.dlg.pushButton_select_normalise_in.setToolTip(
            "Directory with geotiffs to be normalised"
        )

        self.dlg.pushButton_select_normalise_out.setToolTip(
            "Directory where normalised geotiffs will be stored"
        )

        self.dlg.pushButton_normalise.setToolTip(
            "Normalise geotiffs:\n1. First order regional removed\n2. Normalise standard deviation using alpahbetical first file in input directory as reference\n3. Remove mean"
        )
        self.dlg.groupBox_10.setToolTip(
            "Normalise directory of geotiffs\nOnce normalised the QGIS merge tool produced a reasonable stitch of the grids\nAssumes same processing level for grids\nAssumes flight heights have been normalised by continuation\nMerge uses alpahbetical first file to define cell size\nAll grids in merge have to be same projection"
        )

        self.dlg.radioButton_normalise_1st.setToolTip(
            "1st order (flat plane) regional removed from grid"
        )

        self.dlg.radioButton_normalise_2nd.setToolTip(
            "2nd order polynomial regional removed from grid (slower)"
        )

    def initParams(self):
        self.localGridName = ""
        self.diskGridPath = ""
        self.diskPointsPath = ""
        self.buffer = 0
        self.DirClean = False
        self.DC_azimuth = 0
        self.DC_lineSpacing = 400
        self.RTE_P = False
        self.RTE_P_type = "Pole"
        self.RTE_P_inc = 0
        self.RTE_P_dec = 0
        self.RTE_P_height = 0
        self.RTE_P_date = [1, 1, 2000]
        self.RemRegional = False
        self.remReg_wavelength = 5000
        self.Derivative = False
        self.derive_direction = "z"
        self.derive_power = 1.0
        self.TDR = False
        self.AS = False
        self.Continuation = False
        self.cont_direction = "up"
        self.cont_height = 500
        self.BandPass = False
        self.band_low = 5
        self.band_high = 50
        self.AGC = False
        self.agc_window = 10
        self.FreqCut = False
        self.FreqCut_type = "Low"
        self.FreqCut_cut = 1000
        self.VI = False
        self.THG = False

        self.Mean = False
        self.mean_size = 3
        self.Median = False
        self.median_size = 3
        self.Gaussian = False
        self.gauss_rad = 1
        self.Direction = "N"
        self.SunShade = False
        self.sun_shade_az = 45
        self.sun_shade_zn = 45

        self.pointType = "point"
        self.input_directory = ""
        self.output_directory = ""

    def parseParams(self):

        self.DirClean = self.dlg.checkBox_3_DirClean.isChecked()
        self.DC_azimuth = self.dlg.lineEdit_3_azimuth.text()

        self.RTE_P = self.dlg.checkBox_4_RTE_P.isChecked()
        self.RTE_P_type = self.dlg.comboBox_3_rte_p_list.currentText()
        self.RTE_P_inc = self.dlg.lineEdit_6_inc.text()
        self.RTE_P_dec = self.dlg.lineEdit_5_dec.text()
        self.RTE_P_int = self.dlg.lineEdit_6_int.text()
        date_text = str(self.dlg.dateEdit.date().toPyDate())
        date_split = date_text.split("-")
        self.RTE_P_date = [int(date_split[2]), int(date_split[1]), int(date_split[0])]

        self.RemRegional = self.dlg.checkBox_5_regional.isChecked()
        self.remReg_wavelength = self.dlg.lineEdit_9_removeReg_wavelength.text()

        self.Derivative = self.dlg.checkBox_6_derivative.isChecked()
        self.derive_direction = self.dlg.comboBox_derivDirection.currentText()
        self.derive_power = self.dlg.lineEdit_9_derivePower.text()

        self.TDR = self.dlg.checkBox_7_tiltDerivative.isChecked()

        self.AS = self.dlg.checkBox_8_analyticSignal.isChecked()

        self.Continuation = self.dlg.checkBox_9_continuation.isChecked()
        self.cont_direction = self.dlg.comboBox_2_continuationDirection.currentText()
        self.cont_height = self.dlg.lineEdit_10_continuationHeight.text()

        self.BandPass = self.dlg.checkBox_10_bandPass.isChecked()
        self.band_low = self.dlg.lineEdit_12_bandPassLow.text()
        if float(self.band_low) <= 0.0:
            self.band_low = 1e-10
        self.band_high = self.dlg.lineEdit_11_bandPassHigh.text()

        self.AGC = self.dlg.checkBox_11_1vd_agc.isChecked()
        self.agc_window = self.dlg.lineEdit_13_agc_window.text()

        self.FreqCut = self.dlg.checkBox_10_freqCut.isChecked()
        self.FreqCut_type = self.dlg.comboBox_2_FreqCutType.currentText()
        self.FreqCut_cut = self.dlg.lineEdit_12_FreqPass.text()

        self.VI = self.dlg.checkBox_4_PGrav.isChecked()
        self.THG = self.dlg.checkBox_11_tot_hz_grad.isChecked()

        self.Mean = self.dlg.checkBox_Mean.isChecked()
        self.mean_conv_size = int(self.dlg.lineEdit_Mean_size.text())

        self.Median = self.dlg.checkBox_Median.isChecked()
        self.median_conv_size = int(self.dlg.lineEdit_Median_size.text())

        self.Gaussian = self.dlg.checkBox_Gaussian.isChecked()
        self.gauss_rad = float(self.dlg.lineEdit_Gaussian_Sigma.text())

        self.Direction = self.dlg.checkBox_Directional.isChecked()
        self.directional_dir = self.dlg.comboBox_Dir_dir.currentText()

        self.SunShade = self.dlg.checkBox_SunShading.isChecked()
        self.sun_shade_az = float(self.dlg.lineEdit_SunSh_Az.text())
        self.sun_shade_zn = float(self.dlg.lineEdit_SunSh_Zn.text())

        self.NaN = self.dlg.checkBox_NaN.isChecked()
        if self.dlg.radioButton_NaN_Above.isChecked():
            self.NaN_Condition = "above"
        elif self.dlg.radioButton_NaN_Below.isChecked():
            self.NaN_Condition = "below"
        else:
            self.NaN_Condition = "between"
        self.NaN_Above = float(self.dlg.doubleSpinBox_NaN_Above.text())
        self.NaN_Below = float(self.dlg.doubleSpinBox_NaN_Below.text())

    def loadGrid(self):
        fileInfo = QFileInfo(self.diskGridPath)
        baseName = fileInfo.baseName()

        self.layer = QgsRasterLayer(self.diskGridPath, baseName)
        if not self.is_layer_loaded(baseName):
            QgsProject.instance().addMapLayer(self.layer)

        self.dx = self.layer.rasterUnitsPerPixelX()
        self.dy = self.layer.rasterUnitsPerPixelY()
        # Access the raster data provider
        provider = self.layer.dataProvider()

        # Get raster dimensions
        cols = provider.xSize()  # Number of columns
        rows = provider.ySize()  # Number of rows

        # Read raster data as a block
        band = 1  # Specify the band number (1-based index)
        raster_block = provider.block(band, provider.extent(), cols, rows)

        # Copy the block data into a NumPy array
        extent = self.layer.extent()
        rows, cols = self.layer.height(), self.layer.width()
        raster_block = provider.block(1, extent, cols, rows)  # !!!!!
        self.raster_array = np.zeros((rows, cols))
        for i in range(rows):
            for j in range(cols):
                self.raster_array[i, j] = raster_block.value(i, j)

        # Handle NoData values if needed
        no_data_value = provider.sourceNoDataValue(1)  # Band 1

        if no_data_value is not None:
            self.raster_array[self.raster_array == no_data_value] = np.nan

    def insert_text_before_extension(self, file_path, insert_text):
        """
        Insert text at the end of the filename, before the file extension.

        Parameters:
            file_path (str): Full path of the file.
            insert_text (str): Text to insert before the file extension.

        Returns:
            str: The modified file path.
        """
        # Separate the file path into directory, base name, and extension
        dir_name, base_name = os.path.split(file_path)
        file_name, file_ext = os.path.splitext(base_name)

        # Construct the new file name
        new_file_name = f"{file_name}{insert_text}{file_ext}"

        # Combine directory and new file name
        return os.path.join(dir_name, new_file_name)

    def procRSTGridding(self):

        gridder = QGISGridData(self.iface)

        layer_name = self.dlg.mMapLayerComboBox_selectGrid_3.currentText()
        input = self.get_layer_path_by_name(layer_name)
        zcolumn = self.dlg.comboBox_select_grid_data_field.currentText()
        cell_size = self.dlg.doubleSpinBox_cellsize.text()
        mask = None

        gridder.launch_r_surf_rst_dialog(input, zcolumn, cell_size, mask)

    def procIDWGridding(self):
        gridder = QGISGridData(self.iface)

        layer_name = self.dlg.mMapLayerComboBox_selectGrid_3.currentText()
        input = self.get_layer_path_by_name(layer_name)
        zcolumn = self.dlg.comboBox_select_grid_data_field.currentText()
        cell_size = self.dlg.doubleSpinBox_cellsize.text()

        mask = None

        gridder.launch_idw_dialog(input, zcolumn, cell_size, mask)

    def procbsplineGridding(self):
        gridder = QGISGridData(self.iface)

        layer_name = self.dlg.mMapLayerComboBox_selectGrid_3.currentText()
        input = self.get_layer_path_by_name(layer_name)
        zcolumn = self.dlg.comboBox_select_grid_data_field.currentText()
        cell_size = self.dlg.doubleSpinBox_cellsize.text()

        layer = QgsProject.instance().mapLayersByName(layer_name)[0]
        provider = layer.dataProvider()
        extent = provider.extent()

        mask = None
        gridder.launch_bspline_dialog(input, zcolumn, cell_size, mask)

    def get_layer_path_by_name(self, layer_name):
        """
        Get the file path of a layer given its name.

        :param layer_name: The name of the layer in the QGIS project.
        :return: File path of the layer or None if not found or not a file-based layer.
        """
        # Iterate through all layers in the project
        for layer in QgsProject.instance().mapLayers().values():
            if layer.name() == layer_name:
                # Check if the layer has a data provider and a file source
                if hasattr(layer, "dataProvider") and hasattr(
                    layer.dataProvider(), "dataSourceUri"
                ):
                    # Return the file path (for file-based layers like shapefiles or rasters)
                    return layer.dataProvider().dataSourceUri().split("|")[0]
        return None

    def procDirClean(self):
        cutoff_wavelength = 4 * float(self.DC_lineSpacing)
        if self.unit_check(cutoff_wavelength):
            self.new_grid = self.processor.combined_BHP_DirCos_filter(
                self.raster_array,
                cutoff_wavelength=cutoff_wavelength,
                center_direction=float(self.DC_azimuth) + 90,
                degree=2.0,
                buffer_size=self.buffer,
            )
            self.suffix = "_DirC"

    def procRTP_E(self):
        if self.RTE_P_inc == "0" and self.RTE_P_dec == "0":
            self.iface.messageBar().pushMessage(
                "You need to define Inc and Dec first!", level=Qgis.Warning, duration=15
            )
        else:
            if self.RTE_P_type == "Pole":
                self.new_grid = self.processor.reduction_to_pole(
                    self.raster_array,
                    inclination=float(self.RTE_P_inc),
                    declination=float(self.RTE_P_dec),
                    buffer_size=self.buffer,
                )
                self.suffix = "_RTP"
            else:
                self.new_grid = self.processor.reduction_to_equator(
                    self.raster_array,
                    inclination=float(self.RTE_P_inc),
                    declination=float(self.RTE_P_dec),
                    buffer_size=self.buffer,
                )
                self.new_grid = self.new_grid  # * float(self.RTE_P_int)
                self.suffix = "_RTE"

    def procRemRegional(self):
        cutoff_wavelength = float(self.remReg_wavelength)
        if self.unit_check(cutoff_wavelength):
            self.new_grid = self.processor.remove_regional_trend_fourier(
                self.raster_array,
                cutoff_wavelength=cutoff_wavelength,
                buffer_size=self.buffer,
            )
            self.new_grid = self.raster_array - self.new_grid
            self.suffix = "_RR" + "_" + str(self.remReg_wavelength)

    def procDerivative(self):
        self.new_grid = self.processor.compute_derivative(
            self.raster_array,
            direction=self.derive_direction,
            order=float(self.derive_power),
            buffer_size=self.buffer,
        )
        self.suffix = "_d" + str(self.derive_power) + self.derive_direction

    def procTiltDerivative(self):
        self.new_grid = self.processor.tilt_derivative(
            self.raster_array, buffer_size=self.buffer
        )
        self.suffix = "_TDR"

    def procAnalyticSignal(self):
        self.new_grid = self.processor.analytic_signal(
            self.raster_array, buffer_size=self.buffer
        )
        self.suffix = "_AS"

    def procContinuation(self):
        selected_layer = QgsProject.instance().mapLayersByName(self.localGridName)[0]

        crs = selected_layer.crs()
        if crs.isGeographic():
            height = float(self.cont_height) / 110000
            self.iface.messageBar().pushMessage(
                "Height roughly converted to metres, since this is a geographic projection",
                level=Qgis.Success,
                duration=15,
            )
        else:
            height = float(self.cont_height)
        if self.cont_direction == "up":
            self.new_grid = self.processor.upward_continuation(
                self.raster_array, height=height, buffer_size=self.buffer
            )
            self.suffix = "_UC" + "_" + str(self.cont_height)
        else:
            self.new_grid = self.processor.downward_continuation(
                self.raster_array, height=height, buffer_size=self.buffer
            )
            self.suffix = "_DC" + "_" + str(self.cont_height)

    def procBandPass(self):
        low_cut = float(self.band_low)
        high_cut = float(self.band_high)
        if self.unit_check(low_cut) and self.unit_check(high_cut):
            self.new_grid = self.processor.band_pass_filter(
                self.raster_array,
                low_cut=low_cut,
                high_cut=high_cut,
                buffer_size=self.buffer,
            )
            self.suffix = "_BP" + "_" + str(self.band_low) + "_" + str(self.band_high)

    def procAGC(self):
        self.new_grid = self.processor.automatic_gain_control(
            self.raster_array, window_size=float(self.agc_window)
        )
        self.suffix = "_AGC"

    def procTHG(self):
        self.new_grid = self.processor.total_hz_grad(self.raster_array)
        self.suffix = "_THG"

    def procvInt(self):
        selected_layer = QgsProject.instance().mapLayersByName(self.localGridName)[0]
        crs = selected_layer.crs()
        if crs.isGeographic():
            self.iface.messageBar().pushMessage(
                "Vertical integration requires a metre-based projection system",
                level=Qgis.Warning,
                duration=15,
            )
        else:
            self.new_grid = self.processor.vertical_integration(
                self.raster_array,
                max_wavenumber=None,
                min_wavenumber=1e-4,
                buffer_size=self.buffer,
                buffer_method="mirror",
            )
            self.suffix = "_VI"

    def procFreqCut(self):
        cutoff_wavelength = float(self.FreqCut_cut)
        if self.unit_check(cutoff_wavelength):
            if self.FreqCut_type == "Low":
                self.new_grid = self.processor.low_pass_filter(
                    self.raster_array,
                    cutoff_wavelength=cutoff_wavelength,
                    buffer_size=self.buffer,
                )
                self.suffix = "_LP" + "_" + str(self.FreqCut_cut)
            else:
                self.new_grid = self.processor.high_pass_filter(
                    self.raster_array,
                    cutoff_wavelength=cutoff_wavelength,
                    buffer_size=self.buffer,
                )
                self.suffix = "_HP" + "_" + str(self.FreqCut_cut)

    def procMean(self):
        self.new_grid = self.convolution.mean_filter(
            # self.raster_array,
            self.mean_conv_size
        )
        self.suffix = "_Mn"

    def procMedian(self):
        self.new_grid = self.convolution.median_filter(
            # self.raster_array,
            self.median_conv_size
        )
        self.suffix = "_Md"

    def procGaussian(self):
        self.new_grid = self.convolution.gaussian_filter(
            # self.raster_array,
            self.gauss_rad
        )
        self.suffix = "_Gs"

    def procDirectional(self):
        self.new_grid = self.convolution.directional_filter(
            # self.raster_array,
            self.directional_dir,
            n=3,
        )
        self.suffix = "_Dr"

    def procSunShade(self):
        self.new_grid = self.convolution.sun_shading_filter(
            self.raster_array,
            sun_alt=self.sun_shade_zn,
            sun_az=180 - self.sun_shade_az,
        )
        self.suffix = "_Sh"

    def procNaN(self):
        self.new_grid = self.SG_Util.Threshold2Nan(
            self.raster_array,
            condition=self.NaN_Condition,
            above_threshold_value=self.NaN_Above,
            below_threshold_value=self.NaN_Below,
        )
        self.suffix = "_Clean"

    def procNormalise(self):
        processor = GeophysicalProcessor(None, None, None)
        inpath = self.input_directory
        outpath = self.output_directory
        order = self.dlg.radioButton_normalise_1st.isChecked()
        if (
            os.path.exists(inpath)
            and os.path.exists(outpath)
            and inpath != ""
            and outpath != ""
        ):
            processor.normalise_geotiffs(inpath, outpath, order)

    def procBSDworms(self):
        num_levels = int(self.dlg.spinBox_levels.value())
        bottom_level = int(self.dlg.doubleSpinBox_base.text())
        delta_z = float(self.dlg.doubleSpinBox_inc.text())
        selected_layer = QgsProject.instance().mapLayersByName(self.localGridName)[0]
        crs = selected_layer.crs()

        if crs.isGeographic():
            self.iface.messageBar().pushMessage(
                "This is a geographic projection, you need to specify convert it to a projected CRS",
                level=Qgis.Warning,
                duration=15,
            )
            return False
        else:
            if selected_layer.isValid():
                self.diskGridPath = selected_layer.dataProvider().dataSourceUri()
                self.dx = selected_layer.rasterUnitsPerPixelX()
                self.dy = selected_layer.rasterUnitsPerPixelY()
                crs = int(selected_layer.crs().authid().split(":")[1])

                self.processor = GeophysicalProcessor(self.dx, self.dy, self.buffer)
                shps = self.dlg.checkBox_worms_shp.isChecked()
                self.processor.bsdwormer(
                    self.diskGridPath, num_levels, bottom_level, delta_z, shps, crs
                )
                self.iface.messageBar().pushMessage(
                    "Worms saved to same directory as original grid",
                    level=Qgis.Success,
                    duration=15,
                )

    def set_normalise_in(self):
        self.input_directory = QFileDialog.getExistingDirectory(
            None, "Select Input Folder"
        )

        if os.path.exists(self.input_directory) and self.input_directory != "":
            self.dlg.lineEdit_loadPointsPath_normalise_in.setText(self.input_directory)

        else:
            self.iface.messageBar().pushMessage(
                "Error: Path Incorrect",
                level=Qgis.Critical,
                duration=15,
            )

    def set_normalise_out(self):
        self.output_directory = QFileDialog.getExistingDirectory(
            None, "Select Output Folder"
        )

        if os.path.exists(self.output_directory) and self.output_directory != "":
            self.dlg.lineEdit_loadPointsPath_normalise_out.setText(
                self.output_directory
            )

        else:
            self.iface.messageBar().pushMessage(
                "Error: Path Incorrect",
                level=Qgis.Critical,
                duration=15,
            )

    def util_display_grid(self, grid):
        import matplotlib.pyplot as plt

        plt.imshow(grid, origin="lower", cmap="viridis")
        plt.colorbar(label="Levels")
        plt.title("Grid")
        plt.show()

    def unit_check(self, length):
        selected_layer = QgsProject.instance().mapLayersByName(self.localGridName)[0]
        crs = selected_layer.crs()
        if crs.isGeographic() and length > 100:
            self.iface.messageBar().pushMessage(
                "Since this is a geographic projection, you need to specify lengths in degrees",
                level=Qgis.Warning,
                duration=15,
            )
            return False
        else:
            return True

    def addNewGrid(self):
        if self.suffix:
            if self.is_layer_loaded(self.base_name + self.suffix):
                project = QgsProject.instance()
                layer = project.mapLayersByName(self.base_name + self.suffix)[0]
                project.removeMapLayer(layer.id())

            self.diskNewGridPath = self.insert_text_before_extension(
                self.diskGridPath, self.suffix
            )
            err = self.numpy_array_to_raster(
                self.new_grid,
                self.diskNewGridPath,
                dx=None,
                xmin=None,
                ymax=None,
                reference_layer=self.layer,
                no_data_value=np.nan,
            )
            if err != -1:
                con_raster_layer = QgsRasterLayer(
                    self.diskNewGridPath, self.base_name + self.suffix
                )
                if con_raster_layer.isValid():
                    QgsProject.instance().addMapLayer(con_raster_layer)

                    # Add the layer to the project
                    QgsProject.instance().addMapLayer(con_raster_layer)

                    # Access the raster data provider
                    provider = con_raster_layer.dataProvider()

                    # Calculate statistics for the first band
                    band = 1  # Specify the band number
                    stats = provider.bandStatistics(band)

                    # Create or modify the renderer
                    renderer = con_raster_layer.renderer()
                    if isinstance(renderer, QgsSingleBandGrayRenderer):
                        # Set contrast enhancement
                        contrast_enhancement = renderer.contrastEnhancement()
                        contrast_enhancement.setMinimumValue(stats.minimumValue)
                        contrast_enhancement.setMaximumValue(stats.maximumValue)

                        # Refresh the layer
                        con_raster_layer.triggerRepaint()
                    else:
                        print("Renderer is not a QgsSingleBandGrayRenderer.")

    def processGeophysics_fft(self):
        self.localGridName = self.dlg.mMapLayerComboBox_selectGrid.currentText()
        self.processGeophysics()

    def processGeophysics_conv(self):
        self.localGridName = self.dlg.mMapLayerComboBox_selectGrid_Conv.currentText()
        self.processGeophysics()

    def processGeophysics(self):
        process = False
        """if(os.path.exists(self.diskGridPath) and self.diskGridPath!=""):
            self.parseParams()
            self.loadGrid()

            paths = os.path.split(self.diskGridPath)
            self.base_name = "".join(paths[1].split(".")[:-1])
            provider = self.layer.dataProvider()

            # Get raster dimensions
            cols = provider.xSize()  # Number of columns
            rows = provider.ySize()  # Number of rows
            process=True"""

        if self.localGridName and self.localGridName != "":
            self.parseParams()

            self.layer = QgsProject.instance().mapLayersByName(self.localGridName)[0]
            if self.layer.isValid():
                self.base_name = self.localGridName

                self.diskGridPath = self.layer.dataProvider().dataSourceUri()
                self.dx = self.layer.rasterUnitsPerPixelX()
                self.dy = self.layer.rasterUnitsPerPixelY()
                # Access the raster data provider
                provider = self.layer.dataProvider()

                # Get raster dimensions
                cols = provider.xSize()  # Number of columns
                rows = provider.ySize()  # Number of rows

                # Read raster data as a block
                band = 1  # Specify the band number (1-based index)
                raster_block = provider.block(band, provider.extent(), cols, rows)

                # Copy the block data into a NumPy array
                extent = self.layer.extent()
                rows, cols = self.layer.height(), self.layer.width()
                raster_block = provider.block(1, extent, cols, rows)  # !!!!!
                self.raster_array = np.zeros((rows, cols))
                for i in range(rows):
                    for j in range(cols):
                        self.raster_array[i, j] = raster_block.value(i, j)

                # Handle NoData values if needed
                no_data_value = provider.sourceNoDataValue(1)  # Band 1

                if no_data_value is not None:
                    self.raster_array[self.raster_array == no_data_value] = np.nan

                process = True

        if process:
            self.buffer = min(rows, cols)
            if self.buffer > int(self.dlg.lineEdit_13_max_buffer.text()):
                self.buffer = int(self.dlg.lineEdit_13_max_buffer.text())
            self.processor = GeophysicalProcessor(self.dx, self.dy, self.buffer)
            self.convolution = ConvolutionFilter(self.raster_array)
            self.SG_Util = SG_Util(self.raster_array)
            self.suffix = ""
            if self.DirClean:
                self.procDirClean()
                self.addNewGrid()
            if self.RTE_P:
                self.procRTP_E()
                self.addNewGrid()
            if self.RemRegional:
                self.procRemRegional()
                self.addNewGrid()
            if self.Derivative:
                self.procDerivative()
                self.addNewGrid()
            if self.TDR:
                self.procTiltDerivative()
                self.addNewGrid()
            if self.AS:
                self.procAnalyticSignal()
                self.addNewGrid()
            if self.Continuation:
                self.procContinuation()
                self.addNewGrid()
            if self.BandPass:
                self.procBandPass()
                self.addNewGrid()
            if self.FreqCut:
                self.procFreqCut()
                self.addNewGrid()
            if self.AGC:
                self.procAGC()
                self.addNewGrid()
            if self.VI:
                # self.procRTP_E()
                self.procvInt()
                self.addNewGrid()
            if self.THG:
                self.procTHG()
                self.addNewGrid()

            if self.Mean:
                self.procMean()
                self.addNewGrid()
            if self.Median:
                self.procMedian()
                self.addNewGrid()
            if self.Gaussian:
                self.procGaussian()
                self.addNewGrid()
            if self.Direction:
                self.procDirectional()
                self.addNewGrid()
            if self.SunShade:
                self.procSunShade()
                self.addNewGrid()

            if self.NaN:
                self.procNaN()
                self.addNewGrid()

            self.resetCheckBoxes()

    def resetCheckBoxes(self):
        self.dlg.checkBox_4_RTE_P.setChecked(False)
        self.dlg.checkBox_7_tiltDerivative.setChecked(False)
        self.dlg.checkBox_8_analyticSignal.setChecked(False)
        self.dlg.checkBox_4_PGrav.setChecked(False)
        self.dlg.checkBox_9_continuation.setChecked(False)

        self.dlg.checkBox_3_DirClean.setChecked(False)
        self.dlg.checkBox_5_regional.setChecked(False)
        self.dlg.checkBox_10_bandPass.setChecked(False)
        self.dlg.checkBox_10_freqCut.setChecked(False)
        self.dlg.checkBox_11_1vd_agc.setChecked(False)

        self.dlg.checkBox_6_derivative.setChecked(False)
        self.dlg.checkBox_11_tot_hz_grad.setChecked(False)

        self.dlg.checkBox_Mean.setChecked(False)
        self.dlg.checkBox_Median.setChecked(False)
        self.dlg.checkBox_Gaussian.setChecked(False)
        self.dlg.checkBox_Directional.setChecked(False)
        self.dlg.checkBox_SunShading.setChecked(False)

        self.dlg.checkBox_NaN.setChecked(False)

        self.RTE_P = False
        self.TDR = False
        self.AS = False

        self.DirClean = False
        self.RemRegional = False
        self.Derivative = False
        self.Continuation = False
        self.BandPass = False
        self.AGC = False
        self.FreqCut = False
        self.VI = False
        self.THG = False
        self.Mean = False
        self.Median = False
        self.Gaussian = False
        self.Gaussian = False
        self.Direction = False
        self.SunShade = False

    def is_layer_loaded(self, layer_name):
        """
        Check if a layer with the specified name is already loaded in QGIS.

        Parameters:
            layer_name (str): The name of the layer to check.

        Returns:
            bool: True if the layer is loaded, False otherwise.
        """
        for layer in QgsProject.instance().mapLayers().values():
            if layer.name() == layer_name:
                return True
        return False

    def select_RGBgrid_file(self):
        start_directory = self.last_directory if self.last_directory else os.getcwd()

        self.diskRGBGridPath, _filter = QFileDialog.getOpenFileName(
            None,
            "Select RGB Image File",
            start_directory,
            "Grids (*.TIF;*.tif;*.TIFF;*.tiff)",
        )
        if os.path.exists(self.diskRGBGridPath) and self.diskRGBGridPath != "":
            self.dlg.lineEdit_2_loadGridPath_2.setText(self.diskRGBGridPath)
            self.last_directory = os.path.dirname(self.diskRGBGridPath)

    def processRGB(self):
        self.diskRGBGridPath = self.dlg.lineEdit_2_loadGridPath_2.text()
        if self.diskRGBGridPath != "":
            if os.path.exists(self.diskRGBGridPath):
                LUT_list = self.dlg.textEdit_2_colour_list.toPlainText()

                if LUT_list != "":

                    result, RGBGridPath_gray = self.convert_RGB_to_grey(
                        self.diskRGBGridPath, LUT_list
                    )
                    if result:

                        basename = os.path.basename(RGBGridPath_gray)
                        filename_without_extension = os.path.splitext(basename)[0]

                        self.layer = QgsRasterLayer(
                            RGBGridPath_gray, filename_without_extension
                        )
                        """try:
                            test_proj = self.layer.crs().authid()
                            self.layer.setCrs(test_proj)

                        except:
                            # Define the new CRS (e.g., EPSG:4326 for WGS84)
                            new_crs = QgsCoordinateReferenceSystem("EPSG:4326")
                            # Set the CRS for the raster layer
                            self.layer.setCrs(new_crs)"""
                        if not self.is_layer_loaded(filename_without_extension):
                            QgsProject.instance().addMapLayer(self.layer)

                    else:
                        if RGBGridPath_gray != -3:
                            self.iface.messageBar().pushMessage(
                                "Conversion failed, check CSS colour names",
                                level=Qgis.Warning,
                                duration=15,
                            )

                else:
                    self.iface.messageBar().pushMessage(
                        "First define a CSS Colour list <a href='https://matplotlib.org/stable/gallery/color/named_colors.html#css-colors'> (See here for list of colours)</a>",
                        level=Qgis.Info,
                        duration=15,
                    )

    def select_grid_file(self):
        start_directory = self.last_directory if self.last_directory else os.getcwd()

        self.diskGridPath, _filter = QFileDialog.getOpenFileName(
            None,
            "Select Data File",
            start_directory,
            "Grids (*.TIF;*.tif;*.TIFF;*.tiff;*.grd;*GRD;*.ERS;*.ers)",
        )
        suffix = self.diskGridPath.split(".")[-1].lower()
        epsg = None
        if os.path.exists(self.diskGridPath) and self.diskGridPath != "":
            self.dlg.lineEdit_2_loadGridPath.setText(self.diskGridPath)
            self.dlg.pushButton_3_applyProcessing.setEnabled(True)
            self.last_directory = os.path.dirname(self.diskGridPath)

            if suffix == "grd":
                if os.path.exists(self.diskGridPath + ".xml"):
                    epsg = extract_proj_str(self.diskGridPath + ".xml")
                if epsg == None:
                    epsg = 4326
                    self.iface.messageBar().pushMessage(
                        "No CRS found in XML, default to 4326",
                        level=Qgis.Warning,
                        duration=15,
                    )
                else:
                    self.iface.messageBar().pushMessage(
                        "CRS Read from XML as " + str(epsg),
                        level=Qgis.Info,
                        duration=15,
                    )
                # self.dlg.mQgsProjectionSelectionWidget.setCrs(QgsCoordinateReferenceSystem('EPSG:'+str(epsg)))
                self.save_a_grid(epsg)
            elif suffix == "tif" or suffix == "ers":
                basename = os.path.basename(self.diskGridPath)
                filename_without_extension = os.path.splitext(basename)[0]
                self.layer = QgsRasterLayer(
                    self.diskGridPath, filename_without_extension
                )
                """try:
                    test_proj = self.layer.crs().authid()
                    self.layer.setCrs(test_proj)
                except:
                    # Define the new CRS (e.g., EPSG:4326 for WGS84)
                    new_crs = QgsCoordinateReferenceSystem("EPSG:4326")
                    # Set the CRS for the raster layer
                    self.layer.setCrs(new_crs)"""
                if not self.is_layer_loaded(self.diskGridPath):
                    QgsProject.instance().addMapLayer(self.layer)

    # save grd file as geotiff
    def save_a_grid(self, epsg):

        # load grd file and store in memory
        if self.diskGridPath != "":
            if not os.path.exists(self.diskGridPath):
                self.iface.messageBar().pushMessage(
                    "File: " + self.diskGridPath + " not found",
                    level=Qgis.Warning,
                    duration=3,
                )
            else:
                grid, header, Gdata_type = load_oasis_montaj_grid_optimized(
                    self.diskGridPath
                )
                # grid,header,Gdata_type=load_oasis_montaj_grid(self.diskGridPath)
                if Gdata_type == -1:
                    self.iface.messageBar().pushMessage(
                        "Sorry, can't read 'SHORT' or 'INT' data types at the moment",
                        level=Qgis.Warning,
                        duration=30,
                    )
                    return
                else:
                    directory_path = os.path.dirname(self.diskGridPath)
                    basename = os.path.basename(self.diskGridPath)
                    filename_without_extension = os.path.splitext(basename)[0]
                    self.diskGridPath = (
                        directory_path + "/" + filename_without_extension + ".tif"
                    )

                    fn = self.diskGridPath
                    if os.path.exists(self.diskGridPath) and not self.is_layer_loaded(
                        filename_without_extension
                    ):
                        os.remove(self.diskGridPath)
                        if os.path.exists(self.diskGridPath + "aux.xml"):
                            os.remove(self.diskGridPath + "aux.xml")

                    basename = os.path.basename(self.diskGridPath)
                    extension = os.path.splitext(basename)[1].lower()
                    if extension == "ers":
                        driver = gdal.GetDriverByName("ERS")
                    else:
                        driver = gdal.GetDriverByName("GTiff")

                    if header["ordering"] == 1:
                        ds = driver.Create(
                            fn,
                            xsize=header["shape_e"],
                            ysize=header["shape_v"],
                            bands=1,
                            eType=Gdata_type,
                        )
                    else:
                        ds = driver.Create(
                            fn,
                            xsize=header["shape_v"],
                            ysize=header["shape_e"],
                            bands=1,
                            eType=Gdata_type,
                        )

                    ds.GetRasterBand(1).WriteArray(grid)
                    geot = [
                        header["x_origin"] - (header["spacing_e"] / 2),
                        header["spacing_e"],
                        0,
                        header["y_origin"] - (header["spacing_v"] / 2),
                        0,
                        header["spacing_e"],
                    ]
                    ds.SetGeoTransform(geot)
                    srs = osr.SpatialReference()
                    srs.ImportFromEPSG(int(epsg))
                    ds.SetProjection(srs.ExportToWkt())
                    ds.FlushCache()
                    ds = None

                    self.layer = QgsRasterLayer(
                        self.diskGridPath, filename_without_extension
                    )
                    if not self.is_layer_loaded(filename_without_extension):
                        QgsProject.instance().addMapLayer(self.layer)

        else:
            self.iface.messageBar().pushMessage(
                "You need to select a file first", level=Qgis.Warning, duration=3
            )

    def select_point_file(self):
        start_directory = self.last_directory if self.last_directory else os.getcwd()

        self.diskPointsPath, _filter = QFileDialog.getOpenFileName(
            None,
            "Select Data File",
            start_directory,
            "points or lines (*.csv;*.txt;*.xyz;*.CSV;*.TXT;*.XYZ)",
        )
        if os.path.exists(self.diskPointsPath) and self.diskPointsPath != "":
            self.last_directory = os.path.dirname(self.diskPointsPath)

            basename = os.path.basename(self.diskPointsPath)
            extension = os.path.splitext(basename)[1]
            self.dlg.lineEdit_loadPointsPath.setText(self.diskPointsPath)

            if extension == ".XYZ":
                self.line_data_cols = self.get_XYZ_header(self.diskPointsPath)
                self.dlg.mQgsProjectionSelectionWidget.setEnabled(True)
                self.dlg.pushButton_load_point_data.setEnabled(True)
                self.pointType = "line"

            else:
                columns = self.read_csv_header(self.diskPointsPath)
                # columns = list(points.columns)
                self.dlg.comboBox_grid_x.setEnabled(True)
                self.dlg.comboBox_grid_y.setEnabled(True)
                self.dlg.mQgsProjectionSelectionWidget.setEnabled(True)

                self.dlg.comboBox_grid_x.addItems(columns)
                self.dlg.comboBox_grid_y.addItems(columns)
                self.dlg.pushButton_load_point_data.setEnabled(True)
                self.pointType = "point"

    def read_csv_header(self, file_path):
        """
        Reads the first line of a CSV file and stores the values as a list.

        :param file_path: Path to the CSV file
        :return: List of header values
        """
        with open(file_path, "r") as file:
            header = file.readline().strip().split(",")
        return header

    def numpy_array_to_raster(
        self,
        numpy_array,
        raster_path,
        dx=None,
        xmin=None,
        ymax=None,
        reference_layer=None,
        no_data_value=np.nan,
    ):
        """
        Convert a NumPy array to a GeoTIFF raster file.

        Parameters:
            numpy_array (numpy.ndarray): The NumPy array to convert.
            raster_path (str): The path to save the raster file.
            reference_layer (QgsRasterLayer, optional): A reference layer for CRS and geotransform.
            no_data_value: Value to use for no data (default is NaN).
        """

        # Check if the file already exists and remove it
        if os.path.exists(raster_path):
            try:
                os.remove(raster_path)
                if os.path.exists(raster_path + "aux.xml"):
                    os.remove(raster_path + "aux.xml")
            except:
                self.iface.messageBar().pushMessage(
                    "Couldn't delete layer, may be open in another program? On windows files on non-C: drive may be hard to delete",
                    level=Qgis.Warning,
                    duration=15,
                )
                return -1

        rows, cols = numpy_array.shape
        driver = gdal.GetDriverByName("GTiff")
        output_raster = driver.Create(raster_path, cols, rows, 1, gdal.GDT_Float32)

        # Set geotransform and projection if a reference layer is provided
        if reference_layer:
            provider = reference_layer.dataProvider()
            extent = provider.extent()
            geotransform = [
                extent.xMinimum(),
                extent.width() / cols,  # pixel width
                0,
                extent.yMaximum(),
                0,
                -extent.height() / rows,  # pixel height (negative)
            ]
            output_raster.SetGeoTransform(geotransform)

            # Set CRS
            srs = osr.SpatialReference()
            srs.ImportFromWkt(reference_layer.crs().toWkt())
            output_raster.SetProjection(srs.ExportToWkt())
        else:
            crs = self.dlg.mQgsProjectionSelectionWidget.crs().authid()
            srs = osr.SpatialReference()
            srs.ImportFromEPSG(int(crs.split(":")[1]))
            output_raster.SetProjection(srs.ExportToWkt())
            geotransform = [
                xmin,
                dx,  # pixel width
                0,
                ymax,
                0,
                -dx,  # pixel height (negative)
            ]
            output_raster.SetGeoTransform(geotransform)

        # Write data to raster
        band = output_raster.GetRasterBand(1)
        if no_data_value is not None:
            band.SetNoDataValue(no_data_value)
        numpy_array = np.nan_to_num(
            numpy_array, nan=no_data_value
        )  # Replace NaN with no_data_value
        band.WriteArray(numpy_array)
        band.FlushCache()
        output_raster = None  # Close the file
        return 0

    def day_month_to_decimal_year(self, year, month, day):
        """
        Convert a day and month into a decimal year.

        Parameters:
            year (int): The year.
            month (int): The month (1-12).
            day (int): The day (1-31 depending on the month).

        Returns:
            float: The decimal year.
        """
        # Create datetime object for the given date
        date = datetime(year, month, day)

        # Calculate the start and end of the year
        start_of_year = datetime(year, 1, 1)
        end_of_year = datetime(year + 1, 1, 1)

        # Calculate the total days in the year
        days_in_year = (end_of_year - start_of_year).days

        # Calculate the number of days since the start of the year
        days_since_start_of_year = (date - start_of_year).days

        # Compute the decimal year
        decimal_year = year + days_since_start_of_year / days_in_year

        return decimal_year

    # estimate mag field from centroid of data, date and sensor height
    def update_mag_field(self):

        self.localGridName = self.dlg.mMapLayerComboBox_selectGrid.currentText()
        if os.path.exists(self.diskGridPath) or self.localGridName:
            if os.path.exists(self.diskGridPath):
                self.loadGrid()

            self.layer = QgsProject.instance().mapLayersByName(self.localGridName)[0]
            self.base_name = self.localGridName

            # retrieve parameters
            self.magn_int = self.dlg.lineEdit_6_int.text()
            date_text = str(self.dlg.dateEdit.date().toPyDate())

            date_split = date_text.split("-")
            self.magn_SurveyDay = int(date_split[2])
            self.magn_SurveyMonth = int(date_split[1])
            self.magn_SurveyYear = int(date_split[0])
            date = datetime(
                self.magn_SurveyYear, self.magn_SurveyMonth, self.magn_SurveyDay
            )
            extent = self.layer.extent()  # Get the extent of the raster layer

            # calculate midpoint of grid
            midx = extent.xMinimum() + (extent.xMaximum() - extent.xMinimum()) / 2
            midy = extent.yMinimum() + (extent.yMaximum() - extent.yMinimum()) / 2

            if self.layer.crs().authid():
                # convert midpoint to lat/long
                magn_proj = self.layer.crs().authid().split(":")[1]
                from pyproj import CRS

                crs_proj = CRS.from_user_input(int(magn_proj))
                crs_ll = CRS.from_user_input(4326)
                proj = Transformer.from_crs(crs_proj, crs_ll, always_xy=True)
                long, lat = proj.transform(midx, midy)

                date = self.day_month_to_decimal_year(
                    self.magn_SurveyYear, self.magn_SurveyMonth, self.magn_SurveyDay
                )

                I, D, intensity = self.calcIGRF(date, float(100.0), lat, long)

                self.RTE_P_inc = I
                self.RTE_P_dec = D
                self.RTE_P_int = intensity

                # update widgets
                self.dlg.lineEdit_5_dec.setText(str(round(self.RTE_P_dec, 1)))
                self.dlg.lineEdit_6_inc.setText(str(round(self.RTE_P_inc, 1)))
                self.dlg.lineEdit_6_int.setText(str(int(self.RTE_P_int)))
            else:
                self.iface.messageBar().pushMessage(
                    "Sorry, I couldn't interpret the projection system of this layer, try either saving out grid or define the Inc/Dec manually.",
                    level=Qgis.Warning,
                    duration=15,
                )

    def calcIGRF(self, date, alt, lat, lon):

        igrf_gen = "14"
        itype = 1
        d1 = d2 = d3 = None
        colat = 90 - lat
        iut = IGRF(d1, d2, d3)

        # Load in the file of coefficients
        # IGRF_FILE = r"./SHC_files/IGRF" + igrf_gen + ".SHC"
        IGRF_FILE = (
            os.path.dirname(os.path.realpath(__file__))
            + "/igrf/SHC_files/IGRF"
            + igrf_gen
            + ".SHC"
        )
        from pathlib import Path

        def convert_to_native_path(mixed_path):
            return str(Path(mixed_path))

        IGRF_FILE_norm = convert_to_native_path(IGRF_FILE)
        igrf = iut.load_shcfile(IGRF_FILE_norm, None)

        # Interpolate the geomagnetic coefficients to the desired date(s)
        # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
        f = interpolate.interp1d(igrf.time, igrf.coeffs, fill_value="extrapolate")
        coeffs = f(date)

        # Compute the main field B_r, B_theta and B_phi value for the location(s)
        Br, Bt, Bp = iut.synth_values(
            coeffs.T, alt, colat, lon, igrf.parameters["nmax"]
        )

        # For the SV, find the 5 year period in which the date lies and compute
        # the SV within that period. IGRF has constant SV between each 5 year period
        # We don't need to subtract 1900 but it makes it clearer:
        epoch = (date - 1900) // 5
        epoch_start = epoch * 5
        # Add 1900 back on plus 1 year to account for SV in nT per year (nT/yr):
        coeffs_sv = f(1900 + epoch_start + 1) - f(1900 + epoch_start)
        Brs, Bts, Bps = iut.synth_values(
            coeffs_sv.T, alt, colat, lon, igrf.parameters["nmax"]
        )

        # Use the main field coefficients from the start of each five epoch
        # to compute the SV for Dec, Inc, Hor and Total Field (F)
        # [Note: these are non-linear components of X, Y and Z so treat separately]
        coeffsm = f(1900 + epoch_start)
        Brm, Btm, Bpm = iut.synth_values(
            coeffsm.T, alt, colat, lon, igrf.parameters["nmax"]
        )

        # Rearrange to X, Y, Z components
        X = -Bt
        Y = Bp
        Z = -Br
        # For the SV
        dX = -Bts
        dY = Bps
        dZ = -Brs
        Xm = -Btm
        Ym = Bpm
        Zm = -Brm
        if itype == 1:
            # alt = input("Enter altitude in km: ").rstrip()
            # alt = iut.check_float(alt)
            alt, colat, sd, cd = iut.gg_to_geo(alt, colat)

        # Rotate back to geodetic coords if needed
        if itype == 1:
            t = X
            X = X * cd + Z * sd
            Z = Z * cd - t * sd
            t = dX
            dX = dX * cd + dZ * sd
            dZ = dZ * cd - t * sd
            t = Xm
            Xm = Xm * cd + Zm * sd
            Zm = Zm * cd - t * sd

        intensity = np.sqrt(X**2 + Y**2 + Z**2)
        # Compute the four non-linear components
        dec, hoz, inc, eff = iut.xyz2dhif(X, Y, Z)
        return inc, dec, intensity

    def extract_raster_to_numpy(self, raster_layer):
        """
        Extract the raster data from a QgsRasterLayer as a NumPy array.

        Parameters:
            raster_layer (QgsRasterLayer): The QGIS raster layer.

        Returns:
            numpy.ndarray: The raster data as a NumPy array.
        """
        # Get the raster data provider
        # Access the raster data provider
        provider = raster_layer.dataProvider()

        # Get raster dimensions
        cols = provider.xSize()  # Number of columns
        rows = provider.ySize()  # Number of rows

        # Read raster data as a block
        band = 1  # Specify the band number (1-based index)
        raster_block = provider.block(band, provider.extent(), cols, rows)

        # Copy the block data into a NumPy array
        extent = raster_layer.extent()
        rows, cols = raster_layer.height(), raster_layer.width()
        raster_block = provider.block(1, extent, cols, rows)  # !!!!!
        self.raster_array = np.zeros((rows, cols))
        for i in range(rows):
            for j in range(cols):
                self.raster_array[i, j] = raster_block.value(i, j)

        # Handle NoData values if needed
        no_data_value = provider.sourceNoDataValue(1)  # Band 1

        if no_data_value is not None:
            self.raster_array[self.raster_array == no_data_value] = np.nan

        return self.raster_array

    def display_rad_power_spectrum(self):
        self.localGridName = self.dlg.mMapLayerComboBox_selectGrid.currentText()
        if self.localGridName != "":
            self.pslayer = QgsProject.instance().mapLayersByName(self.localGridName)[0]
            if self.pslayer.isValid():
                grid = self.extract_raster_to_numpy(
                    self.pslayer
                )  # Your method to get NumPy array from raster

                dx, dy = (
                    self.pslayer.rasterUnitsPerPixelX(),
                    self.pslayer.rasterUnitsPerPixelY(),
                )
                # Get extent
                extent = self.pslayer.extent()
                minx = extent.xMinimum()
                maxx = extent.xMaximum()
                miny = extent.yMinimum()
                maxy = extent.yMaximum()

                # Get number of columns (nx) and rows (ny)
                provider = self.pslayer.dataProvider()
                nx = provider.xSize()  # Number of columns
                ny = provider.ySize()  # Number of rows

                x = np.linspace(minx, maxx, provider.xSize())
                y = np.linspace(miny, maxy, provider.ySize())
                # Initialize the PowerSpectrumDock and display the plot
                power_spectrum_dock = PowerSpectrumDock(
                    grid, self.localGridName, dx, dy, x, y
                )
                power_spectrum_dock.plot_grid_and_power_spectrum()

    def update_paths(self):
        self.localGridName = self.dlg.mMapLayerComboBox_selectGrid.currentText()
        self.dlg.mMapLayerComboBox_selectGrid_Conv.setCurrentText(self.localGridName)
        self.dlg.mMapLayerComboBox_selectGrid_worms.setCurrentText(self.localGridName)
        self.dlg.lineEdit_2_loadGridPath.setText("")
        self.diskGridPath = ""
        self.base_name = self.localGridName

        if len(self.base_name) > 0:
            selected_layer = QgsProject.instance().mapLayersByName(self.localGridName)[
                0
            ]
            if selected_layer.isValid():
                crs = selected_layer.crs()
                if crs.isGeographic():
                    self.dlg.label_41_units.setText("Units: deg")
                else:
                    self.dlg.label_41_units.setText("Units: m")

    def update_paths_conv(self):
        self.localGridName = self.dlg.mMapLayerComboBox_selectGrid_Conv.currentText()
        self.dlg.mMapLayerComboBox_selectGrid.setCurrentText(self.localGridName)
        self.dlg.mMapLayerComboBox_selectGrid_worms.setCurrentText(self.localGridName)
        self.dlg.lineEdit_2_loadGridPath.setText("")
        self.diskGridPath = ""
        self.base_name = self.localGridName

        if len(self.base_name) > 0:
            selected_layer = QgsProject.instance().mapLayersByName(self.localGridName)[
                0
            ]
            if selected_layer.isValid():
                crs = selected_layer.crs()
                if crs.isGeographic():
                    self.dlg.label_41_units.setText("Units: deg")
                else:
                    self.dlg.label_41_units.setText("Units: m")

    def update_paths_worms(self):
        self.localGridName = self.dlg.mMapLayerComboBox_selectGrid_worms.currentText()

        self.dlg.mMapLayerComboBox_selectGrid.setCurrentText(self.localGridName)
        self.dlg.mMapLayerComboBox_selectGrid_Conv.setCurrentText(self.localGridName)
        self.dlg.lineEdit_2_loadGridPath.setText("")
        self.diskGridPath = ""
        self.base_name = self.localGridName

        if len(self.base_name) > 0:
            selected_layer = QgsProject.instance().mapLayersByName(self.localGridName)[
                0
            ]
            if selected_layer.isValid():
                crs = selected_layer.crs()
                if crs.isGeographic():
                    self.dlg.label_41_units.setText("Units: deg")
                else:
                    self.dlg.label_41_units.setText("Units: m")

    # --------------------------------------------------------------------------
    def show_version(self):
        metadata_path = os.path.dirname(os.path.realpath(__file__)) + "/metadata.txt"

        with open(metadata_path) as plugin_version_file:
            metadata = plugin_version_file.readlines()
            for line in metadata:
                parts = line.split("=")
                if len(parts) == 2 and parts[0] == "version":
                    plugin_version = parts[1]

            return plugin_version

    def run(self):
        """Run method that loads and starts the plugin"""

        if not self.pluginIsActive:
            self.pluginIsActive = True
            self.initParams()
            # print "** STARTING SGTool"

            # dockwidget may not exist if:
            #    first run of plugin
            #    removed on close (see self.onClosePlugin method)
            if self.dlg == None:
                # Create the dockwidget (after translation) and keep reference
                self.dlg = SGToolDockWidget()

            # connect to provide cleanup on closing of dockwidget
            self.dlg.closingPlugin.connect(self.onClosePlugin)

            # show the dockwidget
            # TODO: fix to allow choice of dock location
            self.iface.addDockWidget(Qt.RightDockWidgetArea, self.dlg)

            # Find existing dock widgets in the right area
            right_docks = [
                d
                for d in self.iface.mainWindow().findChildren(QDockWidget)
                if self.iface.mainWindow().dockWidgetArea(d) == Qt.RightDockWidgetArea
            ]
            # If there are other dock widgets, tab this one with the first one found
            if right_docks:
                for dock in right_docks:
                    if dock != self.dlg:
                        self.iface.mainWindow().tabifyDockWidget(dock, self.dlg)
                        # Optionally, bring your plugin tab to the front
                        self.dlg.raise_()
                        break
            # Raise the docked widget above others
            self.dlg.show()
            self.define_tips()

            # Access the QgsMapLayerComboBox by its objectName
            self.dlg.mMapLayerComboBox_selectGrid.setFilters(
                QgsMapLayerProxyModel.RasterLayer
            )
            self.dlg.mMapLayerComboBox_selectGrid_Conv.setFilters(
                QgsMapLayerProxyModel.RasterLayer
            )
            self.dlg.mMapLayerComboBox_selectGrid_worms.setFilters(
                QgsMapLayerProxyModel.RasterLayer
            )
            self.dlg.mMapLayerComboBox_selectGrid_3.setFilters(
                QgsMapLayerProxyModel.PointLayer
            )

            self.dlg.version_label.setText(self.show_version())

            self.deriv_dir_list = []
            self.deriv_dir_list.append("z")
            self.deriv_dir_list.append("x")
            self.deriv_dir_list.append("y")
            self.dlg.comboBox_derivDirection.addItems(self.deriv_dir_list)

            self.contin_dir_list = []
            self.contin_dir_list.append("up")
            self.contin_dir_list.append("down")
            self.dlg.comboBox_2_continuationDirection.addItems(self.contin_dir_list)

            self.dlg.comboBox_3_rte_p_list.clear()
            self.ret_p_list = []
            self.ret_p_list.append("Pole")
            self.ret_p_list.append("Eqtr")
            self.dlg.comboBox_3_rte_p_list.addItems(self.ret_p_list)

            self.freq_cut_type_list = []
            self.freq_cut_type_list.append("Low")
            self.freq_cut_type_list.append("High")
            self.dlg.comboBox_2_FreqCutType.addItems(self.freq_cut_type_list)

            self.dlg.pushButton_4_calcIGRF.clicked.connect(self.update_mag_field)
            self.dlg.pushButton_2_selectGrid.clicked.connect(self.select_grid_file)
            self.dlg.pushButton_2_selectGrid_RGB.clicked.connect(
                self.select_RGBgrid_file
            )
            self.dlg.pushButton_3_applyProcessing.clicked.connect(
                self.processGeophysics_fft
            )
            self.dlg.pushButton_3_applyProcessing_Conv.clicked.connect(
                self.processGeophysics_conv
            )
            self.dlg.pushButton_3_applyProcessing_Conv_2.clicked.connect(
                self.processRGB
            )

            self.dlg.pushButton_selectPoints.clicked.connect(self.select_point_file)

            self.dlg.mMapLayerComboBox_selectGrid.layerChanged.connect(
                self.update_paths
            )
            self.dlg.mMapLayerComboBox_selectGrid_Conv.layerChanged.connect(
                self.update_paths_conv
            )

            self.dlg.mMapLayerComboBox_selectGrid_worms.layerChanged.connect(
                self.update_paths_worms
            )

            self.dlg.mQgsProjectionSelectionWidget.setCrs(
                QgsCoordinateReferenceSystem("EPSG:4326")
            )
            self.dlg.pushButton_rad_power_spectrum.clicked.connect(
                self.display_rad_power_spectrum
            )
            self.localGridName = self.dlg.mMapLayerComboBox_selectGrid.currentText()

            if self.localGridName:
                selected_layer = QgsProject.instance().mapLayersByName(
                    self.localGridName
                )[0]
                crs = selected_layer.crs()
                if crs.isGeographic():
                    self.dlg.label_41_units.setText("Units: deg")
                else:
                    self.dlg.label_41_units.setText("Units: m")

            self.dlg.lineEdit_Mean_size.setValidator(OddPositiveIntegerValidator())
            self.dlg.lineEdit_Median_size.setValidator(OddPositiveIntegerValidator())

            self.directional_list = []
            self.directional_list.append("N")
            self.directional_list.append("NE")
            self.directional_list.append("E")
            self.directional_list.append("SE")
            self.directional_list.append("S")
            self.directional_list.append("SW")
            self.directional_list.append("W")
            self.directional_list.append("NW")
            self.dlg.comboBox_Dir_dir.addItems(self.directional_list)

            self.dlg.pushButton_load_point_data.clicked.connect(
                self.import_point_line_data
            )

            self.dlg.mMapLayerComboBox_selectGrid_3.layerChanged.connect(
                self.updateLayertoGrid
            )

            self.dlg.doubleSpinBox_cellsize.valueChanged.connect(
                self.updateLayertoGrid2
            )
            self.dlg.comboBox_select_grid_data_field.currentTextChanged.connect(
                self.updateLayertoGrid2
            )
            # Connect to layer removal signal
            QgsProject.instance().layerRemoved.connect(self.refreshComboBox)

            if self.dlg.mMapLayerComboBox_selectGrid_3.currentText() != "":
                self.updateLayertoGrid()

            self.cell_size = self.dlg.doubleSpinBox_cellsize.value()

            self.gridDirectory = None

            # Connection to the Github site  :
            self.dlg.pushButton_repo.clicked.connect(
                lambda: QDesktopServices.openUrl(
                    QUrl("https://github.com/swaxi/SGTool")
                )
            )
            # Connection to the Help pdf file :
            self.dlg.pushButton_help.clicked.connect(
                lambda: QDesktopServices.openUrl(
                    QUrl(
                        "https://raw.githubusercontent.com/swaxi/SGTool/refs/heads/main/Structural%20Geophysics%20Tools.pdf"
                    )
                )
            )
            # Connection to the CSS Colours site  :
            self.dlg.pushButton_CSSS_Colours.clicked.connect(
                lambda: QDesktopServices.openUrl(
                    QUrl(
                        "https://matplotlib.org/stable/gallery/color/named_colors.html#css-colors"
                    )
                )
            )

            # self.dlg.pushButton_rst.clicked.connect(self.procRSTGridding)
            self.dlg.pushButton_idw_2.clicked.connect(self.procIDWGridding)
            self.dlg.pushButton_bspline_3.clicked.connect(self.procbsplineGridding)

            self.dlg.pushButton_worms.clicked.connect(self.procBSDworms)

            self.dlg.pushButton_normalise.clicked.connect(self.procNormalise)
            self.dlg.pushButton_select_normalise_in.clicked.connect(
                self.set_normalise_in
            )
            self.dlg.pushButton_select_normalise_out.clicked.connect(
                self.set_normalise_out
            )

    # select directory to store grid
    def gridFile(self):
        self.gridFilePath = QFileDialog.getSaveFileName(None, "Save grid file as")

        if len(self.gridFilePath) > 1:
            extension = self.gridFilePath[1]
        else:
            extension = ""
        self.gridFilePath = self.gridFilePath[0]

        if extension.lower() != ".tif":
            self.gridFilePath = self.gridFilePath + ".tif"

        self.dlg.lineEdit_gridOutputDir.setText(self.gridFilePath)

    def get_layer_fields(self, layer):
        """
        Get a list of field names from a QgsVectorLayer.

        Parameters:
            layer (QgsVectorLayer): The vector layer to retrieve fields from.

        Returns:
            list: A list of field names.
        """
        if not layer.isValid():
            raise ValueError("Invalid layer provided.")

        fields = layer.fields()
        field_names = [field.name() for field in fields]
        return field_names

    # Function to refresh the combo box
    def refreshComboBox(self):
        return
        comboBox = self.dlg.mMapLayerComboBox_selectGrid_3
        comboBox.clear()
        for layer in QgsProject.instance().mapLayers().values():
            comboBox.addItem(layer.name(), layer.id())

    def updateLayertoGrid(self):
        if self.dlg.mMapLayerComboBox_selectGrid_3.count() > 0:

            self.selectedPoints = self.dlg.mMapLayerComboBox_selectGrid_3.currentText()
            selected_layer = QgsProject.instance().mapLayersByName(self.selectedPoints)[
                0
            ]
            if selected_layer.isValid():

                field_names = self.get_layer_fields(selected_layer)
                self.dlg.comboBox_select_grid_data_field.clear()
                self.dlg.comboBox_select_grid_data_field.addItems(field_names)

                if selected_layer.featureCount() > 0:
                    extent = selected_layer.extent()

                    self.cell_size = self.dlg.doubleSpinBox_cellsize.value()

                    self.nx_label = int(
                        (extent.xMaximum() - extent.xMinimum()) / self.cell_size
                    )
                    self.ny_label = int(
                        (extent.yMaximum() - extent.yMinimum()) / self.cell_size
                    )
                    self.dlg.nx_label.setText(str(self.nx_label))
                    self.dlg.ny_label.setText(str(self.ny_label))

    def updateLayertoGrid2(self):

        if self.dlg.mMapLayerComboBox_selectGrid_3.count() > 0:

            self.selectedPoints = self.dlg.mMapLayerComboBox_selectGrid_3.currentText()
            selected_layer = QgsProject.instance().mapLayersByName(self.selectedPoints)[
                0
            ]
            if selected_layer.isValid():

                extent = selected_layer.extent()
                self.cell_size = self.dlg.doubleSpinBox_cellsize.value()

                if selected_layer.featureCount() > 0:
                    self.nx_label = int(
                        (extent.xMaximum() - extent.xMinimum()) / self.cell_size
                    )
                    self.ny_label = int(
                        (extent.yMaximum() - extent.yMinimum()) / self.cell_size
                    )
                    self.dlg.nx_label.setText(str(self.nx_label))
                    self.dlg.ny_label.setText(str(self.ny_label))

    def import_point_line_data(self):
        # import point or line data as vector file to memory
        dir_name, base_name = os.path.split(self.diskPointsPath)
        file_name, file_ext = os.path.splitext(base_name)

        proj = self.dlg.mQgsProjectionSelectionWidget.crs().authid()

        if self.pointType == "line":
            crs = proj.split(":")[1]
            load_ties = self.dlg.checkBox_load_tie_lines.isChecked()
            self.import_XYZ(self.diskPointsPath, crs, file_name, load_ties=load_ties)
        else:
            x_field = self.dlg.comboBox_grid_x.currentText()
            y_field = self.dlg.comboBox_grid_y.currentText()
            self.import_CSV(
                self.diskPointsPath, x_field, y_field, layer_name=file_name, crs=proj
            )

    def import_CSV(
        self, file_path, x_field, y_field, layer_name="points", crs="EPSG:4326"
    ):
        """
        Loads a CSV file as a vector layer with all attributes in QGIS.

        Parameters:
            file_path (str): Path to the CSV file.
            x_field (str): Name of the column containing X coordinates.
            y_field (str): Name of the column containing Y coordinates.
            layer_name (str): Name for the layer in the QGIS project.
            crs (str): Coordinate reference system for the layer (default is 'EPSG:4326').

        Returns:
            QgsVectorLayer: The loaded vector layer.
        """
        # Define the URI for the CSV file, specifying coordinate fields and CRS
        uri = (
            f"file:///{file_path}?type=csv&xField={x_field}&yField={y_field}"
            f"&crs={crs}&detectTypes=yes&delimiter=,&quote="
        )

        # Load the layer as a delimited text vector layer
        layer = QgsVectorLayer(uri, layer_name, "delimitedtext")

        if not layer.isValid():
            raise ValueError(f"Failed to load layer: {file_path}")

        # Add the layer to the current QGIS project
        QgsProject.instance().addMapLayer(layer)
        return layer

    def get_XYZ_header(self, csv_file):
        # Input file path
        # csv_file = r"//wsl.localhost/Ubuntu-20.04/home/mark/gridding/MAG.XYZ"  # Replace with your actual file path

        # Initialize variables
        data_list = []
        current_line_number = None

        # Read the file line-by-line
        with open(csv_file, "r") as file:
            for line in file:
                line = line.strip()
                if line.startswith("LINE:"):  # Check for 'LINE:' markers
                    current_line_number = int(
                        re.search(r"\d+", line).group()
                    )  # Extract the line number
                elif current_line_number is not None:
                    try:
                        # Parse numerical lines
                        parts = list(map(float, line.split()))
                        if len(parts) >= 3:  # Ensure at least 5 components (x, y, z)
                            data_list.append(parts + [current_line_number])
                            return len(parts) - 2
                    except ValueError:
                        pass

    def import_XYZ(self, XYZ_file, crs, layer_name="line", load_ties=True):
        # Initialize variables
        data_list = []
        current_line_number = None

        # Read the file line-by-line
        with open(XYZ_file, "r") as file:
            for line in file:
                line = line.strip()
                if line.startswith("LINE:"):  # Check for 'LINE:' markers
                    current_line_number = int(re.search(r"\d+", line).group())
                elif line.startswith("TIE:"):  # Check for 'TIE:' markers
                    if load_ties:
                        current_line_number = int(re.search(r"\d+", line).group())
                    else:
                        current_line_number = None
                elif current_line_number is not None:
                    try:
                        parts = list(map(float, line.split()))
                        if len(parts) >= 2:  # Ensure at least x and y are present
                            data_list.append(parts + [current_line_number])
                    except ValueError:
                        pass

        # Process and create the line layer
        line_layer = QgsVectorLayer("LineString?crs=EPSG:" + crs, layer_name, "memory")
        line_provider = line_layer.dataProvider()

        fields = QgsFields()
        fields.append(QgsField("LINE_ID", QVariant.Int))

        for i in range(len(data_list[0]) - 3):
            fields.append(QgsField(f"data_{i}", QVariant.Double))

        line_provider.addAttributes(fields)
        line_layer.updateFields()

        line_features = {}

        for data in data_list:
            x, y, *values, line_id = data
            if line_id not in line_features:
                line_features[line_id] = []
            line_features[line_id].append((x, y, values))

        for line_id, points in line_features.items():
            coords = [QgsPointXY(x, y) for x, y, _ in points]
            geometry = QgsGeometry.fromPolylineXY(coords)

            feature = QgsFeature()
            feature.setGeometry(geometry)

            first_values = points[0][2]
            feature.setAttributes([line_id] + first_values)
            line_provider.addFeature(feature)

        QgsProject.instance().addMapLayer(line_layer)

        # Create the point layer
        point_layer = QgsVectorLayer(
            "Point?crs=EPSG:" + crs, f"{layer_name}_points", "memory"
        )
        point_provider = point_layer.dataProvider()

        fields = QgsFields()
        fields.append(QgsField("LINE_ID", QVariant.Int))

        for i in range(len(data_list[0]) - 3):
            fields.append(QgsField(f"data_{i}", QVariant.Double))

        point_provider.addAttributes(fields)
        point_layer.updateFields()

        for data in data_list:
            x, y, *values, line_id = data
            point = QgsGeometry.fromPointXY(QgsPointXY(x, y))

            feature = QgsFeature()
            feature.setGeometry(point)
            feature.setAttributes([line_id] + values)
            point_provider.addFeature(feature)

        QgsProject.instance().addMapLayer(point_layer)
        layer_tree = QgsProject.instance().layerTreeRoot()
        layer_tree.findLayer(point_layer.id()).setItemVisibilityChecked(False)

    def convert_RGB_to_grey(self, RGBGridPath, LUT):
        result = False

        # Open the 3-band TIF using GDAL
        dataset = gdal.Open(RGBGridPath, gdal.GA_ReadOnly)
        if not dataset:
            self.iface.messageBar().pushMessage(
                "Unable to open the dataset.", level=Qgis.Warning, duration=15
            )
            return False, -3

        if dataset.RasterCount < 3:
            self.iface.messageBar().pushMessage(
                "Data file must have at least 3 layers", level=Qgis.Warning, duration=15
            )
            return False, -3

        red = dataset.GetRasterBand(1).ReadAsArray().astype(float)
        green = dataset.GetRasterBand(2).ReadAsArray().astype(float)
        blue = dataset.GetRasterBand(3).ReadAsArray().astype(float)

        transform = dataset.GetGeoTransform()
        projection = dataset.GetProjection()

        # Stack bands into an RGB array
        rgb_raster = np.dstack((red, green, blue))

        # Parse the LUT
        LUT = LUT.replace(" ", "")
        css_color_list = LUT.split(",")
        css_color_list.reverse()

        lut = self.generate_rgb_lut(css_color_list, num_entries=1024)
        if not lut:
            print("Couldn't generate LUT")
            return False, False

        scalar_values, lut_colors = zip(*lut)
        lut_colors = np.array(lut_colors) / 255.0  # Normalize LUT colors

        # Identify white and black pixels
        white_mask = (rgb_raster == [255, 255, 255]).all(axis=2)
        black_mask = (rgb_raster == [0, 0, 0]).all(axis=2)

        # Normalize raster RGB values to [0, 1]
        normalized_rgb = rgb_raster / 255.0

        # Flatten RGB raster for KDTree query
        reshaped_rgb = normalized_rgb.reshape(-1, 3)

        # Build a KDTree for nearest neighbor lookup
        lut_tree = cKDTree(lut_colors)
        distances, indices = lut_tree.query(reshaped_rgb)

        # Map nearest LUT color to scalar values
        scalar_grid = np.array(scalar_values)[indices].reshape(rgb_raster.shape[:2])

        # Set white and black areas to NaN
        scalar_grid[white_mask] = np.nan
        scalar_grid[black_mask] = np.nan

        # Scale data
        LUT_min = self.dlg.mQgsDoubleSpinBox_LUT_min.value()
        LUT_max = self.dlg.mQgsDoubleSpinBox_LUT_max.value()
        scalar_grid = (scalar_grid * (LUT_max - LUT_min)) + LUT_min

        # Prepare output file
        driver = gdal.GetDriverByName("GTiff")
        RGBGridPath_gray = self.insert_text_before_extension(RGBGridPath, "_gray")
        print(RGBGridPath_gray, dataset.RasterXSize, dataset.RasterYSize, projection)
        output_dataset = driver.Create(
            RGBGridPath_gray,
            dataset.RasterXSize,
            dataset.RasterYSize,
            1,
            gdal.GDT_Float32,
        )

        if not output_dataset:
            self.iface.messageBar().pushMessage(
                "Unable to create the output dataset.", level=Qgis.Warning, duration=15
            )
            return False, -3

        output_dataset.SetGeoTransform(transform)
        output_dataset.SetProjection(projection)

        # Write the scaled scalar grid to the output file
        output_band = output_dataset.GetRasterBand(1)
        output_band.WriteArray(scalar_grid)
        output_band.SetNoDataValue(np.nan)

        # Cleanup
        output_band = None
        output_dataset = None
        dataset = None

        result = True
        return result, RGBGridPath_gray

    def convert_RGB_to_grey_rasterio(self, RGBGridPath, LUT):
        result = False
        # Load the 3-band TIF
        with rasterio.open(RGBGridPath) as src:
            array = src.read()
            if array.shape[0] < 3:
                self.iface.messageBar().pushMessage(
                    "Data file must have at least 3 layers",
                    level=Qgis.Warning,
                    duration=15,
                )
                return False, -3
            else:

                red = src.read(1)  # Band 1
                green = src.read(2)  # Band 2
                blue = src.read(3)  # Band 3
                profile = src.profile
                transform = src.transform
                crs = src.crs

        if src == "":
            src = "32753"
        # Stack bands into an RGB array
        rgb_raster = np.dstack((red, green, blue)).astype(float)

        LUT = LUT.replace(" ", "")
        css_color_list = LUT.split(",")
        css_color_list.reverse()
        # Define the LUT for (high to low) scalar values
        lut = self.generate_rgb_lut(css_color_list, num_entries=1024)
        if lut:
            # Extract scalar values and RGB colors
            scalar_values, lut_colors = zip(*lut)
            lut_colors = np.array(lut_colors) / 255.0  # Normalize LUT colors

            # Identify white pixels (255, 255, 255)
            white_mask = (rgb_raster == [255, 255, 255]).all(axis=2)

            # Identify white pixels (255, 255, 255)
            black_mask = (rgb_raster == [0, 0, 0]).all(axis=2)

            # Normalize raster RGB values to [0, 1]
            normalized_rgb = rgb_raster / 255.0

            # Flatten RGB raster for KDTree query
            reshaped_rgb = normalized_rgb.reshape(-1, 3)

            # Build a KDTree for nearest neighbor lookup
            lut_tree = cKDTree(lut_colors)
            distances, indices = lut_tree.query(reshaped_rgb)

            # Map nearest LUT color to scalar values
            scalar_grid = np.array(scalar_values)[indices].reshape(rgb_raster.shape[:2])

            # Set white (255, 255, 255) areas to NaN
            scalar_grid[white_mask] = np.nan
            scalar_grid[black_mask] = np.nan

            # Save the floating-point raster with georeferencing
            profile.update(
                count=1, dtype="float32", transform=transform, crs=crs, nodata=np.nan
            )
            RGBGridPath_gray = self.insert_text_before_extension(RGBGridPath, "_gray")

            # scale data
            LUT_min = self.dlg.mQgsDoubleSpinBox_LUT_min.value()
            LUT_max = self.dlg.mQgsDoubleSpinBox_LUT_max.value()

            scalar_grid = (scalar_grid * (LUT_max - LUT_min)) + LUT_min

            with rasterio.open(RGBGridPath_gray, "w", **profile) as dst:
                dst.write(scalar_grid, 1)
                result = True

            return result, RGBGridPath_gray
        else:
            return False, False

    def generate_rgb_lut(self, css_color_list, num_entries=1024):
        """
        Generate an RGB LUT list from a list of CSS color names.

        Parameters:
            css_color_list (list): List of CSS color names recognized by Matplotlib.
            num_entries (int): Total number of entries in the LUT.

        Returns:
            list: List of [decimal index, (R, G, B)] where R, G, B are 0-255 integers.
        """
        import matplotlib.colors as mcolors

        # Normalize indices to decimal values between 0 and 1
        decimal_indices = np.linspace(0, 1, num_entries)

        # Create a continuous colormap using the input CSS color list
        try:
            cmap = mcolors.LinearSegmentedColormap.from_list(
                "custom_cmap", css_color_list
            )
        except:

            return False

        # Generate RGB values for each index
        rgb_colors = [cmap(i)[:3] for i in decimal_indices]
        rgb_colors_255 = [
            (int(r * 255), int(g * 255), int(b * 255)) for r, g, b in rgb_colors
        ]

        # Combine decimal indices and RGB tuples
        lut = [
            [round(decimal_index, 6), rgb]
            for decimal_index, rgb in zip(decimal_indices, rgb_colors_255)
        ]

        return lut

    def vector_layer_to_dataframe(self, layer, attribute_name):
        """
        Converts a vector layer to a pandas DataFrame containing x, y coordinates and a chosen attribute.

        Parameters:
        - layer_name (str): The name of the vector layer in the QGIS project.
        - attribute_name (str): The attribute to include in the DataFrame.

        Returns:
        - pd.DataFrame: A DataFrame with columns 'x', 'y', and the chosen attribute.
        """
        # Get the layer by name

        # Check if the attribute exists in the layer
        if attribute_name not in [field.name() for field in layer.fields()]:
            raise ValueError(f"Attribute '{attribute_name}' not found in layer")

        # Extract features and build DataFrame
        data = self.extract_features_to_array(layer, attribute_name)

        crs = layer.crs()  # Get the CRS of the layer
        epsg_code = crs.postgisSrid()  # Retrieve the EPSG code

        # Create and return a DataFrame
        return data, epsg_code

    def extract_features_to_array(self, layer, attribute_name):
        """
        Extract x, y, value data from a QGIS layer into a 3×n NumPy array.

        :param layer: QGIS layer to extract features from.
        :param attribute_name: The attribute name to extract as the value.
        :return: 3×n NumPy array of x, y, and value data.
        """
        data = []

        for feature in layer.getFeatures():
            geom = feature.geometry()
            if geom and geom.isMultipart():
                # Handle multipart geometries
                for part in geom.asMultiPoint():
                    data.append([part.x(), part.y(), feature[attribute_name]])
            elif geom:
                # Handle single-part geometries
                point = geom.asPoint()
                data.append([point.x(), point.y(), feature[attribute_name]])

        # Convert to a NumPy array and return as a 3×n array
        return np.array(data)

    def gridPointData(self):
        self.gridFilePath = self.dlg.lineEdit_gridOutputDir.text()
        if self.gridFilePath and not os.path.exists(self.gridFilePath):

            if self.dlg.mMapLayerComboBox_selectGrid_3.count() > 0:
                self.selectedPoints = (
                    self.dlg.mMapLayerComboBox_selectGrid_3.currentText()
                )
                selected_layer = QgsProject.instance().mapLayersByName(
                    self.selectedPoints
                )[0]
                if selected_layer.isValid():
                    if selected_layer.featureCount() > 3:

                        attribute_name = (
                            self.dlg.comboBox_select_grid_data_field.currentText()
                        )
                        data, epsg = self.vector_layer_to_dataframe(
                            selected_layer, attribute_name
                        )
                        # normalize = False
                        """if self.dlg.radioButton_RST.isChecked():
                            normalize = False"""

                        gridder = GridData(
                            data,
                            self.nx_label,
                            self.ny_label,
                            grid_bounds=None,
                            # normalize=normalize,
                        )
                        cell_size = self.dlg.doubleSpinBox_cellsize.value()

                        """if self.dlg.radioButton_CT.isChecked():
                            new_grid = gridder.interpolate(method="clough_tocher")
                            self.suffix = "_CT"
                        elif self.dlg.radioButton_IDW.isChecked():
                            power = float(self.dlg.lineEdit_IDW_power.text())
                            new_grid = gridder.interpolate(method="idw", power=power)
                            self.suffix = "_IDW"""
                        """elif self.dlg.radioButton_RST.isChecked():
                            new_grid = gridder.interpolate_with_v_surf_rst(
                                self.gridFilePath, epsg, cell_size=cell_size
                            )
                            self.suffix = "_RST"
                        """
                        extent = selected_layer.extent()

                        self.nx_label = int(
                            (extent.xMaximum() - extent.xMinimum()) / cell_size
                        )
                        self.ny_label = int(
                            (extent.yMaximum() - extent.yMinimum()) / cell_size
                        )
                        dir_name, base_name = os.path.split(self.gridFilePath)
                        file_name, file_ext = os.path.splitext(base_name)
                        self.addGridded(
                            new_grid,
                            file_name,
                            self.gridFilePath,
                            epsg,
                            extent,
                            cell_size,
                        )
                    else:
                        self.iface.messageBar().pushMessage(
                            "Data file must have at least 3 points",
                            level=Qgis.Warning,
                            duration=15,
                        )
        else:
            self.iface.messageBar().pushMessage(
                "Cannot overwrite existing file",
                level=Qgis.Warning,
                duration=15,
            )

    # save new gridded data as geotiff
    def addGridded(
        self, grid, filename_without_extension, filepath, epsg, extent, cell_size
    ):

        if self.is_layer_loaded(filename_without_extension):
            layer = QgsProject.instance().mapLayersByName(filename_without_extension)

            if layer:
                # Rename the layer
                filename_without_extension = filename_without_extension + "1"

        self.diskGridPath = filepath

        driver = gdal.GetDriverByName("GTiff")
        try:
            ds = driver.Create(
                filepath,
                xsize=grid.shape[1],
                ysize=grid.shape[0],
                bands=1,
                eType=GDALDataType["GDT_Float32"],
            )

            ds.GetRasterBand(1).WriteArray(grid)
            geot = [
                extent.xMinimum() - (cell_size / 2),
                cell_size,
                0,
                extent.yMinimum() - (cell_size / 2),
                0,
                cell_size,
            ]
            ds.SetGeoTransform(geot)
            srs = osr.SpatialReference()
            srs.ImportFromEPSG(int(epsg))
            ds.SetProjection(srs.ExportToWkt())
            ds.FlushCache()
            ds = None

            self.layer = QgsRasterLayer(self.diskGridPath, filename_without_extension)
            if self.is_layer_loaded(filename_without_extension):
                project = QgsProject.instance()
                layer = project.mapLayersByName(filename_without_extension)[0]
                project.removeMapLayer(layer.id())

            QgsProject.instance().addMapLayer(self.layer)
            # Access the raster data provider
            provider = self.layer.dataProvider()

            # Calculate statistics for the first band
            band = 1  # Specify the band number
            stats = provider.bandStatistics(band)

            # Create or modify the renderer
            renderer = self.layer.renderer()
            if isinstance(renderer, QgsSingleBandGrayRenderer):
                # Set contrast enhancement
                contrast_enhancement = renderer.contrastEnhancement()
                contrast_enhancement.setMinimumValue(stats.minimumValue)
                contrast_enhancement.setMaximumValue(stats.maximumValue)

                # Refresh the layer
                self.layer.triggerRepaint()
            else:
                print("Renderer is not a QgsSingleBandGrayRenderer.")
        except:
            self.iface.messageBar().pushMessage(
                "Data file must have at lesat 3 points",
                level=Qgis.Warning,
                duration=15,
            )

    def create_temp_raster_mask_from_convex_hull(
        self, vector_layer, extent, cell_size=10
    ):
        """
        Creates a temporary raster mask based on the convex hull of a vector layer of points.

        :param vector_layer: The input vector layer (points).
        :param cell_size: Cell size for the output raster (default: 10).
        :return: Path to the temporary raster mask.
        """
        feedback = QgsProcessingFeedback()
        try:
            # Step 1: Generate convex hull
            convex_hull_output = processing.run(
                "qgis:minimumboundinggeometry",
                {"INPUT": vector_layer, "TYPE": 3, "OUTPUT": "TEMPORARY_OUTPUT"},
                feedback=feedback,
            )
            convex_hull_layer = convex_hull_output["OUTPUT"]

            # Step 2: Rasterize the convex hull into a temporary file
            temp_raster_path = tempfile.NamedTemporaryFile(
                suffix=".tif", delete=False
            ).name
            rasterize_output = processing.run(
                "gdal:rasterize",
                {
                    "INPUT": convex_hull_layer,
                    "FIELD": None,  # Use the entire polygon
                    "BURN": 1,  # Value to burn into the raster
                    "UNITS": 1,  # Cell size in map units
                    "WIDTH": cell_size,
                    "HEIGHT": cell_size,
                    "EXTENT": extent,
                    "NODATA": 0,  # Value for no data cells
                    "OUTPUT": temp_raster_path,
                },
                feedback=feedback,
            )

            print(f"Temporary raster mask created: {temp_raster_path}")
            return rasterize_output["OUTPUT"]

        except Exception as e:
            print(f"Error creating raster mask: {e}")
            return None
        # replace with specific processor calls so raster clipping can be done easily...
