from qgis.core import (QgsProcessing,
                       QgsProcessingAlgorithm,
                       QgsProcessingParameterFeatureSource,
                       QgsProcessingParameterField,
                       QgsProcessingParameterString,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterBoolean,
                       QgsProcessingParameterCrs,
                       QgsProcessingParameterFeatureSink,
                       QgsFeatureSink,
                       QgsFeature,
                       QgsGeometry,
                       QgsPoint,
                       QgsField,
                       QgsFields,
                       QgsWkbTypes,
                       QgsProcessingException,
                       QgsCoordinateReferenceSystem,
                       QgsCoordinateTransform,
                       QgsProject)
from PyQt5.QtCore import QVariant

class Icesat2Production(QgsProcessingAlgorithm):

    INPUT_LAYER = 'INPUT_LAYER'
    DATE_FIELD = 'DATE_FIELD'
    X_FIELD = 'X_FIELD'
    Y_FIELD = 'Y_FIELD'
    HEIGHT_FIELD = 'HEIGHT_FIELD'
    CONF_FIELD = 'CONF_FIELD'
    
    USE_DATE = 'USE_DATE'
    START_DATE = 'START_DATE'
    END_DATE = 'END_DATE'
    
    MIN_CONF = 'MIN_CONF'
    BIN_SIZE = 'BIN_SIZE'
    TARGET_CRS = 'TARGET_CRS'
    OUTPUT_POINTS = 'OUTPUT_POINTS'

    def name(self):
        return 'icesat2_production'

    def displayName(self):
        return 'ICESat-2 Processor'

    def group(self):
        return 'ICESat-2 Processor'

    def groupId(self):
        return 'icesat2_processor'

    def createInstance(self):
        return Icesat2Production()

    def shortHelpString(self):
        return """
        <h3>ICESat-2 Data Processor</h3>
        <p>Processes SlideRule/ATL03 photon data with spatial binning, filtering, and coordinate transformation.</p>
        <p><b>Input Requirements:</b> X/Y fields must be Longitude/Latitude (WGS84).</p>
        <p><b>Output:</b> 3D Points (PointZ) with aggregated height values.</p>
        """

    def initAlgorithm(self, config=None):
        self.addParameter(QgsProcessingParameterFeatureSource(
            self.INPUT_LAYER, 'Input Layer (CSV)', types=[QgsProcessing.TypeVector]))
        
        self.addParameter(QgsProcessingParameterField(
            self.DATE_FIELD, 'Date Column', '', self.INPUT_LAYER, optional=True))
        
        self.addParameter(QgsProcessingParameterField(
            self.X_FIELD, 'Longitude (X) Column [lon_ph]', 'lon_ph', self.INPUT_LAYER))
        self.addParameter(QgsProcessingParameterField(
            self.Y_FIELD, 'Latitude (Y) Column [lat_ph]', 'lat_ph', self.INPUT_LAYER))
        
        self.addParameter(QgsProcessingParameterField(
            self.HEIGHT_FIELD, 'Height Column [ortho_h]', 'ortho_h', self.INPUT_LAYER))
        self.addParameter(QgsProcessingParameterField(
            self.CONF_FIELD, 'Confidence Column', 'confidence', self.INPUT_LAYER))

        self.addParameter(QgsProcessingParameterNumber(
            self.MIN_CONF, 'Min Confidence (0.0 - 1.0)', QgsProcessingParameterNumber.Double, 0.9)) 

        self.addParameter(QgsProcessingParameterBoolean(
            self.USE_DATE, 'Filter by Date?', False))
        self.addParameter(QgsProcessingParameterString(
            self.START_DATE, 'Start Date (YYYY-MM-DD)', '2019-01-01'))
        self.addParameter(QgsProcessingParameterString(
            self.END_DATE, 'End Date (YYYY-MM-DD)', '2022-12-31'))
        
        self.addParameter(QgsProcessingParameterNumber(
            self.BIN_SIZE, 'Bin Size (Degrees)', QgsProcessingParameterNumber.Double, 0.0001))

        self.addParameter(QgsProcessingParameterCrs(
            self.TARGET_CRS, 'Output Coordinate System', 'EPSG:4326'))

        self.addParameter(QgsProcessingParameterFeatureSink(
            self.OUTPUT_POINTS, 'Final Output (3D Points)'))

    def processAlgorithm(self, parameters, context, feedback):
        source = self.parameterAsSource(parameters, self.INPUT_LAYER, context)
        if source is None: raise QgsProcessingException("Invalid input layer")

        x_name = self.parameterAsString(parameters, self.X_FIELD, context)
        y_name = self.parameterAsString(parameters, self.Y_FIELD, context)
        h_name = self.parameterAsString(parameters, self.HEIGHT_FIELD, context)
        c_name = self.parameterAsString(parameters, self.CONF_FIELD, context)
        d_name = self.parameterAsString(parameters, self.DATE_FIELD, context)
        
        min_conf = self.parameterAsDouble(parameters, self.MIN_CONF, context)
        bin_size = self.parameterAsDouble(parameters, self.BIN_SIZE, context)
        use_date = self.parameterAsBool(parameters, self.USE_DATE, context)
        start_str = self.parameterAsString(parameters, self.START_DATE, context)
        end_str = self.parameterAsString(parameters, self.END_DATE, context)

        target_crs = self.parameterAsCrs(parameters, self.TARGET_CRS, context)
        source_crs = QgsCoordinateReferenceSystem("EPSG:4326")
        
        transform = QgsCoordinateTransform(source_crs, target_crs, context.project())
        do_transform = (source_crs != target_crs)

        fields = source.fields()
        try:
            x_idx = fields.indexOf(x_name)
            y_idx = fields.indexOf(y_name)
            h_idx = fields.indexOf(h_name)
            c_idx = fields.indexOf(c_name)
            d_idx = fields.indexOf(d_name) if use_date else -1
            if -1 in [x_idx, y_idx, h_idx]:
                 raise QgsProcessingException("Critical columns missing.")
        except Exception as e:
             raise QgsProcessingException(f"Field Error: {e}")

        out_fields = QgsFields(fields) 
        if out_fields.indexOf("aggregated_count") == -1:
            out_fields.append(QgsField("aggregated_count", QVariant.Int))
        if out_fields.indexOf("mean_height") == -1:
            out_fields.append(QgsField("mean_height", QVariant.Double))

        (sink, dest_id) = self.parameterAsSink(parameters, self.OUTPUT_POINTS,
                                               context, out_fields, QgsWkbTypes.PointZ, target_crs)

        bins = {}
        total = source.featureCount()
        
        iterator = source.getFeatures()
        for i, feature in enumerate(iterator):
            if feedback.isCanceled(): break
            
            attrs = feature.attributes()
            
            try:
                c_val = attrs[c_idx]
                if isinstance(c_val, list): c_val = c_val[0]
                if float(c_val) < min_conf: continue
            except: continue

            if use_date and d_idx != -1:
                try:
                    dt = str(attrs[d_idx]).split('T')[0].strip()
                    if not (start_str <= dt <= end_str): continue
                except: continue

            try:
                vx = float(attrs[x_idx])
                vy = float(attrs[y_idx])
                vh = float(attrs[h_idx])

                gx = round(vx / bin_size)
                gy = round(vy / bin_size)
                key = (gx, gy)

                if key not in bins:
                    bins[key] = {'sh': vh, 'sx': vx, 'sy': vy, 'cnt': 1, 'atts': list(attrs)}
                else:
                    bins[key]['sh'] += vh
                    bins[key]['sx'] += vx
                    bins[key]['sy'] += vy
                    bins[key]['cnt'] += 1
            except: continue
            
            if i % 10000 == 0: feedback.setProgress(100 * i / total)

        for key, data in bins.items():
            cnt = data['cnt']
            avg_x = data['sx'] / cnt
            avg_y = data['sy'] / cnt
            avg_h = data['sh'] / cnt
            
            pt_wgs84 = QgsPoint(avg_x, avg_y, avg_h) 
            geom = QgsGeometry(pt_wgs84)
            
            if do_transform:
                try:
                    res = geom.transform(transform)
                    if res != 0: continue
                except Exception:
                    continue

            final_attrs = data['atts']
            final_attrs[h_idx] = avg_h
            final_attrs.append(cnt)
            final_attrs.append(avg_h)
            
            feat = QgsFeature()
            feat.setGeometry(geom)
            feat.setAttributes(final_attrs)
            sink.addFeature(feat, QgsFeatureSink.FastInsert)

        return {self.OUTPUT_POINTS: dest_id}