# -*- coding: utf-8 -*-

"""
/***************************************************************************
 gridindex
                                 A QGIS plugin
 This plugin generates a grid of rectangular polygons for map book page indexing.
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2025-06-07
        copyright            : (C) 2025 by Kapildev Adhikari
        email                : kapildevadk@gmail.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

__author__ = 'Kapildev Adhikari'
__date__ = '2025-06-07'
__copyright__ = '(C) 2025 by Kapildev Adhikari'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import math
from qgis.PyQt.QtCore import QCoreApplication, QVariant
from qgis.core import (QgsProcessing,
                       QgsFeatureSink,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterMapLayer,
                       QgsProcessingParameterFeatureSink,
                       QgsProcessingParameterDistance,
                       QgsProcessingParameterEnum,
                       QgsProcessingParameterBoolean,
                       QgsProcessingParameterNumber,
                       QgsRectangle,
                       QgsGeometry,
                       QgsFeature,
                       QgsField,
                       QgsFields,
                       QgsWkbTypes,
                       QgsSpatialIndex,
                       QgsProject,
                       QgsProcessingException,
                       QgsFeatureRequest,
                       QgsVectorLayer,
                       QgsRasterLayer,  # Added for better raster support
                       QgsUnitTypes,
                       QgsCoordinateReferenceSystem,
                       QgsCoordinateTransform,
                       QgsDistanceArea,  # Added for geographic calculations
                       QgsRasterDataProvider,  # Added for raster data access
                       QgsPointXY)  # Added for point operations
from qgis.PyQt.QtGui import QIcon
import os


class gridindexAlgorithm(QgsProcessingAlgorithm):
    """
    This algorithm creates a grid index with advanced labeling and ordering options.
    """

    INPUT_LAYER = 'INPUT_LAYER'
    CELL_WIDTH = 'CELL_WIDTH'
    CELL_HEIGHT = 'CELL_HEIGHT'
    USE_ABSOLUTE_NAMING = 'USE_ABSOLUTE_NAMING'
    LABEL_ORIGIN = 'LABEL_ORIGIN'
    NUM_ROWS = 'NUM_ROWS'
    NUM_COLS = 'NUM_COLS'
    START_PAGE = 'START_PAGE'
    OUTPUT = 'OUTPUT'

    def initAlgorithm(self, config=None):
        """Define the parameters for the tool."""
        
        self.addParameter(
            QgsProcessingParameterMapLayer(
                self.INPUT_LAYER, self.tr('Intersection Layer'),
                [QgsProcessing.TypeVector, QgsProcessing.TypeRaster]
            )
        )
        
        self.addParameter(
            QgsProcessingParameterDistance(
                self.CELL_WIDTH, self.tr('Grid Cell Width'),
                parentParameterName=self.INPUT_LAYER, defaultValue=500.0, optional=False
            )
        )
        
        self.addParameter(
            QgsProcessingParameterDistance(
                self.CELL_HEIGHT, self.tr('Grid Cell Height'),
                parentParameterName=self.INPUT_LAYER, defaultValue=500.0, optional=False
            )
        )

        self.addParameter(
            QgsProcessingParameterBoolean(
                self.USE_ABSOLUTE_NAMING,
                self.tr('Use absolute grid position for Page Names'),
                defaultValue=False
            )
        )
        self.parameterDefinition(self.USE_ABSOLUTE_NAMING).setHelp(
            self.tr("If checked, names are based on the overall grid column (e.g., C5). If unchecked, they are numbered sequentially within each row (e.g., C1, C2...).")
        )
        
        self.addParameter(
            QgsProcessingParameterEnum(
                self.LABEL_ORIGIN,
                self.tr('Labeling starts from'),
                options=['Top-Left', 'Top-Right', 'Bottom-Left', 'Bottom-Right'],
                defaultValue=0
            )
        )
        self.parameterDefinition(self.LABEL_ORIGIN).setHelp(
            self.tr("Controls which corner of the grid both the PageNumber and PageName sequences begin from.")
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.NUM_ROWS, self.tr('Number of rows (optional override)'),
                type=QgsProcessingParameterNumber.Integer, optional=True, defaultValue=0 
            )
        )
        
        self.addParameter(
            QgsProcessingParameterNumber(
                self.NUM_COLS, self.tr('Number of columns (optional override)'),
                type=QgsProcessingParameterNumber.Integer, optional=True, defaultValue=0 
            )
        )

        self.addParameter(
            QgsProcessingParameterNumber(
                self.START_PAGE, self.tr('Starting page number'),
                type=QgsProcessingParameterNumber.Integer, optional=True, defaultValue=1
            )
        )

        self.addParameter(
            QgsProcessingParameterFeatureSink(self.OUTPUT, self.tr('Output Grid Index'))
        )
    
    def processAlgorithm(self, parameters, context, feedback):
        source = self.parameterAsLayer(parameters, self.INPUT_LAYER, context)
        if source is None:
            raise QgsProcessingException(self.invalidSourceError(parameters, self.INPUT_LAYER))
            
        use_absolute_naming = self.parameterAsBoolean(parameters, self.USE_ABSOLUTE_NAMING, context)
        label_origin_index = self.parameterAsEnum(parameters, self.LABEL_ORIGIN, context)
        num_rows_override = self.parameterAsInt(parameters, self.NUM_ROWS, context)
        num_cols_override = self.parameterAsInt(parameters, self.NUM_COLS, context)
        start_page = self.parameterAsInt(parameters, self.START_PAGE, context)

        source_crs = source.crs()
        is_geographic = source_crs.isGeographic()
        
        if is_geographic:
            feedback.pushInfo("Geographic CRS detected. Using degree-based calculations with geodesic distance.")
            
            # Use WGS 84 / World Mercator (EPSG:3395) as calculation CRS for better accuracy
            calculation_crs = QgsCoordinateReferenceSystem('EPSG:3395')
            
            # Create transform from source to calculation CRS
            transform_to_calc = QgsCoordinateTransform(source_crs, calculation_crs, context.transformContext())
            transform_to_source = QgsCoordinateTransform(calculation_crs, source_crs, context.transformContext())
            
            # Transform extent to calculation CRS
            calc_extent = transform_to_calc.transform(source.extent())
            
            # Get cell dimensions in meters (for geographic data, convert from degrees)
            cell_width = self.parameterAsDouble(parameters, self.CELL_WIDTH, context)
            cell_height = self.parameterAsDouble(parameters, self.CELL_HEIGHT, context)
            
            # If the input is in degrees, convert to meters using approximate conversion
            # This is a rough approximation - 1 degree ≈ 111,320 meters at equator
            if source_crs.mapUnits() == QgsUnitTypes.DistanceDegrees:
                if cell_width < 1:  # Assume it's in degrees if less than 1
                    cell_width *= 111320  # Convert degrees to meters
                if cell_height < 1:
                    cell_height *= 111320
                    
            feedback.pushInfo(f"Using cell dimensions: {cell_width}m x {cell_height}m")
            
        else:
            feedback.pushInfo("Projected CRS detected. Using layer units for calculation.")
            cell_width = self.parameterAsDouble(parameters, self.CELL_WIDTH, context)
            cell_height = self.parameterAsDouble(parameters, self.CELL_HEIGHT, context)
            calculation_crs = source_crs
            calc_extent = source.extent()
            transform_to_source = None

        feedback.pushInfo(f"Calculations will be performed in {calculation_crs.authid()}")

        fields = QgsFields()
        fields.append(QgsField('PageNumber', QVariant.Int))
        fields.append(QgsField('PageName', QVariant.String))
        
        is_raster = isinstance(source, QgsRasterLayer)
        if is_raster:
            fields.append(QgsField('RasterBands', QVariant.Int))
            fields.append(QgsField('PixelSizeX', QVariant.Double))
            fields.append(QgsField('PixelSizeY', QVariant.Double))
            fields.append(QgsField('RasterCoverage', QVariant.Double))
            feedback.pushInfo(f"Raster layer detected with {source.bandCount()} bands")
        
        (sink, dest_id) = self.parameterAsSink(parameters, self.OUTPUT, context, fields, QgsWkbTypes.Polygon, source_crs)

        origin_x, origin_y = calc_extent.xMinimum(), calc_extent.yMinimum()
        num_rows = num_rows_override if num_rows_override > 0 else math.ceil(calc_extent.height() / cell_height)
        num_cols = num_cols_override if num_cols_override > 0 else math.ceil(calc_extent.width() / cell_width)
        feedback.pushInfo(f"Grid will have a maximum of {num_rows} rows and {num_cols} columns.")
        
        is_vector = isinstance(source, QgsVectorLayer)
        spatial_index = QgsSpatialIndex(source.getFeatures()) if is_vector else None
        
        raster_extent = None
        raster_pixel_size_x = None
        raster_pixel_size_y = None
        raster_data_provider = None
        if is_raster:
            raster_extent = source.extent()
            raster_pixel_size_x = source.rasterUnitsPerPixelX()
            raster_pixel_size_y = source.rasterUnitsPerPixelY()
            raster_data_provider = source.dataProvider()
            feedback.pushInfo(f"Raster pixel size: {raster_pixel_size_x} x {raster_pixel_size_y}")
        
        page_counter = start_page

        def check_raster_coverage(cell_rect, raster_layer, data_provider):
            """
            Check if a grid cell contains actual raster data (not just intersects extent).
            Returns tuple: (has_data, coverage_percentage)
            """
            if is_geographic:
                # For geographic data, use the transformed raster extent
                current_raster_extent = transform_to_calc.transform(raster_extent)
                use_calc_coords = True
            else:
                # For projected data, use the original raster extent directly
                current_raster_extent = raster_extent
                use_calc_coords = False
            
            if not current_raster_extent.intersects(cell_rect):
                return False, 0.0
            
            # Get the intersection of cell with raster extent
            intersection = current_raster_extent.intersect(cell_rect)
            if intersection.isEmpty():
                return False, 0.0
            
            if is_geographic:
                # For geographic data, convert pixel sizes to meters for sampling calculations
                pixel_size_x = raster_pixel_size_x * 111320  # degrees to meters approximation
                pixel_size_y = raster_pixel_size_y * 111320
            else:
                # For projected data, use pixel sizes directly
                pixel_size_x = raster_pixel_size_x
                pixel_size_y = raster_pixel_size_y
            
            # Calculate sampling grid
            steps_x = max(3, int(intersection.width() / pixel_size_x / 10))  # Sample every ~10 pixels
            steps_y = max(3, int(intersection.height() / pixel_size_y / 10))
            
            step_x = intersection.width() / steps_x
            step_y = intersection.height() / steps_y
            
            valid_data_count = 0
            total_samples = 0
            
            for i in range(steps_x + 1):
                for j in range(steps_y + 1):
                    x = intersection.xMinimum() + (i * step_x)
                    y = intersection.yMinimum() + (j * step_y)
                    
                    if is_geographic:
                        # Transform from calculation CRS back to source CRS for sampling
                        sample_point = transform_to_source.transform(QgsPointXY(x, y))
                    else:
                        # For projected data, coordinates are already in the correct CRS
                        sample_point = QgsPointXY(x, y)
                    
                    # Sample the raster at this point
                    sample_result = data_provider.sample(sample_point, 1)  # Sample band 1
                    
                    total_samples += 1
                    if sample_result[0] and sample_result[1] is not None:  # Check for valid data
                        valid_data_count += 1
            
            if total_samples == 0:
                return False, 0.0
            
            coverage_percentage = (valid_data_count / total_samples) * 100.0
            has_data = coverage_percentage > 0  # Has data if any sample points contain valid data
            
            return has_data, coverage_percentage

        def get_row_label(row_index, total_rows):
            is_top_down = (label_origin_index == 0 or label_origin_index == 1)
            effective_row = row_index if not is_top_down else (total_rows - 1) - row_index
            
            label = ""
            if effective_row < 0: return ""
            while True:
                label = chr(effective_row % 26 + 65) + label
                effective_row = effective_row // 26 - 1
                if effective_row < 0:
                    break
            return label

        r_iterator, c_iterator = None, None
        if label_origin_index == 0: # Top-Left
            r_iterator = range(num_rows - 1, -1, -1)
            c_iterator = range(num_cols)
        elif label_origin_index == 1: # Top-Right
            r_iterator = range(num_rows - 1, -1, -1)
            c_iterator = range(num_cols - 1, -1, -1)
        elif label_origin_index == 2: # Bottom-Left
            r_iterator = range(num_rows)
            c_iterator = range(num_cols)
        elif label_origin_index == 3: # Bottom-Right
            r_iterator = range(num_rows)
            c_iterator = range(num_cols - 1, -1, -1)
        
        row_counters = {}

        for r in r_iterator:
            for c in c_iterator:
                if feedback.isCanceled(): return {}
                
                cell_rect_calc = QgsRectangle(origin_x + (c * cell_width), 
                                            origin_y + (r * cell_height),
                                            origin_x + ((c + 1) * cell_width),
                                            origin_y + ((r + 1) * cell_height))
                
                if is_geographic and transform_to_source:
                    cell_rect_source = transform_to_source.transform(cell_rect_calc)
                else:
                    cell_rect_source = cell_rect_calc
                
                found_intersection = False
                raster_info = {}
                
                if is_vector:
                    if spatial_index.intersects(cell_rect_source):
                        for feat in source.getFeatures(QgsFeatureRequest().setFilterRect(cell_rect_source)):
                            if feat.geometry().intersects(cell_rect_source):
                                found_intersection = True
                                break
                elif is_raster:  # Enhanced raster intersection detection with proper clipping
                    if is_geographic:
                        check_rect = cell_rect_calc  # Use calculation CRS coordinates
                    else:
                        check_rect = cell_rect_source  # Use source CRS coordinates directly
                        
                    has_data, coverage = check_raster_coverage(check_rect, source, raster_data_provider)
                    if has_data:
                        found_intersection = True
                        # Calculate additional raster information
                        raster_info = {
                            'bands': source.bandCount(),
                            'pixel_size_x': raster_pixel_size_x,
                            'pixel_size_y': raster_pixel_size_y,
                            'coverage': coverage
                        }
                        
                        feedback.pushInfo(f"Grid cell {r},{c} contains {coverage:.1f}% raster data coverage")
                
                if found_intersection:
                    feat = QgsFeature(fields)
                    feat.setGeometry(QgsGeometry.fromRect(cell_rect_source))

                    row_letter = get_row_label(r, num_rows)
                    page_name = ""

                    if use_absolute_naming:
                        page_name = f"{row_letter}{c + 1}"
                    else:
                        if r not in row_counters:
                            row_counters[r] = 1 if c_iterator.start < c_iterator.stop else num_cols
                        
                        page_name = f"{row_letter}{row_counters[r]}"
                        
                        if c_iterator.start < c_iterator.stop:
                            row_counters[r] += 1
                        else:
                            row_counters[r] -= 1
                    
                    if is_raster:
                        feat.setAttributes([
                            page_counter, 
                            page_name,
                            raster_info.get('bands', 0),
                            raster_info.get('pixel_size_x', 0.0),
                            raster_info.get('pixel_size_y', 0.0),
                            raster_info.get('coverage', 0.0)
                        ])
                    else:
                        feat.setAttributes([page_counter, page_name])
                    
                    sink.addFeature(feat, QgsFeatureSink.FastInsert)
                    page_counter += 1

        return {self.OUTPUT: dest_id}

    def description(self):
        return self.tr(
            """
            <p>This algorithm creates a grid of rectangular polygon features for a map book index with advanced labeling controls.</p>
            <p><b>How it works:</b></p>
            <ol>
            <li>A grid is generated over the full extent of the selected <b>Intersection Layer</b>.</li>
            <li>Only grid cells that intersect with the Intersection Layer are kept in the final output.</li>
            </ol>
            <p><b><u>CRS & Unit Handling:</u></b></p>
            <p>This tool intelligently handles both projected and geographic coordinate systems. For geographic CRS (e.g., WGS 84), the tool automatically uses World Mercator projection for accurate calculations. For projected layers, you can specify the cell size in any supported unit.</p>
            <p><b><u>Enhanced Raster Support:</u></b></p>
            <p>The tool now provides enhanced support for raster layers, including band count and pixel size information in the output attributes.</p>
            <p><b><u>Naming and Ordering:</u></b></p>
            <ul>
            <li><b>Use absolute grid position for Page Names:</b> When checked, the number in the name (e.g., the '5' in 'C5') corresponds to the grid's absolute column number. When unchecked, it's a sequential number for created cells within that row.</li>
            <li><b>Labeling starts from:</b> Controls which corner of the grid both the PageNumber and PageName sequences begin from.</li>
            </ul>
            """
        )

    def name(self):
        return 'Grid Index'
    def displayName(self):
        return self.tr(self.name())
    def group(self):
        return ''
    def groupId(self):
        return ''
    def tr(self, string):
        return QCoreApplication.translate('Processing', string)
    def createInstance(self):
        return gridindexAlgorithm()
    def icon(self):
        return QIcon(os.path.join(os.path.dirname(__file__), 'icon.png'))
