#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
#   Copyright (C) 2018 Oslandia <infos@oslandia.com>
#
#   This file is a piece of free software; you can redistribute it and/or
#   modify it under the terms of the GNU Library General Public
#   License as published by the Free Software Foundation; either
#   version 2 of the License, or (at your option) any later version.
#
#   This library is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#   Library General Public License for more details.
#   You should have received a copy of the GNU Library General Public
#   License along with this library; if not, see <http://www.gnu.org/licenses/>.
#

from qgis.PyQt.QtCore import pyqtSignal, QObject, QVariant
from qgis.core import QgsFeatureRequest
from PyQt5.QtXml import QDomDocument

import numpy as np


class DataInterface(QObject):
    """DataInterface is a class that abstracts how a bunch of (X,Y) data are represented"""

    data_modified = pyqtSignal()

    def __init__(self):
        QObject.__init__(self)

    def get_x_values(self):
        raise (self.tr("DataInterface is an abstract class, get_x_values() must be defined"))

    def get_y_values(self):
        raise (self.tr("DataInterface is an abstract class, get_y_values() must be defined"))

    def get_x_min(self):
        raise (self.tr("DataInterface is an abstract class, get_x_min() must be defined"))

    def get_x_max(self):
        raise (self.tr("DataInterface is an abstract class, get_x_min() must be defined"))

    def get_y_min(self):
        raise (self.tr("DataInterface is an abstract class, get_y_min() must be defined"))

    def get_y_max(self):
        raise (self.tr("DataInterface is an abstract class, get_y_max() must be defined"))

    def get_ids_values(self):
        return self.tr("DataInterface is an abstract class, get_y_max() must be defined")

    def get_uncertainty_values(self):
        return self.tr("DataInterface is an abstract class, get_y_max() must be defined")

    # keep here just for compatibility but it should'nt exist
    # plot_item doesn't need layer object
    # FIXME a layer seems to be needed for symbology.
    # TODO: try to make it internal to plotter UI
    def get_layer(self):
        raise (self.tr("DataInterface is an abstract class, get_layer() must be defined"))


