# -*- coding: utf-8 -*-

__author__ = 'Sanda Takeru'
__date__ = '2025-02-14'
__copyright__ = '(C) 2025 by Sanda Takeru'
__revision__ = '$Format:%H$'

from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.core import (
    QgsProcessing, QgsVectorLayer, QgsProcessingAlgorithm, QgsProject,
    QgsProcessingParameterFeatureSource, QgsProcessingParameterFeatureSink,
    QgsProcessingParameterRasterLayer, QgsProcessingParameterNumber,
    QgsRectangle, QgsFeature, QgsGeometry, QgsField, QgsCoordinateTransform,
    QgsProcessingException, QgsFeatureRequest, QgsPointXY
)
import numpy as np
import concurrent.futures

HelpMessage = 'Help'

def reprojection_to_projectCRS(layer):
    # Get the CRS of the layer
    layer_crs = layer.crs()
    project_crs = QgsProject.instance().crs()

    # Reproject the polygon layer if the CRS is difference
    if layer_crs != project_crs:
        transform = QgsCoordinateTransform(layer_crs, project_crs, QgsProject.instance())
        features = layer.getFeatures()
        layer.startEditing()
        for feature in features:
            geom = feature.geometry()
            geom.transform(transform)
            feature.setGeometry(geom)
            layer.updateFeature(feature)
        layer.commitChanges()

def check_layer_validity(layer, layer_name):
    if not layer.isValid():
        raise QgsProcessingException(f'Failed to load {layer_name}.')
    
band_count = 0
def check_raster_layer(raster_layer, feedback):
    global band_count
    if not raster_layer.isValid():
        raise QgsProcessingException(f'Failed to load raster layer: {raster_layer.name()}.')

    project_crs = QgsProject.instance().crs()
    if project_crs != raster_layer.crs():
        raise QgsProcessingException('Raster layer CRS is difference from the project CRS.')

    feedback.pushInfo(f'This process is calculating based on the project CRS: {project_crs}.')

    band_count = raster_layer.bandCount()
    if band_count == 0:
        raise QgsProcessingException('Raster layer has no bands.')

    feedback.pushInfo(f'Raster layer has {band_count} band(s).')

    band_ranges = []
    for band in range(1, band_count + 1):
        stats = raster_layer.dataProvider().bandStatistics(band)
        if stats.minimumValue is None or stats.maximumValue is None:
            raise QgsProcessingException(f'Band {band} has no data.')
        band_ranges.append((stats.minimumValue, stats.maximumValue))

    for i, (min_value, max_value) in enumerate(band_ranges, start=1):
        feedback.pushInfo(f"Band {i}: Min value = {min_value}, Max value = {max_value}")

