# Import basic libs
import xml.etree.ElementTree as ET

# Import Qgis libs
from qgis.core import QgsFeature

# Import PyQt libs
from qgis.PyQt.QtCore import QObject, QUrl, pyqtSignal
from qgis.PyQt.QtNetwork import QNetworkReply, QNetworkRequest


class GetTaxrefFromWiki(QObject):
    finished_dl = pyqtSignal()
    """
    API Request to WikiData to get the taxon for every occurence by inatId.
    network_manager : the one
    project : qgis project
    layer : initial layer
    no_occurence_layer : the layer containing occurence with no matching taxon
    one_occurence_layer : the layer containing occurence with one matching taxon
    dlg : main dialog widget
    field : name of the field containing the taxon id
    selected_only : know if only the selected occurences must be processed
    api_road : get api used to find taxref equivalent
    """

    def __init__(
        self,
        network_manager=None,
        project=None,
        layer=None,
        no_occurence_layer=None,
        one_occurence_layer=None,
        multiple_occurence_layer=None,
        dlg=None,
        field=None,
        selected_only=None,
        api_road=None,
    ):
        super().__init__()
        self.network_manager = network_manager
        self.project = project
        self.layer = layer
        self.no_occurence_layer = no_occurence_layer
        self.one_occurence_layer = one_occurence_layer
        self.multiple_occurence_layer = multiple_occurence_layer
        self.field = field
        self.thread = dlg.thread
        self.progress_bar = dlg.select_progress_bar_label
        selected_only = selected_only
        self._pending_downloads = 0
        if api_road == "GBIF_ID":
            self.wiki_id = "P846"
        elif api_road == "iNaturalist_ID":
            self.wiki_id = "P3151"

        self.ids = {}

        # results list
        self.no_occurence = []
        self.multi_occurence = []
        self.multiples_occurences_values = {}

        if not selected_only:
            selected_features = self.layer.selectedFeatureIds()
            self.layer.selectAll()
        for obs in self.layer.getSelectedFeatures():
            if obs[self.field] not in self.ids:
                self.ids[obs[self.field]] = [obs.id()]
            else:
                self.ids[obs[self.field]].append(obs.id())

        if not selected_only:
            self.layer.removeSelection()
            self.layer.selectByIds(selected_features)

        self._pending_downloads = len(self.ids)
        self._iterate_ids = 0

        self.thread.set_max(len(self.ids))
        self.thread.add_one(0)
        self.progress_bar.setText(
            self.tr("Downloaded data : " + str(0) + "/" + str(len(self.ids)))
        )

        self.download(list(self.ids.keys())[self._iterate_ids])

    @property
    def pending_downloads(self):
        return self._pending_downloads

    @property
    def iterate_ids(self):
        return self._iterate_ids

    def download(self, taxon_id):
        self._iterate_ids += 1
        url = (
            "https://query.wikidata.org/sparql?query=SELECT ?GBIF_Taxon_ID ?iNat_Taxon_ID ?TAXREF_ID ?Taxon_Name WHERE { ?item wdt:"
            + self.wiki_id
            + ' "'
            + str(taxon_id)
            + '". OPTIONAL { ?item wdt:P846 ?GBIF_Taxon_ID. } OPTIONAL { ?item wdt:P3151 ?iNat_Taxon_ID. } OPTIONAL { ?item wdt:P3186 ?TAXREF_ID. } OPTIONAL { ?item wdt:P225 ?Taxon_Name. }} '
        )
        request = QNetworkRequest(QUrl(url))
        request.setHeader(
            QNetworkRequest.ContentTypeHeader,
            "application/sparql-results+json",
        )
        reply = self.network_manager.get(request)
        reply.finished.connect(lambda: self.handle_finished(reply, self.ids[taxon_id]))

    def handle_finished(self, reply, features_id):
        self._pending_downloads -= 1
        if reply.error() != QNetworkReply.NetworkError.NoError:
            print(f"code: {reply.error()} message: {reply.errorString()}")
            if reply.error() == 403:
                print("Service down")
        else:
            data_request = reply.readAll().data().decode()
            if data_request != "":
                sparql = "{http://www.w3.org/2005/sparql-results#}"
                root = ET.fromstring(data_request)
                if root.find(sparql + "results").find(sparql + "result"):
                    taxref_list = {}
                    taxref_id = None
                    for result in root.find(sparql + "results").findall(
                        sparql + "result"
                    ):
                        for bind in result.findall(sparql + "binding"):
                            attrib = bind.attrib
                            if attrib["name"] == "TAXREF_ID":
                                taxref_id = bind.find(sparql + "literal").text
                            if taxref_id:
                                if attrib["name"] == "Taxon_Name":
                                    if taxref_id not in taxref_list:
                                        taxref_list[taxref_id] = bind.find(
                                            sparql + "literal"
                                        ).text
                    if len(taxref_list) == 1:
                        self.one_occurence_layer.startEditing()
                        for feature_id in features_id:
                            feature = QgsFeature(self.one_occurence_layer.fields())
                            feature.setGeometry(
                                self.layer.getFeature(feature_id).geometry()
                            )
                            attributes = self.layer.getFeature(feature_id).attributes()
                            attributes.append(list(taxref_list.keys())[0])
                            attributes.append(list(taxref_list.values())[0])
                            attributes.append(
                                "https://inpn.mnhn.fr/espece/cd_nom/{cd_nom}".format(
                                    cd_nom=list(taxref_list.keys())[0]
                                )
                            )
                            feature.setAttributes(attributes)
                            self.one_occurence_layer.dataProvider().addFeature(feature)
                        self.one_occurence_layer.updateExtents()
                        self.one_occurence_layer.commitChanges()
                        self.one_occurence_layer.triggerRepaint()
                    elif len(taxref_list) > 1:
                        self.multiple_occurence_layer.startEditing()
                        for feature_id in features_id:
                            feature = self.layer.getFeature(feature_id)
                            self.multi_occurence.append(feature_id)
                            feature = self.layer.getFeature(feature_id)
                            self.multiple_occurence_layer.dataProvider().addFeature(
                                feature
                            )
                        self.multiple_occurence_layer.updateExtents()
                        self.multiple_occurence_layer.commitChanges()
                        self.multiple_occurence_layer.triggerRepaint()

                        occurences_available = {}
                        for occurence in taxref_list:
                            occurences_available[str(occurence)] = taxref_list[
                                occurence
                            ]
                            self.multiples_occurences_values[
                                str(feature[self.field])
                            ] = occurences_available
                    else:
                        self.no_occurence_layer.startEditing()
                        for feature_id in features_id:
                            self.no_occurence.append(feature_id)
                            feature = self.layer.getFeature(feature_id)
                            self.no_occurence_layer.dataProvider().addFeature(feature)
                        self.no_occurence_layer.updateExtents()
                        self.no_occurence_layer.commitChanges()
                        self.no_occurence_layer.triggerRepaint()
                else:
                    self.no_occurence_layer.startEditing()
                    for feature_id in features_id:
                        self.no_occurence.append(feature_id)
                        feature = self.layer.getFeature(feature_id)
                        self.no_occurence_layer.dataProvider().addFeature(feature)
                    self.no_occurence_layer.updateExtents()
                    self.no_occurence_layer.commitChanges()
                    self.no_occurence_layer.triggerRepaint()
        if self.pending_downloads == 0:
            self.finished_dl.emit()
        else:
            self.thread.set_max(len(self.ids))
            self.thread.add_one(1)
            self.progress_bar.setText(
                self.tr(
                    "Downloaded data : "
                    + str(self.thread.value)
                    + "/"
                    + str(len(self.ids))
                )
            )
            self.download(list(self.ids.keys())[self._iterate_ids])
