# -*- coding: utf-8 -*-
"""
/***************************************************************************
 MoniQueDialog
                                 A QGIS plugin
 Monoplotting oblique images.
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                             -------------------
        begin                : 2024-02-07
        git sha              : $Format:%H$
        copyright            : (C) 2024 by Sebastian Mikolka-Flöry
        email                : s.floery@gmx.at
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

import os
import open3d as o3d
import numpy as np

from qgis.PyQt import QtWidgets, QtCore, QtGui
from qgis.gui import QgsProjectionSelectionWidget
from qgis.core import (
    Qgis,
    QgsProject,
    QgsVectorLayer,
    QgsFeature,
    QgsField,
    QgsGeometry,
    QgsPointXY,
    QgsVectorFileWriter,
)

import operator
from ..lsq import srs_lm
from ..helpers import calc_hfov, calc_vfov

class OrientDialog(QtWidgets.QDialog):
    
    gcp_selected_signal = QtCore.pyqtSignal(object)
    gcp_deselected_signal = QtCore.pyqtSignal()
    gcp_delete_signal = QtCore.pyqtSignal(object)
    get_camera_signal = QtCore.pyqtSignal()
    camera_estimated_signal = QtCore.pyqtSignal(object)
    save_orientation_signal = QtCore.pyqtSignal()
    
    def __init__(self, parent=None, icon_dir=None, active_iid=None):
        """Constructor."""
        super(OrientDialog, self).__init__()

        self.name2ix = {"gid":1,
                        "X":2,
                        "Y":3,
                        "Z":4,
                        "x":5,
                        "y":6,
                        "dx":7,
                        "dy":8}
        
        self.prev_row = -1
        self.init_params = None
        
        self.parent = parent
        self.parent.img_list.setEnabled(False)
        self.parent.activate_gcp_picking()

        self.icon_dir = icon_dir

        self.setWindowTitle("%s - Camera parameter estimation" % (active_iid))
        self.resize(800, 400)
        self.setMinimumSize(QtCore.QSize(800, 400))
        self.setMaximumSize(QtCore.QSize(800, 400))
        
        params_toolbar = QtWidgets.QToolBar("")
        params_toolbar.setIconSize(QtCore.QSize(20, 20))
        
        table_toolbar = QtWidgets.QToolBar("")
        table_toolbar.setIconSize(QtCore.QSize(20, 20))
        
        self.btn_init_ori = QtWidgets.QAction("Set initial orientation from camera view.", self)
        self.btn_init_ori.setIcon(QtGui.QIcon(os.path.join(self.icon_dir, "mActionMeasureBearing.png")))
        self.btn_init_ori.triggered.connect(self.get_camera_signal.emit)
        # self.btn_ori_tool.setCheckable(True)
        # self.btn_ori_tool.setEnabled(False)
        params_toolbar.addAction(self.btn_init_ori)

        self.btn_delete_gcp = QtWidgets.QAction("Delete selected GCP.", self)
        self.btn_delete_gcp.setIcon(QtGui.QIcon(os.path.join(self.icon_dir, "mActionDeleteSelectedFeatures.png")))
        self.btn_delete_gcp.triggered.connect(self.delete_selected_gcp)
        # self.btn_ori_tool.setCheckable(True)
        self.btn_delete_gcp.setEnabled(False)
        table_toolbar.addAction(self.btn_delete_gcp)

        self.table_gcps = QtWidgets.QTableWidget()
        
        self.table_gcps.setSizeAdjustPolicy(QtWidgets.QAbstractScrollArea.AdjustToContents)
        self.table_gcps.setEditTriggers(QtWidgets.QAbstractItemView.NoEditTriggers)
        self.table_gcps.setAlternatingRowColors(True)
        self.table_gcps.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection)
        self.table_gcps.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows)
        self.table_gcps.setObjectName("table_gcps")
        self.table_gcps.setColumnCount(9)
        self.table_gcps.setRowCount(0)
        
        self.table_gcps.cellClicked.connect(self.gcp_selected)
        self.table_gcps.itemChanged.connect(self.gcp_status_changed)
        
        self.table_gcps.horizontalHeader().setHighlightSections(True)
        self.table_gcps.horizontalHeader().resizeSection(0, 10)
        self.table_gcps.horizontalHeader().resizeSection(1, 30)
        self.table_gcps.horizontalHeader().setSectionResizeMode(1, QtWidgets.QHeaderView.Stretch)
        self.table_gcps.horizontalHeader().resizeSection(2, 80)
        self.table_gcps.horizontalHeader().resizeSection(3, 80)
        self.table_gcps.horizontalHeader().resizeSection(4, 80)
        self.table_gcps.horizontalHeader().resizeSection(5, 70)
        self.table_gcps.horizontalHeader().resizeSection(6, 70)
        self.table_gcps.horizontalHeader().resizeSection(7, 60)
        self.table_gcps.horizontalHeader().resizeSection(8, 60)

        self.table_gcps.verticalHeader().setVisible(False)
        
        self.table_gcps.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)

        self.table_gcps.setHorizontalHeaderLabels(["use", "gid", "X", "Y", "Z", "x", "y", "dx", "dy"])

        params_layout = QtWidgets.QVBoxLayout()
        
        params_layout.setSpacing(3)
    
        def create_cam_param_layout(param=None, label_size=25, line_size=125, unit=None):

            layout = QtWidgets.QHBoxLayout()
            param_label = QtWidgets.QLabel(param)
            param_label.setFixedWidth(25)
            param_label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)

            param_line = QtWidgets.QLineEdit()
            param_line.setMinimumWidth(100)
            param_line.setReadOnly(True)
            param_line.setToolTip("Estimated parameter")

            param_std_line = QtWidgets.QLineEdit()
            param_std_line.setFixedWidth(50)
            param_std_line.setToolTip("Standard deviation")
            
            unit_label = QtWidgets.QLabel(unit)
            unit_label.setFixedWidth(25)

            layout.addWidget(param_label)
            layout.addWidget(param_line)
            layout.addWidget(param_std_line)
            layout.addWidget(unit_label)

            return layout, param_line, param_std_line

        obj_x0_layout, obj_x0_line, obj_x0_std_line = create_cam_param_layout(param="X<sub>0</sub>: ", unit=" [m]")
        obj_y0_layout, obj_y0_line, obj_y0_std_line = create_cam_param_layout(param="Y<sub>0</sub>: ", unit=" [m]")
        obj_z0_layout, obj_z0_line, obj_z0_std_line = create_cam_param_layout(param="Z<sub>0</sub>: ", unit=" [m]")

        self.obj_x0_line = obj_x0_line
        self.obj_x0_std_line = obj_x0_std_line
        
        self.obj_y0_line = obj_y0_line
        self.obj_y0_std_line = obj_y0_std_line
        
        self.obj_z0_line = obj_z0_line
        self.obj_z0_std_line = obj_z0_std_line
                
        alpha_layout, alpha_line, alpha_std_line = create_cam_param_layout(param="\u03B1: ", unit=" [°]")
        zeta_layout, zeta_line, zeta_std_line = create_cam_param_layout(param="\u03B6: ", unit=" [°]")
        kappa_layout, kappa_line, kappa_std_line = create_cam_param_layout(param="\u03BA: ", unit=" [°]")

        self.alpha_line = alpha_line
        self.alpha_std_line = alpha_std_line
        
        self.zeta_line = zeta_line
        self.zeta_std_line = zeta_std_line
        
        self.kappa_line = kappa_line
        self.kappa_std_line = kappa_std_line
        

        focal_layout, focal_line, focal_std_line = create_cam_param_layout(param="f: ", unit=" [px]")      
        img_x0_layout, img_x0_line, img_x0_std_line = create_cam_param_layout(param="x<sub>0</sub>: ", unit=" [px]")
        img_y0_layout, img_y0_line, img_y0_std_line = create_cam_param_layout(param="y<sub>0</sub>: ", unit=" [px]")
        
        self.focal_line = focal_line
        self.focal_std_line = focal_std_line
        
        self.img_x0_line = img_x0_line
        self.img_x0_std_line = img_x0_std_line
        
        self.img_y0_line = img_y0_line
        self.img_y0_std_line = img_y0_std_line
        
        params_layout.addWidget(params_toolbar)
        
        params_layout.addLayout(obj_x0_layout)
        params_layout.addLayout(obj_y0_layout)
        params_layout.addLayout(obj_z0_layout)
        
        params_layout.addLayout(alpha_layout)
        params_layout.addLayout(zeta_layout)
        params_layout.addLayout(kappa_layout)

        params_layout.addLayout(focal_layout)
        params_layout.addLayout(img_x0_layout)
        params_layout.addLayout(img_y0_layout)

        params_layout.addStretch()
        
        self.btn_calc_ori = QtWidgets.QPushButton("Calculate")
        self.btn_calc_ori.setEnabled(False)
        self.btn_calc_ori.clicked.connect(self.calc_orientation)
        
        self.btn_save_ori = QtWidgets.QPushButton("Save")
        self.btn_save_ori.setEnabled(False)
        self.btn_save_ori.clicked.connect(self.save_orientation)
        
        btn_layout = QtWidgets.QHBoxLayout()
        btn_layout.setContentsMargins(5, 0, 0, 0)
        btn_layout.addWidget(self.btn_calc_ori)
        btn_layout.addWidget(self.btn_save_ori)

        params_layout.addLayout(btn_layout)

        self.table_layout = QtWidgets.QVBoxLayout()
        self.table_layout.addWidget(table_toolbar)
        self.table_layout.addWidget(self.table_gcps)
        
        main_layout = QtWidgets.QHBoxLayout()
        main_layout.setSpacing(0)
        main_layout.setContentsMargins(5, 0, 5, 0)
        main_layout.addLayout(self.table_layout)
        main_layout.addLayout(params_layout)

        layout = QtWidgets.QVBoxLayout()
        layout.setSpacing(0)
        layout.setContentsMargins(0, 0, 0, 5)
        # layout.addWidget(self.main_toolbar)
        layout.addLayout(main_layout)
        self.setLayout(layout)

        self.error_dialog = QtWidgets.QErrorMessage(parent=self)
        
    def gcp_selected(self, rix, cix):
        
        if cix != 0:
    
            #row was already selected; hence only deselect
            if rix == self.prev_row:
                self.btn_delete_gcp.setEnabled(False)
                self.table_gcps.clearSelection()
                self.prev_row = -1
                self.gcp_deselected_signal.emit()

            else:
                self.btn_delete_gcp.setEnabled(True)
                self.table_gcps.setCurrentCell(rix, cix)
                self.prev_row = rix

                self.sel_gid = self.table_gcps.item(rix, 1).text()
                self.gcp_selected_signal.emit({"gid":self.sel_gid})
    
    def gcp_status_changed(self, item):
        active_gcps = 0
        
        if item.column() == 0:
            nr_rows = self.table_gcps.rowCount()
                    
            for rix in range(nr_rows):
                if self.table_gcps.item(rix, 0).checkState() == QtCore.Qt.Checked:
                    active_gcps += 1
    
            if active_gcps >= 4:
                self.btn_calc_ori.setEnabled(True)
            else:
                self.btn_calc_ori.setEnabled(False)
    
    def delete_selected_gcp(self):
        self.gcp_delete_signal.emit({"gid":self.sel_gid})
        
        self.table_gcps.removeRow(self.prev_row)
        self.table_gcps.clearSelection()
        self.prev_row = -1
        self.btn_delete_gcp.setEnabled(False)
        
        #if GCP is deleted the number of active GCPS might have changed;
        active_gcps = 0
        nr_rows = self.table_gcps.rowCount()
        for rix in range(nr_rows):
            if self.table_gcps.item(rix, 0).checkState() == QtCore.Qt.Checked:
                active_gcps += 1
    
        if active_gcps >= 4:
            self.btn_calc_ori.setEnabled(True)
        else:
            self.btn_calc_ori.setEnabled(False)
    
    def add_gcp_to_table(self, data, gcp_type=None):
        
        nr_rows = self.table_gcps.rowCount()
        nr_cols = self.table_gcps.columnCount()
        
        gcp_exists = False
                
        for rix in range(nr_rows):
            if self.table_gcps.item(rix, 1).text() == data["gid"]:
                
                if gcp_type == "obj_space":
                    self.table_gcps.setItem(rix, self.name2ix["X"], QtWidgets.QTableWidgetItem("%.1f" % (data["obj_x"])))
                    self.table_gcps.setItem(rix, self.name2ix["Y"], QtWidgets.QTableWidgetItem("%.1f" % (data["obj_y"])))
                    self.table_gcps.setItem(rix, self.name2ix["Z"], QtWidgets.QTableWidgetItem("%.1f" % (data["obj_z"])))
                
                elif gcp_type == "img_space":
                    self.table_gcps.setItem(rix, self.name2ix["x"], QtWidgets.QTableWidgetItem("%.1f" % (data["img_x"])))
                    self.table_gcps.setItem(rix, self.name2ix["y"], QtWidgets.QTableWidgetItem("%.1f" % (data["img_y"])))
                
                self.table_gcps.item(rix, 0).setCheckState(QtCore.Qt.Checked)
                self.table_gcps.item(rix, 0).setFlags(QtCore.Qt.ItemIsUserCheckable | QtCore.Qt.ItemIsEnabled)

                gcp_exists = True
                
                break
            
        if not gcp_exists:
            self.table_gcps.insertRow(nr_rows)
            self.table_gcps.setRowHeight(nr_rows, 25)

            chkBoxItem = QtWidgets.QTableWidgetItem()
            chkBoxItem.setFlags(QtCore.Qt.ItemIsUserCheckable)# | Qt.ItemIsEnabled)
            chkBoxItem.setCheckState(QtCore.Qt.Unchecked)
                        
            self.table_gcps.setItem(nr_rows, 0, chkBoxItem)
            self.table_gcps.setItem(nr_rows, self.name2ix["gid"], QtWidgets.QTableWidgetItem(str(data["gid"])))

            if gcp_type == "obj_space":
                self.table_gcps.setItem(nr_rows, self.name2ix["X"], QtWidgets.QTableWidgetItem("%.1f" % (data["obj_x"])))
                self.table_gcps.setItem(nr_rows, self.name2ix["Y"], QtWidgets.QTableWidgetItem("%.1f" % (data["obj_y"])))
                self.table_gcps.setItem(nr_rows, self.name2ix["Z"], QtWidgets.QTableWidgetItem("%.1f" % (data["obj_z"])))
            elif gcp_type == "img_space":
                self.table_gcps.setItem(nr_rows, self.name2ix["x"], QtWidgets.QTableWidgetItem("%.1f" % (data["img_x"])))
                self.table_gcps.setItem(nr_rows, self.name2ix["y"], QtWidgets.QTableWidgetItem("%.1f" % (data["img_y"])))
        
    def add_gcps_from_lyr(self, gcps):
        
        self.active_gcps = 0
        
        nr_gcps = len(gcps)
        
        for rx, (gid, data) in enumerate(gcps.items()):
            
            nr_none_attr = operator.countOf(data.values(), None)
            
            self.table_gcps.insertRow(rx)
            self.table_gcps.setRowHeight(rx, 25)

            chkBoxItem = QtWidgets.QTableWidgetItem()
            flags = chkBoxItem.flags()
            flags |= QtCore.Qt.ItemIsUserCheckable
            
            if nr_none_attr > 2:
                flags &= ~QtCore.Qt.ItemIsEnabled
                        
            chkBoxItem.setFlags(flags)
            
            if data["active"] == "1":
                chkBoxItem.setCheckState(QtCore.Qt.Checked)
            else:
                chkBoxItem.setCheckState(QtCore.Qt.Unchecked)
            
            self.table_gcps.setItem(rx, 0, chkBoxItem)    
            self.table_gcps.setItem(rx, self.name2ix["gid"], QtWidgets.QTableWidgetItem(gid))
            if data["obj_x"]:
                self.table_gcps.setItem(rx, self.name2ix["X"], QtWidgets.QTableWidgetItem("%.1f" % data["obj_x"]))
            
            if data["obj_y"]:
                self.table_gcps.setItem(rx, self.name2ix["Y"], QtWidgets.QTableWidgetItem("%.1f" % data["obj_y"]))
            
            if data["obj_z"]:
                self.table_gcps.setItem(rx, self.name2ix["Z"], QtWidgets.QTableWidgetItem("%.1f" % data["obj_z"]))
            
            if data["img_x"]:
                self.table_gcps.setItem(rx, self.name2ix["x"], QtWidgets.QTableWidgetItem("%.1f" % data["img_x"]))
            
            if data["img_y"]:                                        
                self.table_gcps.setItem(rx, self.name2ix["y"], QtWidgets.QTableWidgetItem("%.1f" % data["img_y"]))
            
            if data["img_dx"]:
                self.table_gcps.setItem(rx, self.name2ix["dx"], QtWidgets.QTableWidgetItem("%.1f" % data["img_dx"]))
            
            if data["img_dy"]:
                self.table_gcps.setItem(rx, self.name2ix["dy"], QtWidgets.QTableWidgetItem("%.1f" % data["img_dy"]))
    
    def update_selected_gcp(self, data, gcp_type=None):
        if gcp_type == "img_space":
            self.table_gcps.setItem(self.table_gcps.currentRow(), self.name2ix["x"], QtWidgets.QTableWidgetItem("%.1f" % (data["img_x"])))
            self.table_gcps.setItem(self.table_gcps.currentRow(), self.name2ix["y"], QtWidgets.QTableWidgetItem("%.1f" % (data["img_y"])))
        elif gcp_type == "obj_space":
            self.table_gcps.setItem(self.table_gcps.currentRow(), self.name2ix["X"], QtWidgets.QTableWidgetItem("%.1f" % (data["obj_x"])))
            self.table_gcps.setItem(self.table_gcps.currentRow(), self.name2ix["Y"], QtWidgets.QTableWidgetItem("%.1f" % (data["obj_y"])))
            self.table_gcps.setItem(self.table_gcps.currentRow(), self.name2ix["Z"], QtWidgets.QTableWidgetItem("%.1f" % (data["obj_z"])))
    
    def set_init_params(self, data):
        self.obj_x0_line.setText("%.1f" % (data["obj_x0"]))
        self.obj_y0_line.setText("%.1f" % (data["obj_y0"]))
        self.obj_z0_line.setText("%.1f" % (data["obj_z0"]))
        

        #display euler angles in degrees; but in data array its still in radiant
        self.alpha_line.setText("%.3f" % (np.rad2deg(data["alpha"])))
        self.zeta_line.setText("%.3f" % (np.rad2deg(data["zeta"])))
        self.kappa_line.setText("%.3f" % (np.rad2deg(data["kappa"])))
        
        self.img_x0_line.setText("%.1f" % (data["img_x0"]))
        self.img_y0_line.setText("%.1f" % (data["img_y0"]))
        self.focal_line.setText("%.1f" % (data["f"]))
        
        if "obj_x0_std" in list(data.keys()):
            self.obj_x0_std_line.setText("%.1f" % (data["obj_x0_std"]))
            self.obj_y0_std_line.setText("%.1f" % (data["obj_y0_std"]))
            self.obj_z0_std_line.setText("%.1f" % (data["obj_z0_std"]))
            
            self.alpha_std_line.setText("%.3f" % (np.rad2deg(data["alpha_std"])))
            self.zeta_std_line.setText("%.3f" % (np.rad2deg(data["zeta_std"])))
            self.kappa_std_line.setText("%.3f" % (np.rad2deg(data["kappa_std"])))
            
            self.focal_std_line.setText("%.1f" % (data["f_std"]))
                
        self.init_params = data.copy()
    
    def set_residuals(self, data):
        
        used_gids = list(data["residuals"].keys())
        
        nr_rows = self.table_gcps.rowCount()
                
        for rix in range(nr_rows):
            
            curr_gid = self.table_gcps.item(rix,1).text()
            
            if curr_gid in used_gids:
                self.table_gcps.setItem(rix, self.name2ix["dx"], QtWidgets.QTableWidgetItem("%.1f" % (data["residuals"][curr_gid][0])))
                self.table_gcps.setItem(rix, self.name2ix["dy"], QtWidgets.QTableWidgetItem("%.1f" % (data["residuals"][curr_gid][1])))
            else:
                self.table_gcps.setItem(rix, self.name2ix["dx"], QtWidgets.QTableWidgetItem(""))
                self.table_gcps.setItem(rix, self.name2ix["dy"], QtWidgets.QTableWidgetItem(""))
             
    def calc_orientation(self):
        
        if not self.init_params:
            self.error_dialog.showMessage('Set initial camera parameters first!')
        else:
            ori_data = {"gid":[], "img":[], "obj":[], "init_params":None}
            
            nr_rows = self.table_gcps.rowCount()
            for rix in range(nr_rows):
                if self.table_gcps.item(rix, 0).checkState() == QtCore.Qt.Checked:
                    curr_gid = self.table_gcps.item(rix, self.name2ix["gid"]).text()
                    curr_obj_x = float(self.table_gcps.item(rix, self.name2ix["X"]).text())
                    curr_obj_y = float(self.table_gcps.item(rix, self.name2ix["Y"]).text())
                    curr_obj_z = float(self.table_gcps.item(rix, self.name2ix["Z"]).text())
                    curr_img_x = float(self.table_gcps.item(rix, self.name2ix["x"]).text())
                    curr_img_y = float(self.table_gcps.item(rix, self.name2ix["y"]).text())
                    
                    ori_data["gid"].append(curr_gid)
                    ori_data["img"].append([curr_img_x, curr_img_y])
                    ori_data["obj"].append([curr_obj_x, curr_obj_y, curr_obj_z])
                    
            
            ori_data["init_params"] = self.init_params
            res = srs_lm(ori_data)
                        
            if res.success == False:
                self.error_dialog.showMessage('LSQ did not converge: %s' % (res.message))
            else:
                est_obj_x0 = res.params["obj_x0"].value
                est_obj_y0 = res.params["obj_y0"].value
                est_obj_z0 = res.params["obj_z0"].value
                est_alpha = res.params["alpha"].value
                est_zeta = res.params["zeta"].value
                est_kappa = res.params["kappa"].value
                est_focal = res.params["f"].value

                cxx = res.covar
                
                cxx_std = np.sqrt(np.diag(cxx))
                cxx_names = ["%s_std" % (name) for name in res.var_names]
                cxx_dict = dict(zip(cxx_names, cxx_std.tolist()))
                
                gcp_residuals = res.residual.reshape(-1, 2)
                gcp_gids = ori_data["gid"]
                
                gcp_dict = {"residuals": dict(zip(gcp_gids, gcp_residuals.tolist()))}
                
                est_data = {"obj_x0":est_obj_x0,
                            "obj_y0":est_obj_y0,
                            "obj_z0":est_obj_z0,
                            "alpha":est_alpha,
                            "zeta":est_zeta,
                            "kappa":est_kappa,
                            "img_x0":self.init_params["img_x0"],
                            "img_y0":self.init_params["img_y0"],
                            "f":est_focal}
                est_data = {**est_data, **cxx_dict}
                est_data = {**est_data, **gcp_dict}
                
                self.camera_estimated_signal.emit(est_data)
                self.set_init_params(est_data)
                self.set_residuals(est_data)
                self.btn_save_ori.setEnabled(True)

    def save_orientation(self):
        self.save_orientation_signal.emit()
    
    def closeEvent(self, event):
        self.parent.img_list.setEnabled(True)
        self.parent.btn_ori_tool.setChecked(False)

        self.parent.deactivate_gcp_picking()
        self.parent.discard_changes()