class ReuseSurveyAlgorithm(QgsProcessingAlgorithm):
    INPUT_RASTER = 'INPUT_RASTER'
    INPUT_VECTOR_SOURCE = 'INPUT_VECTOR_SOURCE'
    INPUT_VECTOR_TARGET = 'INPUT_VECTOR_TARGET'
    OUTPUT = 'OUTPUT'

    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterRasterLayer(self.INPUT_RASTER, 'Input Raster Layer', [QgsProcessing.TypeRaster]))
        self.addParameter(QgsProcessingParameterFeatureSource(self.INPUT_VECTOR_SOURCE, 'Input Source Polygon Layer', types=[QgsProcessing.TypeVectorPolygon]))
        self.addParameter(QgsProcessingParameterFeatureSource(self.INPUT_VECTOR_TARGET, 'Input Target Polygon Layer', types=[QgsProcessing.TypeVectorPolygon]))
        self.addParameter(QgsProcessingParameterFeatureSink(self.OUTPUT, 'Output Evaluated Polygon Layer'))

    def processAlgorithm(self, parameters, context, feedback):
        raster_layer = self.parameterAsRasterLayer(parameters, self.INPUT_RASTER, context)
        vector_layer_source = self.parameterAsSource(parameters, self.INPUT_VECTOR_SOURCE, context)
        vector_layer_target = self.parameterAsSource(parameters, self.INPUT_VECTOR_TARGET, context)

        vector_layer_source = vector_layer_source.materialize(QgsFeatureRequest())
        vector_layer_target = vector_layer_target.materialize(QgsFeatureRequest())

        check_layer_validity(vector_layer_source, 'Source Polygon Layer')
        check_layer_validity(vector_layer_target, 'Target Polygon Layer')
        check_raster_layer(raster_layer, feedback)

        feature_count = vector_layer_source.featureCount() * vector_layer_target.featureCount()
        if feature_count == 0:
            raise QgsProcessingException('Layer(s) has no features. Exiting process.')

        feedback.pushInfo('Reprojecting to project CRS.')
        reprojection_to_projectCRS(vector_layer_source)
        reprojection_to_projectCRS(vector_layer_target)

        feedback.pushInfo('Making Target layer.')
        vector_layer_target_dict = self.create_new_target_layer(vector_layer_target, feedback)

        feedback.pushInfo('Calculating histogram of source layer.')
        histograms_source = self.calculate_histograms(raster_layer, vector_layer_source, feedback)
        new_histogram_source = self.merge_histograms(histograms_source)

        feedback.pushInfo('Calculating histogram of target layer.')
        histograms_target_dict = self.calculate_target_histograms(raster_layer, vector_layer_target_dict, feedback)

        difference_dict = self.evaluate_differences(new_histogram_source, histograms_target_dict, feedback) #要チェック

        self.add_attributes_to_features(vector_layer_target_dict, difference_dict)

        for vector_layer_target in vector_layer_target_dict:
            QgsProject.instance().addMapLayer(vector_layer_target)

        return {self.OUTPUT: vector_layer_target}

    def create_new_target_layer(self, layer, feedback):
        # Get the specified polygon layer
        polygon_layer = layer
        output_layers = []
        
        # Create the layer
        project_crs = QgsProject.instance().crs()
        output_layer = QgsVectorLayer(f'Polygon?crs={project_crs.authid()}', 'Reuse Survey', 'memory')
        provider = output_layer.dataProvider()
        provider.addAttributes([QgsField('SPS_ID', QVariant.Int)])
        provider.addAttributes([QgsField('SPS_difference', QVariant.Double,'double', 20, 10)])
        provider.addAttributes(polygon_layer.fields().toList())
        output_layer.updateFields()
        
        # Get all features
        polygon_features = list(polygon_layer.getFeatures())
        features = []
        id_counter = 1
        for polygon_feature in polygon_features:

            # Copy the feature
            attributes = polygon_feature.attributes()
            feature = QgsFeature()
            feature.setGeometry(polygon_feature.geometry())
            feature.setAttributes([id_counter] + [0] + attributes)
            
            features.append(feature)
            id_counter += 1

        provider.addFeatures(features)
        output_layer.updateExtents()
        output_layers.append(output_layer)
        return output_layers

    def calculate_histograms(self, raster_layer, vector_layer, feedback):
        feature_count = vector_layer.featureCount()
        digit_count = len(str(feature_count))
        feedback.pushInfo(str(0).zfill(max(digit_count,2))+'/'+str(feature_count).zfill(max(digit_count,2)))
        
        # Initialize a dictionary to store histograms
        histograms = {}

        # Get raster extent and resolution
        raster_extent = raster_layer.extent()
        raster_width = raster_layer.width()
        raster_height = raster_layer.height()
        band_count = raster_layer.bandCount()
        pixel_width = raster_extent.width() / raster_width
        pixel_height = raster_extent.height() / raster_height

        def process_feature(feature):
            feature_histograms = {}
            for band in range(1, band_count + 1):
                # Get the geometry of the feature
                geometry = feature.geometry()
                bbox = geometry.boundingBox()

                # Calculate the subset of the raster that intersects with the feature's bounding box
                x_min = max(raster_extent.xMinimum(), bbox.xMinimum())
                x_max = min(raster_extent.xMaximum(), bbox.xMaximum())
                y_min = max(raster_extent.yMinimum(), bbox.yMinimum())
                y_max = min(raster_extent.yMaximum(), bbox.yMaximum())

                subset_width = int((x_max - x_min) / pixel_width)
                subset_height = int((y_max - y_min) / pixel_height)

                # Get the raster values within the feature's geometry
                values = []
                for i in range(subset_width):
                    for j in range(subset_height):
                        x = x_min + i * pixel_width + pixel_width / 2
                        y = y_max - j * pixel_height - pixel_height / 2
                        point = QgsPointXY(x, y)
                        if geometry.contains(point):
                            value = raster_layer.dataProvider().sample(point, band)
                            if value is not None:
                                values.append(value[0])

                feature_histograms[band] = values
            return feature.id(), feature_histograms

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(process_feature, feature) for feature in vector_layer.getFeatures()]
            for index, future in enumerate(concurrent.futures.as_completed(futures)):
                feature_id, feature_histograms = future.result()
                histograms[feature_id] = feature_histograms
                feedback.pushInfo(str(index + 1).zfill(max(digit_count,2))+'/'+str(feature_count).zfill(max(digit_count,2)))

        return histograms

    def merge_histograms(self, histograms_source):
        new_histogram_source = {1:{}}
        for image_key in histograms_source:
            for band_key in histograms_source[image_key]:
                if band_key not in new_histogram_source[1]:
                    new_histogram_source[1][band_key] = []
                new_histogram_source[1][band_key].extend(histograms_source[image_key][band_key])
        return new_histogram_source

    def calculate_target_histograms(self, raster_layer, vector_layer_target_dict, feedback):
        histograms_target_dict = {}
        for index, vector_layer_target in enumerate(vector_layer_target_dict):
            histograms_target_dict[index+1] = self.calculate_histograms(raster_layer, vector_layer_target, feedback)
        return histograms_target_dict

    def evaluate_differences(self, histograms_source, histograms_target_dict, feedback):
        band_count = len(next(iter(histograms_source.values())).keys())
        source_stats = compute_source_statistics(histograms_source, band_count)
        return calculate_weighted_difference(source_stats, histograms_target_dict, band_count)

    def add_attributes_to_features(self, vector_layer_target_dict, difference_dict):
        for index1, vector_layer_target in enumerate(vector_layer_target_dict):
            provider = vector_layer_target.dataProvider()
            field_num = vector_layer_target.fields().indexFromName('SPS_difference')
            for index2, feature in enumerate(vector_layer_target.getFeatures()):
                key1, key2 = index1 + 1, index2 + 1
                if key1 in difference_dict and key2 in difference_dict[key1]:
                    difference_value = float(difference_dict[key1][key2])
                    provider.changeAttributeValues({feature.id(): {field_num: difference_value}})
            vector_layer_target.updateFields()

    def name(self):
        return 'reuse_survey_algorithm'

    def displayName(self):
        return self.tr('Reuse Survey 襲用評価')

    def group(self):
        return ''

    def groupId(self):
        return ''

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)

    def createInstance(self):
        return ReuseSurveyAlgorithm()

