import numpy as np

from openlog.datamodel.assay.generic_assay import AssayDataExtent


class GeoExtractor:
    """
    Abstract class designed to extract x,y,z geographical coordinates of an assay measure for a specific hole.
    Subclass should be used as an attribute of a GenericAssay instance and called during item construction (plot_item_factory.py).
    For depth domain only.
    """

    def __init__(self, assay) -> None:
        self.assay = assay
        self.geom = "effective_geom"

    def get_planned_eoh(self):
        """
        Return altitude of planned EOH.
        """
        return NotImplementedError

    def get_altitude(self, planned: bool = False):
        self.geom = "planned_geom" if planned else "effective_geom"
        if self.assay.assay_definition.data_extent == AssayDataExtent.DISCRETE:
            return self.get_discrete_altitude()
        else:
            return self.get_extended_altitude()

    def get_discrete_altitude(self) -> np.ndarray:
        coords = self.get_discrete_coordinates()
        if coords is not None:
            z = coords[:, 2]
            return z
        else:
            return

    def get_extended_altitude(self):
        coords = self.get_extended_coordinates()
        if coords is not None:
            z = [np.array(coords)[:, i] for i in [2, 5]]
            return z
        else:
            return

    def get_coordinates(self, planned: bool = False):
        self.geom = "planned_geom" if planned else "effective_geom"
        if self.assay.assay_definition.data_extent == AssayDataExtent.DISCRETE:
            return self.get_discrete_coordinates()
        else:
            return self.get_extended_coordinates()

    def get_discrete_coordinates(self):
        return NotImplementedError

    def get_extended_coordinates(self):
        return NotImplementedError