class LayerData(DataInterface):
    """LayerData model data that are spanned on multiple features (resp. rows) on a layer (resp. table).

    This means each feature (resp. row) of a layer (resp. table) has one (X,Y) pair.
    They will be sorted on X before being displayed.
    """

    # BREAKING CHANGE FOR NEXT RELEASE
    # nodata_value should be None by default (i.e. remove null values)
    def __init__(
        self,
        layer,
        ground_elevation,
        x_fieldname,
        y_fieldname,
        uncertainty_fieldname=None,
        filter_expression=None,
        nodata_value=0.0,
        uom=None,
        symbology=None,
        symbology_type=None,
    ):
        """
        Parameters
        ----------
        layer: QgsVectorLayer
          Vector layer that holds data
        ground_elevation : float
          Feature elevation value
        x_fieldname: str
          Name of the field that holds X values
        y_fieldname: str
          Name of the field that holds Y values
        filter_expression: str
          Filter expression
        nodata_value: Optional[float]
          If None, null values will be removed
          Otherwise, they will be replaced by nodata_value
        uom: Optional[str]
          Unit of measure
          If uom starts with "@" it means the unit of measure is carried by a field name
          e.g. @unit means the field "unit" carries the unit of measure
        """

        DataInterface.__init__(self)

        self.__ground_elevation = ground_elevation
        self.__y_fieldname = y_fieldname
        self.__x_fieldname = x_fieldname
        self.__uncertainty_fieldname = uncertainty_fieldname
        self.__layer = layer
        self.__ids = None
        self.__y_values = None
        self.__x_values = None
        self.__uncertainty_values = None
        self.__x_min = None
        self.__x_max = None
        self.__y_min = None
        self.__y_max = None
        self.__filter_expression = filter_expression
        self.__nodata_value = nodata_value
        self.__uom = uom
        self.__symbology = None
        if symbology is not None:
            self.__symbology = QDomDocument()
            self.__symbology.setContent(symbology)
        self.__symbology_type = symbology_type

        layer.attributeValueChanged.connect(self.__build_data)
        layer.featureAdded.connect(self.__build_data)
        layer.featureDeleted.connect(self.__build_data)

        self.__build_data()

    def get_ids_values(self):
        return self.__ids_values

    def get_y_values(self):
        return self.__y_values

    def set_y_values(self, y_values):
        """Used in stacked data to fit to main layer uom"""
        self.__y_values = y_values

    def get_layer(self):
        return self.__layer

    def get_x_values(self):
        return self.__x_values

    def get_uncertainty_values(self):
        return self.__uncertainty_values

    def get_x_min(self):
        return self.__x_min

    def get_x_max(self):
        return self.__x_max

    def get_y_min(self):
        return self.__y_min

    def get_y_max(self):
        return self.__y_max

    def get_symbology(self):
        return self.__symbology

    def get_symbology_type(self):
        return self.__symbology_type

    def get_uom(self):
        return self.__uom

    def set_uom(self, uom):
        self.__uom = uom

    def __build_data(self):

        req = QgsFeatureRequest()
        if self.__filter_expression is not None:
            req.setFilterExpression(self.__filter_expression)

        # Get unit of the first feature if needed
        if self.__uom is not None and self.__uom.startswith("@"):
            request_unit = QgsFeatureRequest(req)
            request_unit.setLimit(1)
            for f in self.__layer.getFeatures(request_unit):
                self.__uom = f[self.__uom[1:]]
                break

        subset_attr = [self.__x_fieldname, self.__y_fieldname]
        if self.__uncertainty_fieldname is not None:
            subset_attr += [self.__uncertainty_fieldname]
        req.setSubsetOfAttributes(subset_attr, self.__layer.fields())
        # Do not forget to add an index on this field to speed up ordering
        req.addOrderBy(self.__x_fieldname, ascending=False)

        def is_null(v):
            return isinstance(v, QVariant) and v.isNull()

        if self.__nodata_value is not None:
            # replace null values by a 'Nodata' value
            values = [
                (
                    f.id(),
                    f[self.__x_fieldname]
                    if self.__ground_elevation is False
                    else self.__ground_elevation - f[self.__x_fieldname],
                    f[self.__y_fieldname]
                    if not is_null(f[self.__y_fieldname])
                    else self.__nodata_value,
                    f[self.__uncertainty_fieldname]
                    if self.__uncertainty_fieldname is not None
                    else None,
                )
                for f in self.__layer.getFeatures(req)
            ]
        else:
            # do not include null values
            values = [
                (
                    f.id(),
                    f[self.__x_fieldname]
                    if self.__ground_elevation is False
                    else self.__ground_elevation - f[self.__x_fieldname],
                    f[self.__y_fieldname],
                    f[self.__uncertainty_fieldname]
                    if self.__uncertainty_fieldname is not None
                    else 0,
                )
                for f in self.__layer.getFeatures(req)
                if not is_null(f[self.__y_fieldname])
            ]

        self.__ids_values = [val[0] for val in values]
        self.__x_values = [val[1] for val in values]
        self.__y_values = [val[2] for val in values]
        self.__uncertainty_values = [val[3] for val in values]

        self.__x_min, self.__x_max = (
            (min(self.__x_values), max(self.__x_values)) if self.__x_values else (None, None)
        )
        self.__y_min, self.__y_max = (
            (min(self.__y_values), max(self.__y_values)) if self.__y_values else (None, None)
        )

        self.data_modified.emit()