class SamplePlotSurveyAlgorithm(QgsProcessingAlgorithm):
    INPUT_RASTER = 'INPUT_RASTER'
    INPUT_VECTOR_SOURCE = 'INPUT_VECTOR_SOURCE'
    INPUT_WIDTH = 'INPUT_WIDTH'
    INPUT_HEIGHT = 'INPUT_HEIGHT'
    OUTPUT = 'OUTPUT'

    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterRasterLayer(self.INPUT_RASTER, 'Input Raster Layer', [QgsProcessing.TypeRaster]))
        self.addParameter(QgsProcessingParameterFeatureSource(self.INPUT_VECTOR_SOURCE, 'Input Source Polygon Layer', types=[QgsProcessing.TypeVectorPolygon]))
        self.addParameter(QgsProcessingParameterNumber(self.INPUT_WIDTH, 'Input Width of Sample Plot (m)', defaultValue=30))
        self.addParameter(QgsProcessingParameterNumber(self.INPUT_HEIGHT, 'Input Height of Sample Plot (m)', defaultValue=30))
        self.addParameter(QgsProcessingParameterFeatureSink(self.OUTPUT, 'Output Evaluated Polygon Layer'))

    def processAlgorithm(self, parameters, context, feedback):
        raster_layer = self.parameterAsRasterLayer(parameters, self.INPUT_RASTER, context)
        vector_layer_source = self.parameterAsSource(parameters, self.INPUT_VECTOR_SOURCE, context)
        width_length = self.parameterAsDouble(parameters, self.INPUT_WIDTH, context)
        height_length = self.parameterAsDouble(parameters, self.INPUT_HEIGHT, context)

        vector_layer_source = vector_layer_source.materialize(QgsFeatureRequest())

        check_layer_validity(vector_layer_source, 'Source Polygon Layer')
        check_raster_layer(raster_layer, feedback)

        feature_count = vector_layer_source.featureCount()
        if feature_count == 0:
            raise QgsProcessingException('Layer(s) has no features. Exiting process.')

        feedback.pushInfo('Reprojecting to project CRS.')
        reprojection_to_projectCRS(vector_layer_source)

        feedback.pushInfo('Making Target layer.')
        vector_layer_target_dict = self.create_rectangles_within_polygon(vector_layer_source, width_length, height_length)

        feedback.pushInfo('Calculating histogram of source layer.')
        histograms_source = self.calculate_histograms(raster_layer, vector_layer_source, feedback)
        feedback.pushInfo('Calculating histogram of target layer.')
        histograms_target_dict = self.calculate_target_histograms(raster_layer, vector_layer_target_dict, feedback)

        difference_dict = self.evaluate_differences(histograms_source, histograms_target_dict, feedback)

        self.add_attributes_to_features(vector_layer_target_dict, difference_dict)

        for vector_layer_target in vector_layer_target_dict:
            QgsProject.instance().addMapLayer(vector_layer_target)

        return {self.OUTPUT: vector_layer_target}

    def create_rectangles_within_polygon(self, layer, width, height):
        # Get the specified polygon layer
        polygon_layer = layer
        output_layers = []

        # Get all features
        polygon_features = list(polygon_layer.getFeatures())
        for polygon_feature in polygon_features:

            # Get the geometry of the polygon
            polygon_geom = polygon_feature.geometry()

            # Create the layer
            project_crs = QgsProject.instance().crs()
            output_layer = QgsVectorLayer(f'Polygon?crs={project_crs.authid()}', 'Sample Plots', 'memory')
            provider = output_layer.dataProvider()
            provider.addAttributes([QgsField('SPS_ID', QVariant.Int)])
            provider.addAttributes([QgsField('SPS_difference', QVariant.Double,'double', 20, 10)])
            provider.addAttributes(polygon_layer.fields().toList())
            output_layer.updateFields()

            # Function to create and add rectangles
            def create_rectangle(x_min, y_min, x_max, y_max, id, attributes):
                rect = QgsRectangle(x_min, y_min, x_max, y_max)
                feature = QgsFeature()
                feature.setGeometry(QgsGeometry.fromRect(rect))
                feature.setAttributes([id] + [0] + attributes)
                return feature

            # Create and add rectangles to the layer
            features = []
            x_min = polygon_geom.boundingBox().xMinimum()
            y_min = polygon_geom.boundingBox().yMinimum()
            x_max = polygon_geom.boundingBox().xMaximum()
            y_max = polygon_geom.boundingBox().yMaximum()

            id_counter = 1
            for x in range(int((x_max - x_min) / width) + 1):
                for y in range(int((y_max - y_min) / height) + 1):
                    rect_x_min = x_min + x * width
                    rect_y_min = y_min + y * height
                    rect_x_max = rect_x_min + width
                    rect_y_max = rect_y_min + height
                    rect_geom = QgsGeometry.fromRect(QgsRectangle(rect_x_min, rect_y_min, rect_x_max, rect_y_max))
                    if polygon_geom.contains(rect_geom):
                        attributes = polygon_feature.attributes()
                        feature = create_rectangle(rect_x_min, rect_y_min, rect_x_max, rect_y_max, id_counter, attributes)
                        features.append(feature)
                        id_counter += 1

            provider.addFeatures(features)
            output_layer.updateExtents()
            output_layers.append(output_layer)
        return output_layers

    def calculate_histograms(self, raster_layer, vector_layer, feedback):
        feature_count = vector_layer.featureCount()
        digit_count = len(str(feature_count))
        feedback.pushInfo(str(0).zfill(max(digit_count,2))+'/'+str(feature_count).zfill(max(digit_count,2)))
        
        # Initialize a dictionary to store histograms
        histograms = {}

        # Get raster extent and resolution
        raster_extent = raster_layer.extent()
        raster_width = raster_layer.width()
        raster_height = raster_layer.height()
        band_count = raster_layer.bandCount()
        pixel_width = raster_extent.width() / raster_width
        pixel_height = raster_extent.height() / raster_height

        def process_feature(feature):
            feature_histograms = {}
            for band in range(1, band_count + 1):
                # Get the geometry of the feature
                geometry = feature.geometry()
                bbox = geometry.boundingBox()

                # Calculate the subset of the raster that intersects with the feature's bounding box
                x_min = max(raster_extent.xMinimum(), bbox.xMinimum())
                x_max = min(raster_extent.xMaximum(), bbox.xMaximum())
                y_min = max(raster_extent.yMinimum(), bbox.yMinimum())
                y_max = min(raster_extent.yMaximum(), bbox.yMaximum())

                subset_width = int((x_max - x_min) / pixel_width)
                subset_height = int((y_max - y_min) / pixel_height)

                # Get the raster values within the feature's geometry
                values = []
                for i in range(subset_width):
                    for j in range(subset_height):
                        x = x_min + i * pixel_width + pixel_width / 2
                        y = y_max - j * pixel_height - pixel_height / 2
                        point = QgsPointXY(x, y)
                        if geometry.contains(point):
                            value = raster_layer.dataProvider().sample(point, band)
                            if value is not None:
                                values.append(value[0])

                feature_histograms[band] = values
            return feature.id(), feature_histograms

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(process_feature, feature) for feature in vector_layer.getFeatures()]
            for index, future in enumerate(concurrent.futures.as_completed(futures)):
                feature_id, feature_histograms = future.result()
                histograms[feature_id] = feature_histograms
                feedback.pushInfo(str(index + 1).zfill(max(digit_count,2))+'/'+str(feature_count).zfill(max(digit_count,2)))

        return histograms

    def calculate_target_histograms(self, raster_layer, vector_layer_target_dict, feedback):
        histograms_target_dict = {}
        for index, vector_layer_target in enumerate(vector_layer_target_dict):
            histograms_target_dict[index+1] = self.calculate_histograms(raster_layer, vector_layer_target, feedback)
        return histograms_target_dict

    def evaluate_differences(self, histograms_source, histograms_target_dict, feedback):
        band_count = len(next(iter(histograms_source.values())).keys())
        source_stats = compute_source_statistics(histograms_source, band_count)
        return calculate_weighted_difference(source_stats, histograms_target_dict, band_count)

    def add_attributes_to_features(self, vector_layer_target_dict, difference_dict):
        for index1, vector_layer_target in enumerate(vector_layer_target_dict):
            provider = vector_layer_target.dataProvider()
            field_num = vector_layer_target.fields().indexFromName('SPS_difference')
            for index2, feature in enumerate(vector_layer_target.getFeatures()):
                key1, key2 = index1 + 1, index2 + 1
                if key1 in difference_dict and key2 in difference_dict[key1]:
                    difference_value = float(difference_dict[key1][key2])
                    provider.changeAttributeValues({feature.id(): {field_num: difference_value}})
            vector_layer_target.updateFields()

    def name(self):
        return 'Sample Plot Survey Algorithm'

    def displayName(self):
        return self.tr('Sample Plot Survey 標準地選定')

    def group(self):
        return ''

    def groupId(self):
        return ''

    def tr(self, string):
        return QCoreApplication.translate('Processing', string)

    def createInstance(self):
        return SamplePlotSurveyAlgorithm()