class SpatialiteGeoExtractor(GeoExtractor):
    def __init__(self, assay) -> None:
        super().__init__(assay)

    def get_planned_eoh(self):
        q = f"""
            SELECT st_z(ATM_transform(ST_Line_Interpolate_Point(ATM_transform(planned_geom, ATM_CreateXRoll(90)), planned_eoh/max(ifnull(eoh, 0), ifnull(planned_eoh, 0))), ATM_CreateXRoll(-90))) AS c_z
            FROM collar
            WHERE hole_id = '{self.assay.hole_id}'
            """
        alt = self.assay._session.execute(q).fetchall()[0][0]
        return alt

    def get_discrete_coordinates(self) -> np.ndarray:

        q = f"""SELECT c_x, c_y, c_z FROM
                (SELECT {self.assay.hole_col.name},o.{self.assay.x_col.name}, max(ifnull(eoh, 0), ifnull(planned_eoh, 0)),
                st_x(ATM_transform(ST_Line_Interpolate_Point(ATM_transform({self.geom}, ATM_CreateXRoll(90)), o.{self.assay.x_col.name}/max(ifnull(eoh, 0), ifnull(planned_eoh, 0))), ATM_CreateXRoll(-90))) AS c_x,
                st_y(ATM_transform(ST_Line_Interpolate_Point(ATM_transform({self.geom}, ATM_CreateXRoll(90)), o.{self.assay.x_col.name}/max(ifnull(eoh, 0), ifnull(planned_eoh, 0))), ATM_CreateXRoll(-90))) AS c_y,
                st_z(ATM_transform(ST_Line_Interpolate_Point(ATM_transform({self.geom}, ATM_CreateXRoll(90)), o.{self.assay.x_col.name}/max(ifnull(eoh, 0), ifnull(planned_eoh, 0))), ATM_CreateXRoll(-90))) AS c_z
                FROM (SELECT * FROM {self.assay._assay_table.name} where {self.assay.hole_col.name} = '{self.assay.hole_id}') AS o
                LEFT JOIN collar AS c
                ON o.{self.assay.hole_col.name} = c.hole_id
                ORDER BY o.{self.assay.x_col.name} ) AS tmp

                """

        alt = self.assay._session.execute(q).fetchall()
        if len(alt) == 0:
            return None
        alt = np.array(alt)
        # if geometry doesn't exist
        if alt[0, 0] is None:
            return None

        return alt

    def get_extended_coordinates(self):

        q = f"""SELECT f_x, f_y, f_z, t_x, t_y, t_z FROM
                (SELECT {self.assay.hole_col.name},o.{self.assay.x_col.name}, o.{self.assay.x_end_col.name}, max(ifnull(eoh, 0), ifnull(planned_eoh, 0)),
                st_x(ATM_transform(ST_Line_Interpolate_Point(ATM_transform({self.geom}, ATM_CreateXRoll(90)), o.{self.assay.x_col.name}/max(ifnull(eoh, 0), ifnull(planned_eoh, 0))), ATM_CreateXRoll(-90))) AS f_x,
                st_y(ATM_transform(ST_Line_Interpolate_Point(ATM_transform({self.geom}, ATM_CreateXRoll(90)), o.{self.assay.x_col.name}/max(ifnull(eoh, 0), ifnull(planned_eoh, 0))), ATM_CreateXRoll(-90))) AS f_y,
                st_z(ATM_transform(ST_Line_Interpolate_Point(ATM_transform({self.geom}, ATM_CreateXRoll(90)), o.{self.assay.x_col.name}/max(ifnull(eoh, 0), ifnull(planned_eoh, 0))), ATM_CreateXRoll(-90))) AS f_z,
                st_x(ATM_transform(ST_Line_Interpolate_Point(ATM_transform({self.geom}, ATM_CreateXRoll(90)), o.{self.assay.x_end_col.name}/max(ifnull(eoh, 0), ifnull(planned_eoh, 0))), ATM_CreateXRoll(-90))) AS t_x,
                st_y(ATM_transform(ST_Line_Interpolate_Point(ATM_transform({self.geom}, ATM_CreateXRoll(90)), o.{self.assay.x_end_col.name}/max(ifnull(eoh, 0), ifnull(planned_eoh, 0))), ATM_CreateXRoll(-90))) AS t_y,
                st_z(ATM_transform(ST_Line_Interpolate_Point(ATM_transform({self.geom}, ATM_CreateXRoll(90)), o.{self.assay.x_end_col.name}/max(ifnull(eoh, 0), ifnull(planned_eoh, 0))), ATM_CreateXRoll(-90))) AS t_z
                FROM (SELECT * FROM {self.assay._assay_table.name} where {self.assay.hole_col.name} = '{self.assay.hole_id}') AS o
                LEFT JOIN collar AS c
                ON o.{self.assay.hole_col.name} = c.hole_id
                ORDER BY o.{self.assay.x_col.name} ) AS tmp"""

        alt = self.assay._session.execute(q).fetchall()
        if len(alt) == 0:
            return None
        alt = np.array(alt)
        # if geometry doesn't exist
        if alt[0, 0] is None:
            return None
        return alt