class IntervalData(DataInterface):
    """LayerData model data that are spanned on multiple features (resp. rows) on a layer (resp. table).

    This means each feature (resp. row) of a layer (resp. table) has one (X,Y) pair.
    They will be sorted on X before being displayed.
    """

    # BREAKING CHANGE FOR NEXT RELEASE
    # nodata_value should be None by default (i.e. remove null values)
    def __init__(
        self,
        layer,
        ground_elevation,
        x_fieldnames,
        y_fieldname,
        uncertainty_fieldname=None,
        filter_expression=None,
        nodata_value=0.0,
        uom=None,
        symbology=None,
        symbology_type=None,
    ):
        """
        Parameters
        ----------
        layer: QgsVectorLayer
          Vector layer that holds data
        ground_elevation : float
          Feature elevation value
        x_fieldname: str
          Name of the field that holds X values
        y_fieldname: str
          Name of the field that holds Y values
        filter_expression: str
          Filter expression
        nodata_value: Optional[float]
          If None, null values will be removed
          Otherwise, they will be replaced by nodata_value
        uom: Optional[str]
          Unit of measure
          If uom starts with "@" it means the unit of measure is carried by a field name
          e.g. @unit means the field "unit" carries the unit of measure
        """

        DataInterface.__init__(self)

        self.__ground_elevation = ground_elevation
        self.__y_fieldname = y_fieldname
        self.__x_fieldnames = x_fieldnames
        self.__uncertainty_fieldname = uncertainty_fieldname
        self.__layer = layer
        self.__ids = None
        self.__y_values = None
        self.__min_x_values = None
        self.__max_x_values = None
        self.__uncertainty_values = None
        self.__x_min = None
        self.__x_max = None
        self.__y_min = None
        self.__y_max = None
        self.__filter_expression = filter_expression
        self.__nodata_value = nodata_value
        self.__uom = uom
        self.__symbology = None
        if symbology is not None:
            self.__symbology = QDomDocument()
            self.__symbology.setContent(symbology)
        self.__symbology_type = symbology_type

        layer.attributeValueChanged.connect(self.__build_data)
        layer.featureAdded.connect(self.__build_data)
        layer.featureDeleted.connect(self.__build_data)

        self.__build_data()

    def get_ids_values(self):
        return self.__ids_values

    def get_y_values(self):
        return self.__y_values

    def set_y_values(self, y_values):
        """Used in stacked data to fit to main layer uom"""
        self.__y_values = y_values

    def get_layer(self):
        return self.__layer

    def get_x_values(self):
        return (self.__min_x_values, self.__max_x_values)

    def get_min_x_values(self):
        return self.__min_x_values

    def get_max_x_values(self):
        return self.__max_x_values

    def get_uncertainty_values(self):
        return self.__uncertainty_values

    def get_x_min(self):
        return self.__x_min

    def get_x_max(self):
        return self.__x_max

    def get_y_min(self):
        return self.__y_min

    def get_y_max(self):
        return self.__y_max

    def get_symbology(self):
        return self.__symbology

    def get_symbology_type(self):
        return self.__symbology_type

    def get_uom(self):
        return self.__uom

    def set_uom(self, uom):
        self.__uom = uom

    def __build_data(self):

        req = QgsFeatureRequest()
        if self.__filter_expression is not None:
            req.setFilterExpression(self.__filter_expression)

        # Get unit of the first feature if needed
        if self.__uom is not None and self.__uom.startswith("@"):
            request_unit = QgsFeatureRequest(req)
            request_unit.setLimit(1)
            for f in self.__layer.getFeatures(request_unit):
                self.__uom = f[self.__uom[1:]]
                break

        subset_attr = self.__x_fieldnames + [self.__y_fieldname]
        if self.__uncertainty_fieldname is not None:
            subset_attr += [self.__uncertainty_fieldname]
        req.setSubsetOfAttributes(subset_attr, self.__layer.fields())
        # Do not forget to add an index on this field to speed up ordering
        req.addOrderBy(self.__x_fieldnames[0], ascending=False)

        def is_null(v):
            return isinstance(v, QVariant) and v.isNull()

        if self.__nodata_value is not None:
            # replace null values by a 'Nodata' value
            values = [
                (
                    f.id(),
                    f[self.__x_fieldnames[0]]
                    if self.__ground_elevation is False
                    else self.__ground_elevation - f[self.__x_fieldnames[0]],
                    f[self.__x_fieldnames[1]]
                    if self.__ground_elevation is False
                    else self.__ground_elevation - f[self.__x_fieldnames[1]],
                    f[self.__y_fieldname]
                    if not is_null(f[self.__y_fieldname])
                    else self.__nodata_value,
                    f[self.__uncertainty_fieldname]
                    if self.__uncertainty_fieldname is not None
                    else None,
                )
                for f in self.__layer.getFeatures(req)
            ]
        else:
            # do not include null values
            values = [
                (
                    f.id(),
                    f[self.__x_fieldnames[0]]
                    if self.__ground_elevation is False
                    else self.__ground_elevation - f[self.__x_fieldnames[0]],
                    f[self.__x_fieldnames[1]]
                    if self.__ground_elevation is False
                    else self.__ground_elevation - f[self.__x_fieldnames[1]],
                    f[self.__y_fieldname],
                    f[self.__uncertainty_fieldname]
                    if self.__uncertainty_fieldname is not None
                    else 0,
                )
                for f in self.__layer.getFeatures(req)
                if not is_null(f[self.__y_fieldname])
            ]

        self.__ids_values = [val[0] for val in values]
        self.__min_x_values = [val[1] for val in values]
        self.__max_x_values = [val[2] for val in values]
        self.__y_values = [val[3] for val in values]
        self.__uncertainty_values = [val[4] for val in values]

        self.__x_min, self.__x_max = (
            (min(self.__min_x_values), max(self.__max_x_values))
            if self.__min_x_values and self.__max_x_values
            else (None, None)
        )
        self.__y_min, self.__y_max = (
            (min(self.__y_values), max(self.__y_values)) if self.__y_values else (None, None)
        )

        self.data_modified.emit()