def calculate_statistics(arr):
    """Calculate mean and standard deviation of an array."""
    mean = np.mean(arr)
    std_dev = np.std(arr)
    return mean, std_dev

def calculate_weighted_difference(source_stats, target_stats, band_count):
    """Calculate weighted differences between source and target statistics."""
    difference_dict = {}
    for key1, histograms in target_stats.items():
        difference_dict[key1] = {}
        for key2, histogram in histograms.items():
            difference_sum = 0
            for key3, band in histogram.items():
                mean, std_dev = calculate_statistics(band)
                target_stats[key1][key2][key3] = {'mean': mean, 'std_dev': std_dev}
                diff = (mean - source_stats[key1][key3]['mean']) * source_stats[key1][key3]['weight']
                difference_sum += diff
            difference_dict[key1][key2] = difference_sum
    return difference_dict

def compute_source_statistics(histograms_source, band_count):
    """Compute statistics and weights for source histograms."""
    source_stats = {}
    for key1, histogram in histograms_source.items():
        std_dev_inv_sum = 0
        source_stats[key1] = {}
        for key2, band in histogram.items():
            mean, std_dev = calculate_statistics(band)
            if std_dev != 0:
                std_dev_inv_sum += 1 / std_dev
            source_stats[key1][key2] = {'mean': mean, 'std_dev': std_dev}
        for key2 in histogram:
            weight = (1 / source_stats[key1][key2]['std_dev'] / std_dev_inv_sum
                      if std_dev_inv_sum != 0 and source_stats[key1][key2]['std_dev'] != 0
                      else 1 / band_count)
            source_stats[key1][key2]['weight'] = weight
    return source_stats

def evaluate_differences(histograms_source, histograms_target_dict, feedback):
    """Evaluate differences between source and target histograms."""
    band_count = len(next(iter(histograms_source.values())).keys())
    source_stats = compute_source_statistics(histograms_source, band_count)
    return calculate_weighted_difference(source_stats, histograms_target_dict, band_count)

def add_attributes_to_features(vector_layer_target_dict, difference_dict):
    """Add calculated differences as attributes to target features."""
    for index1, vector_layer_target in enumerate(vector_layer_target_dict):
        provider = vector_layer_target.dataProvider()
        field_num = vector_layer_target.fields().indexFromName('SPS_difference')
        for index2, feature in enumerate(vector_layer_target.getFeatures()):
            key1, key2 = index1 + 1, index2 + 1
            if key1 in difference_dict and key2 in difference_dict[key1]:
                difference_value = float(difference_dict[key1][key2])
                provider.changeAttributeValues({feature.id(): {field_num: difference_value}})
        vector_layer_target.updateFields()