# -*- coding: utf-8 -*-

"""
/***************************************************************************
 ClusterPoints
                                 A QGIS plugin
 Cluster Points conducts spatial clustering of points based on their mutual distance to each other. The user can select between the K-Means algorithm and (agglomerative) hierarchical clustering with several different link functions.
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2020-03-30
        copyright            : (C) 2020 by Johannes Jenkner
        email                : jjenkner@web.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Johannes Jenkner'
__date__ = '2021-12-28'
__copyright__ = '(C) 2021 by Johannes Jenkner'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'



number_sample_points = 250
number_correction_trials = 25

import random

from qgis.core import (QgsPoint,QgsPointXY,Qgis,QgsTask,QgsMessageLog)

from math import floor,ceil,sqrt
from sys import float_info

MESSAGE_CATEGORY = 'ClusterPoints: Preparation'


class CFTask(QgsTask):
    
    def __init__(self, description, data, agglomeration_percentile=0,
                 d = None, pa = 0, manhattan = False):
        super().__init__(description, QgsTask.CanCancel)
        self.__data = data
        self.__agglomeration_percentile = agglomeration_percentile
        
        self.d = d
        self.pa = pa
        self.manhattan = manhattan
        self.size = 0
        
        self.result = None

    def cancel(self):
        QgsMessageLog.logMessage("Preparation task cancelled",
            MESSAGE_CATEGORY, Qgis.Critical)
        super().cancel()

    def run(self):
        """
        Execution of task
        """

        self.derive_cf_radius()
        self.result = self.create_blobs()
        return self.result

    def finished(self,result):
        """
        Called upon finish of execution
        """
        
        if result:
             QgsMessageLog.logMessage(self.tr("Successful execution of preparation task"),
                       MESSAGE_CATEGORY, Qgis.Success)
        else:
             QgsMessageLog.logMessage(self.tr("Execution of preparation task failed"),
                       MESSAGE_CATEGORY, Qgis.Critical)

    def derive_cf_radius(self):
    
        # draw <number_sample_points> random points 
        # to estimate mean distance between individual points
        
        if len(self.__data)>number_sample_points:
            subset = random.sample(list(self.__data.keys()),number_sample_points)
        else:
            subset = list(self.__data.keys())
        
        # average pairwise distances
        
        sample_dist = [0]*int(0.5*(len(subset)*(len(subset)+1)))
        
        for i in range(len(subset)-1,0,-1):
            ik = len(subset)-i
            for j in range(i,len(subset)):
                sample_dist[int(0.5*ik*(ik+1))+j] = \
                       self.getDistance(self.__data[subset[i]],self.__data[subset[j]])

        # sort sample distances

        sample_dist.sort()
        
        # retrieve quantile values
        
        p = (len(subset)*(len(subset)+1))/200.0*self.__agglomeration_percentile
        if p.is_integer():
            self.radius = sample_dist[int(p)]
        else:
            self.radius = (1-p%1)*sample_dist[floor(p)]+(p%1)*sample_dist[ceil(p)]
        QgsMessageLog.logMessage(self.tr("Radius for cluster features: {:.5E}".format(
                                         self.radius)),MESSAGE_CATEGORY, Qgis.Info)
        
    def create_blobs(self):
    
        # assign cluster feature blobs with members and centroids
        
        self.derive_cf_radius
        
        self.blobs = []
        
        for key in self.__data.keys():
        
            if self.isCanceled():
                return False
            
            dist = float_info.max
            for j in range(len(self.blobs)):
                dist_j = self.blobs[j].distance2center(self.__data[key])
                if dist_j<dist:
                    min_j = j
                    dist = dist_j
            if dist<self.radius:
                self.blobs[min_j].add_point(key,self.__data[key])
            else:
                self.blobs.append(cf_blob(self.d,self.pa,self.manhattan,
                                 [key],self.__data[key]))
                self.size += 1
                
        '''
        Start iteration for misplaced members:
        This is an own idea to prevent points outside the limits. 
        I'm going over all the points and check their closest centers,
        until all allowed distances are within their limits.
        Go for <number_correction_trials> iterations.
        '''
        blobs2consider = list(range(len(self.blobs)))
        for i in range(1,number_correction_trials+1):
        
            blobsChanged = set()
            
            for j in range(len(self.blobs)):
        
                if self.isCanceled():
                    return False
            
                if self.blobs[j].size==1:
                    continue
            
                if j in blobs2consider:
                    # own blobs must be re-checked against all others
                    blobs2loop = list(range(len(self.blobs)))
                else:
                    # own blob must only be re-checked against the altered ones
                    blobs2loop = blobs2consider+[j]
            
                for key in self.blobs[j].members:
                    dist = float_info.max
                    for jj in blobs2loop:
                        dist_j = self.blobs[jj].distance2center(self.__data[key])
                        if dist_j<dist:
                            min_j = jj
                            dist = dist_j
                    if min_j!=j:
                        blobsChanged.add(j)
                        self.blobs[j].remove_point(key,self.__data[key])
                        if dist<self.radius:
                            blobsChanged.add(min_j)
                            self.blobs[min_j].add_point(key,self.__data[key])
                        else:
                            blobsChanged.add(len(self.blobs))
                            self.blobs.append(cf_blob(self.d,self.pa,self.manhattan,
                                 [key],self.__data[key]))
                            self.size += 1
                        
            if len(blobsChanged) == 0:
                break
            else:
                blobs2consider = list(blobsChanged)
        
        if i==number_correction_trials:
             QgsMessageLog.logMessage(self.tr("Optimization of cluster features "+ \
                        "failed after {} iterations".format(i)), \
                        MESSAGE_CATEGORY, Qgis.Warning)
        else:
             QgsMessageLog.logMessage(self.tr("Optimization of cluster features "+ \
                        "succeeded after {} iterations".format(i)), \
                        MESSAGE_CATEGORY, Qgis.Success)
            
        return True

    def return_centroids(self):
    
        # return dictionary of cluster feature centroids
    
        return dict((i,self.blobs[i].centroid) for i in range(len(self.blobs)))
        
    def return_members(self,keys):
    
        # return cluster feature members for given list of keys
    
        return [p for b in [self.blobs[key].members for key in keys] for p in b]

    def getDistance(self, point1, point2):
        '''
        2-dimensional Euclidean distance or Manhattan distance between points 1 and 2
        plus percentage contribution (pa) of attribute values
        '''
        dist = 0
        if self.manhattan:
            if self.pa < 100:
                dist += (1-0.01*self.pa)* \
                    (self.d.measureLine(QgsPointXY(point1.x(),point1.y()), \
                    QgsPointXY(point2.x(),point1.y()))+ \
                    self.d.measureLine(QgsPointXY(point1.x(),point1.y()), \
                    QgsPointXY(point1.x(),point2.y()))+ \
                    self.d.measureLine(QgsPointXY(point2.x(),point2.y()), \
                    QgsPointXY(point2.x(),point1.y()))+ \
                    self.d.measureLine(QgsPointXY(point2.x(),point2.y()), \
                    QgsPointXY(point1.x(),point2.y())))
            if self.pa > 0:
                dist += 2*0.01*self.pa*self.getAttrDistance(point1,point2)
        else:
            if self.pa < 100:
                dist += (1-0.01*self.pa)* \
                self.d.measureLine(QgsPointXY(point1.x(),point1.y()), \
                QgsPointXY(point2.x(),point2.y()))
            if self.pa > 0:
                dist += 0.01*self.pa*self.getAttrDistance(point1,point2)
        return dist
                
    def getAttrDistance(self, point1, point2):
        '''
        2-dimensional Euclidean distance or Manhattan distance between attributes
        of points 1 and 2
        '''
        attr_size = min(point1.attr_size,point2.attr_size)
        if attr_size == 0:
            return 0
        dist = 0
        if self.manhattan:
            for i in range(attr_size):
                dist += abs(point1.attributes[i]-point2.attributes[i])
        else:
            for i in range(attr_size):
                dist += (point1.attributes[i]-point2.attributes[i])* \
                        (point1.attributes[i]-point2.attributes[i])
            dist = sqrt(dist)
        return dist


class cf_blob:

    def __init__(self, d, pa, manhattan, members, centroid):
        """!
        @brief Constructor of single cluster feature (blob).
        
        @param[in] d (QgsDistanceArea): Qgs Measurement object.
        @param[in] pa (uint): Percentage contribution of attribute values.
        @param[in] manhattan (bool): Bool for use of Manhattan distance.
        @param[in] members (list): List of member keys.
        @param[in] centroid (QgsPoint): Qgs Point with initial centroid
        """

        self.d = d
        self.pa = pa
        self.manhattan = manhattan
        self.members = members
        self.size = len(members)
        self.centroid = centroid
        
    def update_centroid(self,point,remove=False):
        '''
        Update the centroid position with one additional point being added or removed
        '''
        if remove:
            centroid = QgsPointXY(self.centroid.x()-(1.0/self.size)*(point.x()-self.centroid.x()), \
                                  self.centroid.y()-(1.0/self.size)*(point.y()-self.centroid.y()))
        else:
            centroid = QgsPointXY(self.centroid.x()+(1.0/self.size)*(point.x()-self.centroid.x()), \
                                  self.centroid.y()+(1.0/self.size)*(point.y()-self.centroid.y()))
        centroid = Cluster_point(centroid)
        
        if remove:
            centroid.replaceAttributes([self.centroid.attributes[j]- \
                                        (1.0/self.size)*(point.attributes[j]- \
                                        self.centroid.attributes[j]) for j in \
                                        range(point.attr_size)])
        else:
            centroid.replaceAttributes([self.centroid.attributes[j]+ \
                                        (1.0/self.size)*(point.attributes[j]- \
                                        self.centroid.attributes[j]) for j in \
                                        range(point.attr_size)])
        self.centroid = centroid
               
    def add_point(self,index,point):
    
        self.members.append(index)
        self.size+=1
        self.update_centroid(point)
        
    def remove_point(self,index,point):
    
        self.members.remove(index)
        self.update_centroid(point,remove=True)
        self.size-=1

    def distance2center(self, point):
        '''
        2-dimensional Euclidean distance or Manhattan distance to centerpoint
        plus percentage contribution (pa) of attribute values
        '''
        dist = 0
        if self.manhattan:
            if self.pa < 100:
                dist += (1-0.01*self.pa)* \
                    (self.d.measureLine(QgsPointXY(self.centroid.x(),self.centroid.y()), \
                    QgsPointXY(point.x(),self.centroid.y()))+ \
                    self.d.measureLine(QgsPointXY(self.centroid.x(),self.centroid.y()), \
                    QgsPointXY(self.centroid.x(),point.y()))+ \
                    self.d.measureLine(QgsPointXY(point.x(),point.y()), \
                    QgsPointXY(point.x(),self.centroid.y()))+ \
                    self.d.measureLine(QgsPointXY(point.x(),point.y()), \
                    QgsPointXY(self.centroid.x(),point.y())))
            if self.pa > 0:
                dist += 2*0.01*self.pa*self.attrDistance2center(point)
        else:
            if self.pa < 100:
                dist += (1-0.01*self.pa)* \
                    self.d.measureLine(QgsPointXY(self.centroid.x(),self.centroid.y()), \
                    QgsPointXY(point.x(),point.y()))
            if self.pa > 0:
                dist += 0.01*self.pa*self.attrDistance2center(point)
        return dist
                
    def attrDistance2center(self, point):
        '''
        2-dimensional Euclidean distance or Manhattan distance to centerpoint,
        only derived from attributes
        '''
        attr_size = min(point.attr_size,self.centroid.attr_size)
        if attr_size == 0:
            return 0
        dist = 0
        if self.manhattan:
            for i in range(attr_size):
                dist += abs(point.attributes[i]-self.centroid.attributes[i])
        else:
            for i in range(attr_size):
                dist += (point.attributes[i]-self.centroid.attributes[i])* \
                        (point.attributes[i]-self.centroid.attributes[i])
            dist = sqrt(dist)
        return dist


class Cluster_point(QgsPoint):
    '''
    Class extends QgsPoint with attribute values
    '''
    def __init__(self,point):
        super(Cluster_point, self).__init__(point)
        self.attr_size = 0
        self.attributes = []
    
    def addAttribute(self,v):
        self.attr_size += 1
        self.attributes.append(v)
        
    def replaceAttributes(self,v):
        self.attr_size = len(v)
        self.attributes = v
