from domain.CommonHelper import type_lookup_dict, get_keys_from_value
from domain.plan.infrastructure.registered_layer_processor.helper import fetch_point_id, check_key_exists, \
    fetch_point_service_token
from domain.plan.model.point_detail import PointDetail
from domain.plan.model.parcel import Parcel
from domain.plan.model.parcel_line import ParcelLine
from domain.plan.model.admin import Admin as PlanAdmin
from domain.plan.model.line_detail import LineDetail
from shapely.geometry.multipolygon import MultiPolygon
from domain.plan.model.point import Point
from domain.plan.model.point_history import Point_History
from shapely import wkt
from domain.plan.model.created_spi import Created_Spi
from datetime import datetime
from utilities.dcmgeometrysdk.dcmgeometry.arcs import ArcGeom
import logging


class LoadGeometryInfoToPlanDB:
    """
    This class helps to load the geometry info to admin, line_detail, point, point_detail, parcel,
    parcel_line tables.
    """

    def __init__(self, plan_db_session, event_id, file_geometry, json_object):
        self.plan_db = plan_db_session
        self.event_id = event_id
        self.geometry = file_geometry
        self.json_object = json_object
        self.point_oid_dict = {}
        self.type_lookup_dict = type_lookup_dict(self.plan_db)

    def load_all(self):
        lines = []
        point_details = []
        parcels = []
        parcel_line_dict = {}
        if self.geometry.polygons:
            parcels, parcel_line_dict = self.load_parcel(self.geometry.polygons)
            logging.info("Mapped parcel data.")
        if self.geometry.points:
            point_details = self.load_point_detail(self.geometry.points)
            logging.info("Mapped point details data.")
            self.load_point(self.geometry.points)
            logging.info("Mapped point data.")
        if self.geometry.lines:
            lines = self.fetch_line_detail(self.geometry.lines, self.geometry.survey_number, parcel_line_dict)
            logging.info("Mapped lines data.")
        if self.geometry.admin:
            admin_id = self.load_plan_admin(self.geometry.admin, lines, parcels, point_details)
            logging.info("Mapped Admin.")
            self.update_line_detail_id(admin_id)
            logging.info("Loaded geometry details to plan db.")
            return admin_id

    # geometries.admin
    def load_plan_admin(self, admin_info, lines, parcels, point_details):
        admin = PlanAdmin(
            line_detail=lines,
            point_detail=point_details,
            parcel=parcels,
            survey_number=admin_info.plan_number,
            survey_date=admin_info.date_of_survey,
            reg_date=self.get_registered_date(admin_info.registration_date),
            eplan=None,
            lga_code=admin_info.lga_code,
            create_event_id=self.event_id,
            excluded=False,
            excluded_reason=None,
            xml_working_folder=admin_info.plan_number + "/" + self.json_object["document_version_id"],
            pseudo=False,
            ignore_event_id=None
        )
        with self.plan_db as open_session:
            open_session.add(admin)
            open_session.commit()
            return admin.id

    def get_registered_date(self, registration_date):
        if registration_date is None:
            created_spi_objects = self.plan_db.query(Created_Spi.id, Created_Spi.reg_date).filter(
                Created_Spi.spear_id == self.json_object["spear_id"]).all()
            if created_spi_objects is not None:
                spi_dates = [spi[1] for spi in created_spi_objects]
                spi_dates.sort(key=lambda date: datetime.strptime(date, '%d/%m/%Y'))
                return datetime.strptime(spi_dates[0], '%d/%m/%Y').date()
            else:
                return None

        return registration_date

    # geometries.lines
    def fetch_line_detail(self, lines_info, survey_number, parcel_line_dict):
        lines = []
        for line_key, line_obj in lines_info.items():
            parcel_line_value = []
            line_detail_type_id = get_keys_from_value(self.type_lookup_dict, ['purpose_type', line_obj.line_type])
            if isinstance(line_obj, ArcGeom):
                line_detail_radius = line_obj.radius
                line_detail_arc_length = line_obj.arc_length
                line_detail_rotation = line_obj.rot
            else:
                line_detail_radius = None
                line_detail_arc_length = None
                line_detail_rotation = None
            line_detail_azimuth_accuracy = check_key_exists(line_obj, 'azimuthAccuracy')
            bearing_type_id = get_keys_from_value(self.type_lookup_dict, ['observation_type', line_obj.azimuth_type])
            distance_type_id = get_keys_from_value(self.type_lookup_dict, ['observation_type', line_obj.distance_type])
            if line_obj.name in parcel_line_dict:
                parcel_line_value = parcel_line_dict[line_obj.name]

            line_detail = LineDetail(
                name=line_obj.name,
                line_type=line_detail_type_id,
                setup_point_id=self.get_setup_point_id(line_obj, line_key[0]),
                target_point_id=self.get_target_point_id(line_obj, line_key[1]),
                bearing_dms=line_obj.hp_bearing,
                bearing_type=bearing_type_id,
                distance=line_obj.distance,
                distance_type=distance_type_id,
                radius=line_detail_radius,
                ref_survey_number=survey_number,
                description=line_obj.desc,
                adjust_type=None,
                distance_std=line_obj.distance_std,
                direction_std=line_detail_azimuth_accuracy,
                original_distance=line_obj.orig_distance_std,
                original_bearing_dms=None,
                az_adopt_factor_dms=line_obj.az_adopt_fact,
                unswung_bearing_dms=None,
                line_geometry=line_obj.geometry.wkt,
                rotation=line_detail_rotation,
                arc_length=line_detail_arc_length,
                create_event_id=self.event_id,
                ignore_event_id=None,
                pseudo=False,
                parcel_line=parcel_line_value
            )
            lines.append(line_detail)
        return lines

    def get_setup_point_id(self, line_object, point_name):
        if line_object.setup_point.associated_point_oid is None:
            return self.point_oid_dict[point_name]
        else:
            return line_object.setup_point.associated_point_oid

    def get_target_point_id(self, line_object, point_name):
        if line_object.target_point.associated_point_oid is None:
            return self.point_oid_dict[point_name]
        else:
            return line_object.target_point.associated_point_oid

    def create_point_ids_dict(self, points):
        bearer_token = fetch_point_service_token(self.json_object)
        for value in points:
            if value.point_oid is not None and value.point_type == 'control':
                point_id_up = value.point_oid

            elif value.associated_point_oid is None:
                point_id_up = fetch_point_id(value.geometry.wkt, bearer_token, self.json_object["point_api_url"])

            else:
                point_id_up = value.associated_point_oid

            if point_id_up in self.point_oid_dict.values():
                logging.info("Call point service with low tolerance.")
                point_id_up = fetch_point_id(value.geometry.wkt, bearer_token, self.json_object["point_api_url"],
                                             tolerance=0.05)

            self.point_oid_dict[value.name] = point_id_up

    # geometries.point_detail
    def load_point_detail(self, point_info):
        point_detail_list = []
        self.create_point_ids_dict(point_info.values())
        for value in point_info.values():
            point_type_id = get_keys_from_value(self.type_lookup_dict, ['point_type', value.point_type])
            point_state_id = get_keys_from_value(self.type_lookup_dict, ['point_state', value.point_state])
            monument_state_id = get_keys_from_value(self.type_lookup_dict, ['monument_state', value.mon_state])
            monument_condition_id = get_keys_from_value(self.type_lookup_dict,
                                                        ['monument_condition', value.mon_condition])
            mark_type_id = get_keys_from_value(self.type_lookup_dict, ['mark_type', value.mon_type])

            point_detail = PointDetail(point_id=self.point_oid_dict[value.name],
                                       point_type=point_type_id,
                                       point_state=point_state_id,
                                       point_name=value.name,
                                       mark_type=mark_type_id,
                                       mark_state=monument_state_id,
                                       mark_condition=monument_condition_id,
                                       ref_survey_number=self.geometry.survey_number,
                                       description=value.mon_desc,
                                       point_geometry=value.geometry.wkt,
                                       create_event_id=self.event_id,
                                       pseudo=False)
            point_detail_list.append(point_detail)
        return point_detail_list

    def load_parcel(self, polygons):
        parcels = []
        parcel_line_dict = {}
        for parcel_key in polygons:
            parcel_value = polygons[parcel_key]
            # Get parcel lines list
            parcel_line_list = self.load_parcel_line(parcel_value.line_order, parcel_line_dict)
            parcel_type_id = get_keys_from_value(self.type_lookup_dict, ['parcel_type', parcel_value.parcel_type])
            parcel_state_id = get_keys_from_value(self.type_lookup_dict, ['parcel_state', parcel_value.parcel_state])
            parcel_class_id = get_keys_from_value(self.type_lookup_dict, ['parcel_class', parcel_value.parcel_class])
            parcel_format_id = get_keys_from_value(self.type_lookup_dict, ['parcel_format', parcel_value.parcel_format])

            try:
                polygon_geom = wkt.loads(parcel_value.geometry.wkt)
                if not isinstance(polygon_geom, MultiPolygon):
                    multipolygon_geom = MultiPolygon([polygon_geom])
                    parcel_geom_val = multipolygon_geom.wkt
                else:
                    parcel_geom_val = parcel_value.geometry.wkt
            except Exception:
                parcel_geom_val = None
            parcel = Parcel(spi=parcel_value.name,
                            parcel_type=parcel_type_id,
                            parcel_state=parcel_state_id,
                            parcel_class=parcel_class_id,
                            parcel_format=parcel_format_id,
                            parcel_line=parcel_line_list,
                            description=parcel_value.desc,
                            name=parcel_value.name,
                            parcel_geometry=parcel_geom_val,
                            create_event_id=self.event_id,
                            parcel_use=parcel_value.parcel_use)
            parcels.append(parcel)
        return parcels, parcel_line_dict

    def load_parcel_line(self, line_order, parcel_line_dict):
        parcel_lines = []
        for parcel_line_key in line_order:
            parcel_line_value_list = line_order[parcel_line_key]
            for i in range(len(parcel_line_value_list)):
                value = parcel_line_value_list[i]
                ring = parcel_line_key[parcel_line_key.find("-") + 1:]
                parcel_line = ParcelLine(seq=i,
                                         reverse=value.reversed,
                                         ring=ring,
                                         parcel_index=0,
                                         create_event_id=self.event_id
                                         )
                parcel_lines.append(parcel_line)
                # check if the key already exists in the dictionary, if true append the value to the dict
                # This dictionary is required to map the line_detail & parcel_detail tables
                if value.name in parcel_line_dict:
                    parcel_line_list = parcel_line_dict[value.name]
                    parcel_line_list.append(parcel_line)
                else:
                    parcel_line_dict[value.name] = [parcel_line]

        return parcel_lines

    # geometries.point
    def load_point(self, point_info):
        point_list = []
        for value in point_info.values():
            point = self.fetch_point(value, self.point_oid_dict[value.name])
            if point is not None:
                point_list.append(point)
        with self.plan_db as open_session:
            if len(point_list) > 0:
                open_session.add_all(point_list)
                open_session.commit()
        return point_list

    def fetch_point(self, value, point_id_up):
        """
        Gets point and makes necessary updates to point and point history tables.

        :param value: The geometry object of the point.
        :param point_id_up: The point id for the point.
        :return: The fetched point.
        """
        with self.plan_db as open_session:
            point = None
            dataset_type_id = get_keys_from_value(self.type_lookup_dict, ['dataset_type', 'DCAS_Plan_data_load'])

            point_id_lookup = open_session.query(Point).filter_by(
                point_id=point_id_up,
                dataset_type_id=dataset_type_id).all()

            point_history_lookup = open_session.query(Point_History).filter_by(point_id=point_id_up,
                                                                               dataset_type_id=dataset_type_id).all()

            # when point exists in point table
            if len(point_id_lookup) > 0:
                open_session.query(Point).filter_by(point_id=point_id_up,
                                                    dataset_type_id=dataset_type_id).update(
                    {Point.draw_position: value.geometry.wkt, Point.update_event_id: self.event_id},
                    synchronize_session=False)

                if len(point_history_lookup) > 0:
                    open_session.query(Point_History).filter_by(point_id=point_id_up,
                                                                dataset_type_id=dataset_type_id).update(
                        {Point_History.retire_event_id: self.event_id}, synchronize_session=False)
                    point_history = Point_History(
                        point_id=point_id_up,
                        draw_position=value.geometry.wkt,
                        dataset_type_id=dataset_type_id,
                        create_event_id=self.event_id,
                        update_event_id=self.event_id
                    )

                    open_session.add(point_history)
                    open_session.commit()
                else:
                    point_history = Point_History(
                        point_id=point_id_up,
                        draw_position=value.geometry.wkt,
                        dataset_type_id=dataset_type_id,
                        create_event_id=self.event_id,
                        update_event_id=self.event_id
                    )
                    open_session.add(point_history)
                    open_session.commit()
            # when point does not exist in point table
            else:
                point = Point(
                    point_id=point_id_up,
                    draw_position=value.geometry.wkt,
                    dataset_type_id=dataset_type_id,
                    create_event_id=self.event_id
                )
                if len(point_history_lookup) > 0:
                    open_session.query(Point_History).filter_by(point_id=point_id_up,
                                                                dataset_type_id=dataset_type_id).update(
                        {Point_History.update_event_id: self.event_id}, synchronize_session=False)
                else:
                    point_history = Point_History(
                        point_id=point_id_up,
                        draw_position=value.geometry.wkt,
                        dataset_type_id=dataset_type_id,
                        create_event_id=self.event_id,
                        update_event_id=self.event_id
                    )
                    open_session.add(point_history)
                    open_session.commit()
            return point

    def update_line_detail_id(self, admin_id):
        """
        This method helps to update the line id value with id value in the line detail table.
        """
        filtered_rows = self.plan_db.query(LineDetail).filter(LineDetail.admin_id == admin_id).all()
        for row in filtered_rows:
            row.line_id = row.id
        self.plan_db.commit()
