import networkx as nx
import shapely.ops as so
import shapely.geometry as sg
import math
from utilities.dcmgeometrysdk.dcmgeometry.lines import LineGeom
from utilities.dcmgeometrysdk.geometryfunctions.bearingdistancefunctions import calc_distance, \
    calc_bearing, calc_inside_360

class SurveyGraph:

    def __init__(self, lines=None, points=None, add_unconnected=True,
                 add_gen_lines=False, remove=tuple(), join_branches=True, graph=None):
        """class to hold a graph that contains the lines and points of a survey, this is used for recaclulation via
             shortest path and also misclosures using cycle basis, relies on networkx
             :param lines: dictionary of LineGeom and ArcGeom
             :param points: dictionary of PointGeom
             :param add_unconnected: add in points that arent connected to the survey by
                                    calculating the bearing and distances from coordinates
             :param add_gen_lines: add the generated lines to the lines dictionary
             :param remove lines from graph that are of a certain line type
             :param join_branches: if the graph has unconnected branches of survey, join them with
                                    calculated bearing and distances
             :param graph: send existing graph into the class
             :type lines: dict or None
             :type points: dict or None
             :type add_unconnected: bool
             :type add_gen_lines: bool
             :type remove: tuple
             :type join_branches: bool
             :type graph: networkx.Graph or None"""

        self.unconnected_nodes = set()
        self.unconnected_branches = set()
        self.lines = lines
        if lines is not None:
            self.graph = self.generate_graph(lines, points, add_unconnected, add_gen_lines, remove, join_branches)
        elif graph is not None:
            self.graph = graph
        else:
            self.graph = None
        self.connected = nx.is_k_edge_connected(self.ignore_line_type(line_types=('Generated', 'BranchConnection')), 1)

    def return_branches(self, graph=None):
        """returns the branches of a graph, can be a new graph or self, a branch is
        where the survey is not connected together by a measurement
        :param graph: Network X graph, or self
        :type graph: networkx.Graph or None"""

        if graph is None:
            graph = self.graph
        return list(nx.k_edge_components(graph, 1))

    def generate_graph(self, lines, points=None, add_unconnected=True, add_gen_lines=False, remove=tuple(),
                       join_branches=False):

        """generates graph from the input :parameters
             :param lines: dictionary of LineGeom and ArcGeom
             :param points: dictionary of PointGeom
             :param add_unconnected: add in points that arent connected to the survey by
                                    calculating the bearing and distances from coordinates
             :param add_gen_lines: add the generated lines to the lines dictionary
             :param remove lines from graph that are of a certain line :type
             :param join_branches: if the graph has unconnected branches of survey, join them with
                                    calculated bearing and distances
             :param graph: send existing graph into the class
             :type lines: dict or None
             :type points: dict or None
             :type add_unconnected: bool
             :type add_gen_lines: bool
             :type remove: tuple
             :type join_branches: bool"""

        g = nx.Graph()
        nodes = set()
        if points is not None:
            rev_look = {v.geometry.coords[:][0]: k for k, v in points.items()}
        else:
            rev_look = {}
            if add_unconnected is True:
                print('No points provided, setting add unconnected to False')
                add_unconnected = False
            if add_gen_lines is True:
                print('No points provided, setting add gen lines to False')
                add_gen_lines = False
            if join_branches is True:
                print('No points provided, setting join branches to False')
                join_branches = False

        if len(lines) == 0:
            pts = [v for v in points.values()][:2]
            line = sg.LineString([pts[0].geometry, pts[1].geometry])
            gen_line = LineGeom()
            gen_line.create_line_from_coords(setup_point=pts[0], target_point=pts[1], coords=line, name='OBS-1')
            lines[(pts[0].name, pts[1].name)] = gen_line

        for k, val in lines.items():
            if not isinstance(val, list):
                nval = [val]
            else:
                nval = val
            for v in nval:
                if v.distance_type not in remove:
                    setup, target = k
                    if v.geometry is not None:
                        geom = v.geometry.wkt
                    else:
                        geom = None
                    g.add_edge(setup, target, distance=v.distance, bearing=v.dd_bearing, st=(setup, target),
                               distance_type=v.distance_type, az_type=v.azimuth_type, geom=geom)
                    nodes.add(setup)
                    nodes.add(target)


        # find unconnected points
        self.unconnected_nodes = set()
        self.unconnected_branches = set()
        if points is not None:
            for point in points.keys():
                if point not in nodes:
                    self.unconnected_nodes.add(point)

        connected = nx.is_k_edge_connected(g, 1)
        counter = 0
        if connected is False :
            # get largest branch
            branches = list(nx.k_edge_components(g, 1))
            max_b = max(branches, key=len)
            ref_point = sorted([item for item in max_b])[-1]
            if join_branches is False:
                for item in branches:
                    if item != max_b:
                        self.unconnected_nodes.add(list(item)[0])
                        for p in item:
                            self.unconnected_branches.add(p)
            elif join_branches is True and len(branches) > 1:
                for branch in branches:
                    multi = set()
                    for b in branches:
                        if b != branch:
                            multi = multi.union(b)
                    branch_geom = sg.MultiPoint([points.get(point).geometry for point in branch])
                    rest_geom = sg.MultiPoint([points.get(point).geometry for point in branch])
                    branch_nearest = so.nearest_points(branch_geom, rest_geom)[1].coords[:][0]
                    rest_nearest = so.nearest_points(rest_geom, branch_geom)[1].coords[:][0]
                    rest_point = points.get(rev_look.get(branch_nearest))
                    branch_point = points.get(rev_look.get(rest_nearest))
                    line = LineGeom()
                    if branch_point.name != rest_point.name:
                        line.create_line_from_coords(branch_point, rest_point, name=f'generated-{str(counter)}',
                                                     line_type='BranchConnection', crs=branch_point.crs)
                        counter += 1

                        g.add_edge(branch_point.name, rest_point.name, distance=line.distance, bearing=line.dd_bearing,
                                   st=(branch_point.name, rest_point.name), distance_type='BranchConnection',
                                   az_type='BranchConnection', geom=line.geometry.wkt)
                        if add_gen_lines is True:
                            lines[(branch_point.name, rest_point.name)] = line

        else:
            ref_points = sorted(list(nodes))
            ref_point = ref_points[0]

        # add in generated edges for unconnected points.
        if add_unconnected is True:

            con_multi = sg.MultiPoint([v.geometry for k, v in points.items() if k not in
                                       self.unconnected_nodes.union(self.unconnected_branches)])

            for point in self.unconnected_nodes:
                tg = points.get(point)
                target = points.get(point).geometry
                close_point = so.nearest_points(target, con_multi)[1].coords[:][0]
                ref = rev_look.get(close_point, ref_point)
                sp = points.get(ref)
                if sg is not None and target is not None:
                    if sp.name != tg.name:
                        line = LineGeom()
                        line.distance_std = .01
                        line.bearing_std = 1
                        line.create_line_from_coords(sp, tg, name=f'generated-{str(counter)}', line_type='GraphGenerated',
                                                     crs=sp.crs)
                        g.add_edge(ref, point, distance=line.distance, bearing=line.dd_bearing, st=(ref, point),
                                   distance_type='GraphGenerated', az_type='GraphGenerated', geom=line.geometry.wkt)
                        counter += 1
                        if add_gen_lines is True:
                            lines[(sp.name, tg.name)] = line

        return g

    def ignore_line_type(self, line_types=('Ignored', 'Generated', 'GraphGenerated', 'BranchConnection'), g=None):
        """returns a graph ignoring edges that have both a distance :type and azimuth :type in the specified line :types,
            can use self or existing graph
            :param line_types: tuple of line :types that are ignored when returning the graph
            :param g: send existing graph into this function
            :type line_types: tuple
            :type g: networkx.Graph"""
        if g is None:
            return nx.Graph(((u, v, e) for u, v, e in self.graph.edges(data=True) if
                        (e['distance_type'] not in line_types and e['az_type'] not in line_types)))
        else:
            return nx.Graph(((u, v, e) for u, v, e in g.edges(data=True) if
                             (e['distance_type'] not in line_types and e['az_type'] not in line_types)))


    def ignore_line_type_branches(self, line_types=('Ignored', 'Generated', 'GraphGenerated', 'BranchConnection')):
        """returns the branches of the graph at the same time as removing line :types.
        :param line_types: tuple of line :types that are ignored when returning the graph
        :type line_types: tuple"""
        g = self.ignore_line_type(line_types)
        return self.return_branches(g)

    def graph_from_branches(self, line_types=('Ignored', 'Generated', 'GraphGenerated', 'BranchConnection'), k=1):
        """returns the branches of the graph removing k lines, see network x documentation for explaination.
                :param line_types: tuple of line :types that are ignored when returning the graph
                :param k: Generates nodes in each maximal k-edge-connected component in self.graph.
                :type line_types: tuple
                :type k: int"""
        graph = self.ignore_line_type(line_types)
        graphs = [self.graph.subgraph(n) for n in nx.k_edge_subgraphs(graph, k)]

        return graphs

    def most_connected_node(self):
        return max(dict(self.graph.degree()).items(), key=lambda x: x[1])

    def remove_radiations(self, points_to_keep=None, pqs=None, ys=None):
        """remove radiations from graph. radiations are where nodes are only connected to one edge
        :param points_to_keep: keep points even if they are only connected to one edge
        :param pqs: set of points that have set positions via pqs in dynadjust
        :param ys: set of points that have set positions via ys in dynadjust
        :type points_to_keep: set, list or None
        :type pqs: set, list or None
        :type ys: set, list or None"""
        if pqs is None:
            pqs = set()
        if ys is None:
            ys = set()
        if points_to_keep is None:
            points_to_keep = set()

        remove = [n for n, d in dict(self.graph.degree()).items() if d < 2 and n not in points_to_keep and
                  n not in ys and n not in pqs]
        self.graph.remove_nodes_from(remove)

    def remove_unconstrained_measures(self, pqs, ys):
        """remove unconstrained from graph
        :param pqs: set of points that have set positions via pqs in dynadjust
        :param ys: set of points that have set positions via ys in dynadjust

        :type pqs: set, list or None
        :type ys: set, list or None"""
        unconstrained = []
        for k, val in self.lines.items():
            if not (k[0] in pqs or k[0] in ys or k[1] in pqs or k[1] in ys):
                distance = False
                direction = False
                if not isinstance(val, list):
                    nval = [val]
                else:
                    nval = val
                for v in nval:
                    if v.distance is not None:
                        distance = True
                    if v.dd_bearing is not None:
                        direction = True
                if distance is False or direction is False:
                    unconstrained.append(k)
        self.graph.remove_edges_from(unconstrained)
        return unconstrained