class XplordbGeoExtractor(GeoExtractor):
    def __init__(self, assay) -> None:
        super().__init__(assay)

    def get_planned_eoh(self):
        q = f"""
            SELECT st_z(ST_SwapOrdinates(ST_LineInterpolatePoints(ST_SwapOrdinates(ST_CurveToLine(planned_geom), 'zx'), planned_eoh/greatest(eoh, planned_eoh), false), 'xz')) AS c_z
            FROM display.display_collar
            WHERE hole_id = '{self.assay.hole_id}'
            """
        alt = self.assay._session.execute(q).fetchall()[0][0]
        return alt

    def get_discrete_coordinates(self) -> np.ndarray:

        q = f"""SELECT c_x, c_y, c_z FROM
                (SELECT {self.assay.hole_col.name},o.{self.assay.x_col.name}, eoh,
                st_x(ST_SwapOrdinates(ST_LineInterpolatePoints(ST_SwapOrdinates(ST_CurveToLine({self.geom}), 'zx'), least(o.{self.assay.x_col.name}/greatest(eoh, planned_eoh), 1.0), false), 'xz')) AS c_x,
                st_y(ST_SwapOrdinates(ST_LineInterpolatePoints(ST_SwapOrdinates(ST_CurveToLine({self.geom}), 'zx'), least(o.{self.assay.x_col.name}/greatest(eoh, planned_eoh), 1.0), false), 'xz')) AS c_y,
                st_z(ST_SwapOrdinates(ST_LineInterpolatePoints(ST_SwapOrdinates(ST_CurveToLine({self.geom}), 'zx'), least(o.{self.assay.x_col.name}/greatest(eoh, planned_eoh), 1.0), false), 'xz')) AS c_z
                FROM (SELECT * FROM assay.{self.assay._assay_table.name} where {self.assay.hole_col.name} = '{self.assay.hole_id}') AS o
                LEFT JOIN display.display_collar AS c
                ON o.{self.assay.hole_col.name} = c.hole_id
                ORDER BY o.{self.assay.x_col.name} ) AS tmp"""

        alt = self.assay._session.execute(q).fetchall()
        if len(alt) == 0:
            return None
        alt = np.array(alt)
        # if geometry doesn't exist
        if alt[0, 0] is None:
            return None
        return alt

    def get_extended_coordinates(self):

        q = f"""SELECT f_x, f_y, f_z, t_x, t_y, t_z FROM
                (SELECT {self.assay.hole_col.name},o.{self.assay.x_col.name}, o.{self.assay.x_end_col.name}, eoh,
                st_x(ST_SwapOrdinates(ST_LineInterpolatePoints(ST_SwapOrdinates(ST_CurveToLine({self.geom}), 'zx'), least(o.{self.assay.x_col.name}/greatest(eoh, planned_eoh), 1.0), false), 'xz')) AS f_x,
                st_y(ST_SwapOrdinates(ST_LineInterpolatePoints(ST_SwapOrdinates(ST_CurveToLine({self.geom}), 'zx'), least(o.{self.assay.x_col.name}/greatest(eoh, planned_eoh), 1.0), false), 'xz')) AS f_y,
                st_z(ST_SwapOrdinates(ST_LineInterpolatePoints(ST_SwapOrdinates(ST_CurveToLine({self.geom}), 'zx'), least(o.{self.assay.x_col.name}/greatest(eoh, planned_eoh), 1.0), false), 'xz')) AS f_z,
                st_x(ST_SwapOrdinates(ST_LineInterpolatePoints(ST_SwapOrdinates(ST_CurveToLine({self.geom}), 'zx'), least(o.{self.assay.x_end_col.name}/greatest(eoh, planned_eoh), 1.0), false), 'xz')) AS t_x,
                st_y(ST_SwapOrdinates(ST_LineInterpolatePoints(ST_SwapOrdinates(ST_CurveToLine({self.geom}), 'zx'), least(o.{self.assay.x_end_col.name}/greatest(eoh, planned_eoh), 1.0), false), 'xz')) AS t_y,
                st_z(ST_SwapOrdinates(ST_LineInterpolatePoints(ST_SwapOrdinates(ST_CurveToLine({self.geom}), 'zx'), least(o.{self.assay.x_end_col.name}/greatest(eoh, planned_eoh), 1.0), false), 'xz')) AS t_z
                FROM (SELECT * FROM assay.{self.assay._assay_table.name} where {self.assay.hole_col.name} = '{self.assay.hole_id}') AS o
                LEFT JOIN display.display_collar AS c
                ON o.{self.assay.hole_col.name} = c.hole_id
                ORDER BY o.{self.assay.x_col.name} ) AS tmp"""

        alt = self.assay._session.execute(q).fetchall()
        if len(alt) == 0:
            return None
        alt = np.array(alt)
        # if geometry doesn't exist
        if alt[0, 0] is None:
            return None
        return alt