class FeatureData(DataInterface):
    """FeatureData model data that are stored on one feature (resp. row) in a layer (resp. table).

    This usually means data are the result of a sampling with a regular sampling interval for X.
    The feature has one array attribute that stores all values and the X values are given during
    construction.
    """

    def __init__(
        self,
        layer,
        feature_elevation,
        y_fieldname,
        x_values=None,
        feature_ids=None,
        uncertainty_fieldname=None,
        x_start=None,
        x_delta=None,
        x_start_fieldname=None,
        x_delta_fieldname=None,
        uom=None,
        symbology=None,
        symbology_type=None,
    ):
        """
        layer: input QgsVectorLayer
        y_fieldname: name of the field in the input layer that carries data
        x_values: sequence of X values, that should be of the same length as data values.
                  If None, X values are built based on x_start and x_delta
        x_start: starting X value. Should be used with x_delta
        x_delta: interval between two X values.
        feature_ids: IDs of the features read. If set to None, the input data are assumed to represent one feature with ID=0
                     If more than one feature id is passed, their data will be merged.
                     In case of overlap between features, one will be arbitrarily chosen, and a warning will be raised.
        x_start_fieldname: name of the field in the input layer that carries the starting X value
        x_delta_fieldname: name of the field in the input layer that carries the interval between two X values
        """
        x_start_defined = x_start is not None or x_start_fieldname is not None
        x_delta_defined = x_delta is not None or x_delta_fieldname is not None

        if x_values is None:
            if not x_start_defined and not x_delta_defined:
                raise ValueError(self.tr("Define either x_values or x_start / x_delta"))
            if (not x_start_defined and x_delta_defined) or (
                x_start_defined and not x_delta_defined
            ):
                raise ValueError(self.tr("Both x_start and x_delta must be defined"))

        if feature_ids is None:
            feature_ids = [0]

        if x_start_fieldname is None and len(feature_ids) > 1:
            raise ValueError(
                self.tr(
                    "More than one feature, but only one starting value, define x_start_fieldname"
                )
            )
        if x_delta_fieldname is None and len(feature_ids) > 1:
            raise ValueError(
                self.tr("More than one feature, but only one delta value, define x_delta_fieldname")
            )

        DataInterface.__init__(self)

        self.__feature_elevation = feature_elevation
        self.__y_fieldname = y_fieldname
        self.__layer = layer
        self.__ids_values = feature_ids
        self.__x_values = x_values
        self.__uncertainty_fieldname = uncertainty_fieldname
        self.__x_start = x_start
        self.__x_start_fieldname = x_start_fieldname
        self.__x_delta = x_delta
        self.__x_delta_fieldname = x_delta_fieldname
        self.__uom = uom
        self.__symbology = None
        if symbology is not None:
            self.__symbology = QDomDocument()
            self.__symbology.setContent(symbology)
        self.__symbology_type = symbology_type

        # TODO connect on feature modification

        self.__build_data()

    def get_ids_values(self):
        return self.__ids_values

    def get_y_values(self):
        return self.__y_values

    def set_y_values(self, y_values):
        self.__y_values = y_values

    def get_layer(self):
        return self.__layer

    def get_x_values(self):
        return self.__x_values

    def get_uncertainty_values(self):
        return self.__uncertainty_values

    def get_x_min(self):
        return self.__x_min

    def get_x_max(self):
        return self.__x_max

    def get_y_min(self):
        return self.__y_min

    def get_y_max(self):
        return self.__y_max

    def get_symbology(self):
        return self.__symbology

    def get_symbology_type(self):
        return self.__symbology_type

    def get_uom(self):
        return self.__uom

    def set_uom(self, uom):
        self.__uom = uom

    def __build_data(self):

        req = QgsFeatureRequest()
        req.setFilterFids(self.__ids_values)

        self.__x_values = []
        self.__y_values = []
        self.__ids = []
        self.__uncertainty_values = []

        current_data_range = None
        for f in self.__layer.getFeatures(req):
            raw_data = f[self.__y_fieldname]
            if self.__x_start_fieldname is not None:
                x_start = f[self.__x_start_fieldname]
                x_delta = f[self.__x_delta_fieldname]
            else:
                x_start = self.__x_start
                x_delta = self.__x_delta

            if self.__feature_elevation:
                x_start = self.__feature_elevation - x_start
                x_delta = -x_delta

            if isinstance(raw_data, list):
                # QGIS 3 natively reads array values
                # Null values still have to be filtered out
                # WARNING: extracting list from PostgreSQL's arrays seem very sloowww
                y_values = [None if isinstance(x, QVariant) else x for x in raw_data]
            elif isinstance(raw_data, str):
                # We assume values are separated by a ','
                y_values = [
                    None if value == "NULL" else float(value) for value in raw_data.split(",")
                ]
            else:
                print(self.tr("Unsupported data format:") + " {}".format(raw_data.__class__))

            x_values = np.linspace(
                x_start, x_start + x_delta * (len(y_values) - 1), len(y_values)
            ).tolist()

            data_range = (x_start, x_start + x_delta * (len(y_values) - 1))
            if current_data_range is None:
                current_data_range = data_range
                self.__ids.append(f.id())
                self.__x_values = x_values
                self.__y_values = y_values
                if self.__uncertainty_fieldname:
                    self.__uncertainty_values = [
                        float(u) for u in f[self.__uncertainty_fieldname].split(",")
                    ]
            else:
                # look for overlap
                if (current_data_range[0] < data_range[0] < current_data_range[1]) or (
                    current_data_range[0] < data_range[1] < current_data_range[1]
                ):
                    print(self.tr("Overlap in data around feature #") + " {}".format(f.id()))
                    continue
                if current_data_range[0] > data_range[1]:
                    # new data are "on the left"
                    self.__ids.append(f.id())
                    self.__x_values = x_values + self.__x_values
                    self.__y_values = y_values + self.__y_values
                    self.__uncertainty_values = [
                        float(u) for u in f[self.__uncertainty_fieldname].split(",")
                    ] + self.__uncertainty_values
                else:
                    # new data are "on the right"
                    self.__ids.append(f.id())
                    self.__x_values = self.__x_values + x_values
                    self.__y_values = self.__y_values + y_values
                    if self.__uncertainty_fieldname:
                        self.__uncertainty_values = self.__uncertainty_values + [
                            float(u) for u in f[self.__uncertainty_fieldname].split(",")
                        ]
                current_data_range = (self.__x_values[0], self.__x_values[-1])

        self.__x_min, self.__x_max = (
            (min(self.__x_values), max(self.__x_values)) if self.__x_values else (None, None)
        )
        self.__y_min, self.__y_max = (
            (
                min(y for y in self.__y_values if y is not None),
                max(y for y in self.__y_values if y is not None),
            )
            if self.__y_values
            else (None, None)
        )

        self.data_modified.emit()
