# -*- coding: utf-8 -*-
from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import (QgsProcessing, QgsProcessingAlgorithm,
                       QgsProcessingParameterFile,
                       QgsProcessingParameterFileDestination,
                       QgsProcessingParameterString,
                       QgsProcessingParameterNumber,
                       QgsProcessingParameterBoolean,
                       QgsProcessingParameterEnum,
                       QgsProcessingException)
import os
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import re
from PIL import Image
import tempfile
import shutil
import matplotlib.colors as mcolors
import urllib.request
import io
import math
import requests
from osgeo import gdal
import gzip
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

class FurunoWR110ToGIFAlgorithm(QgsProcessingAlgorithm):
    INPUT = 'INPUT'
    OUTPUT = 'OUTPUT'
    TITLE = 'TITLE'
    DURATION = 'DURATION'
    USE_BASEMAP = 'USE_BASEMAP'
    TRANSPARENCY_THRESHOLD = 'TRANSPARENCY_THRESHOLD'
    AGGREGATION_METHOD = 'AGGREGATION_METHOD'
    APPLY_INTERPOLATION = 'APPLY_INTERPOLATION'
    
    def initAlgorithm(self, config=None):
        self.addParameter(
            QgsProcessingParameterFile(
                self.INPUT,
                self.tr('Directorio con archivos H5 Furuno WR110'),
                behavior=QgsProcessingParameterFile.Folder,
                defaultValue=None
            )
        )
        
        self.addParameter(
            QgsProcessingParameterString(
                self.TITLE,
                self.tr('Título para el GIF'),
                defaultValue='Radar Doppler Furuno WR110'
            )
        )
        
        self.addParameter(
            QgsProcessingParameterNumber(
                self.DURATION,
                self.tr('Duración por frame (ms)'),
                type=QgsProcessingParameterNumber.Integer,
                defaultValue=850,
                minValue=100,
                maxValue=5000
            )
        )
        
        # Parámetro para el umbral de transparencia
        self.addParameter(
            QgsProcessingParameterNumber(
                self.TRANSPARENCY_THRESHOLD,
                self.tr('Umbral de transparencia (mm/h)'),
                type=QgsProcessingParameterNumber.Double,
                defaultValue=0.5,  # 0.5 mm/h para precipitación
                minValue=0.0,
                maxValue=10.0
            )
        )
        
        # Método de agregación de elevaciones
        self.addParameter(
            QgsProcessingParameterEnum(
                self.AGGREGATION_METHOD,
                self.tr('Método de agregación de elevaciones'),
                options=[
                    'Máximo (más conservador)',
                    'Promedio ponderado por elevación',
                    'Compuesto vertical inteligente'
                ],
                defaultValue=2  # Compuesto vertical por defecto
            )
        )
        
        self.addParameter(
            QgsProcessingParameterBoolean(
                self.USE_BASEMAP,
                self.tr('Usar mapa base (OpenStreetMap)'),
                defaultValue=True
            )
        )
        
        self.addParameter(
            QgsProcessingParameterBoolean(
                self.APPLY_INTERPOLATION,
                self.tr('Aplicar interpolación para suavizar'),
                defaultValue=True
            )
        )
        
        self.addParameter(
            QgsProcessingParameterFileDestination(
                self.OUTPUT,
                self.tr('GIF de salida'),
                fileFilter='GIF Files (*.gif)',
                optional=True
            )
        )
    
    def processAlgorithm(self, parameters, context, feedback):
        input_folder = self.parameterAsFile(parameters, self.INPUT, context)
        output_gif = self.parameterAsFileOutput(parameters, self.OUTPUT, context)
        title = self.parameterAsString(parameters, self.TITLE, context)
        frame_duration = self.parameterAsInt(parameters, self.DURATION, context)
        use_basemap = self.parameterAsBool(parameters, self.USE_BASEMAP, context)
        transparency_threshold = self.parameterAsDouble(parameters, self.TRANSPARENCY_THRESHOLD, context)
        aggregation_method = self.parameterAsEnum(parameters, self.AGGREGATION_METHOD, context)
        apply_interpolation = self.parameterAsBool(parameters, self.APPLY_INTERPOLATION, context)
        
        if not input_folder:
            raise QgsProcessingException(self.tr("No se ha seleccionado un directorio válido."))
            
        # Si no se especifica una salida, creamos un nombre basado en la entrada
        if not output_gif:
            output_gif = os.path.join(input_folder, "furuno_wr110_animation.gif")
            
        feedback.pushInfo(f"Procesando directorio: {input_folder}")
        feedback.pushInfo(f"GIF de salida: {output_gif}")
        feedback.pushInfo(f"Título: {title}")
        feedback.pushInfo(f"Duración por frame: {frame_duration}ms")
        feedback.pushInfo(f"Usar mapa base: {'Sí' if use_basemap else 'No'}")
        feedback.pushInfo(f"Umbral de transparencia: {transparency_threshold} mm/h")
        
        # Crear un directorio temporal para los archivos PNG intermedios
        temp_png_folder = tempfile.mkdtemp()
        feedback.pushInfo(f"Directorio temporal para PNG: {temp_png_folder}")
        
        try:
            # Importar el algoritmo WR110 para reutilizar sus métodos
            from .furuno_wr110_algorithm import FurunoWR110ToGeoTIFFAlgorithm
            wr110_algorithm = FurunoWR110ToGeoTIFFAlgorithm()
            
            # Función para descargar mosaicos de OpenStreetMap
            def download_osm_tiles(min_lon, max_lon, min_lat, max_lat, zoom=10):
                """Descarga mosaicos de OpenStreetMap para un área y zoom específicos."""
                def deg2num(lat_deg, lon_deg, zoom):
                    lat_rad = math.radians(lat_deg)
                    n = 2.0 ** zoom
                    xtile = int((lon_deg + 180.0) / 360.0 * n)
                    ytile = int((1.0 - math.asinh(math.tan(lat_rad)) / math.pi) / 2.0 * n)
                    return (xtile, ytile)
                
                def num2deg(xtile, ytile, zoom):
                    n = 2.0 ** zoom
                    lon_deg = xtile / n * 360.0 - 180.0
                    lat_rad = math.atan(math.sinh(math.pi * (1 - 2 * ytile / n)))
                    lat_deg = math.degrees(lat_rad)
                    return (lon_deg, lat_deg)
                
                x_min, y_max = deg2num(min_lat, min_lon, zoom)
                x_max, y_min = deg2num(max_lat, max_lon, zoom)
                
                # Limitar el número de mosaicos
                if (x_max - x_min + 1) > 6:
                    center_x = (x_min + x_max) // 2
                    x_min = max(0, center_x - 3)
                    x_max = center_x + 3
                
                if (y_max - y_min + 1) > 6:
                    center_y = (y_min + y_max) // 2
                    y_min = max(0, center_y - 3)
                    y_max = center_y + 3
                
                feedback.pushInfo(f"Descargando mosaicos OSM: zoom={zoom}, x={x_min}-{x_max}, y={y_min}-{y_max}")
                
                tile_size = 256
                width = (x_max - x_min + 1) * tile_size
                height = (y_max - y_min + 1) * tile_size
                map_img = Image.new('RGBA', (width, height))
                
                url_template = "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
                headers = {'User-Agent': 'QGIS_Plugin/1.0 (Furuno_WR110_to_GIF; fapucha@utpl.edu.ec)'}
                
                for x in range(x_min, x_max + 1):
                    for y in range(y_min, y_max + 1):
                        url = url_template.format(z=zoom, x=x, y=y)
                        try:
                            response = requests.get(url, headers=headers, timeout=5)
                            if response.status_code == 200:
                                tile = Image.open(io.BytesIO(response.content))
                                map_img.paste(tile, ((x - x_min) * tile_size, (y - y_min) * tile_size))
                        except Exception as e:
                            blank = Image.new('RGBA', (tile_size, tile_size), (255, 255, 255, 0))
                            map_img.paste(blank, ((x - x_min) * tile_size, (y - y_min) * tile_size))
                
                nw_lon, nw_lat = num2deg(x_min, y_min, zoom)
                se_lon, se_lat = num2deg(x_max + 1, y_max + 1, zoom)
                map_extent = (nw_lon, se_lon, se_lat, nw_lat)
                
                return map_img, map_extent
            
            # Definir la paleta de colores para precipitación (mm/h)
            # Ajustada para mejor visualización con transiciones en puntos clave
            rain_colors = [
                (0, 0, 0, 0),              # Transparente (0 mm/h)
                (0.12, 0.56, 1, 0.6),      # Azul claro (2 mm/h)
                (0.53, 0.81, 0.98, 0.7),   # Azul cielo (5 mm/h)
                (0.2, 0.8, 0.2, 0.8),      # Verde (10 mm/h)
                (1, 1, 0, 0.85),           # Amarillo (15 mm/h)
                (1, 0.6, 0, 0.9),          # Naranja (20 mm/h)
                (1, 0, 0, 0.9),            # Rojo (30 mm/h)
                (0.7, 0, 0.7, 0.95),       # Púrpura (40 mm/h)
                (0.4, 0, 0.4, 1.0)         # Morado oscuro (50+ mm/h)
            ]
            
            # Crear colormap con los puntos de anclaje específicos
            positions = [0, 0.04, 0.1, 0.2, 0.3, 0.4, 0.6, 0.8, 1.0]  # Posiciones normalizadas (0-50 mm/h)
            cmap_colors = [(positions[i], rain_colors[i]) for i in range(len(positions))]
            rain_cmap = plt.cm.colors.LinearSegmentedColormap.from_list("rain_cmap", 
                [(pos, color[:3] + (color[3],)) for pos, color in cmap_colors])
            
            # Normalización para valores de precipitación 0-50 mm/h (más realista)
            norm = plt.cm.colors.Normalize(vmin=0, vmax=50)
            
            # Función para procesar un archivo H5
            def process_h5_file(file_path, output_path, radar_title, thresh):
                try:
                    # Descomprimir si es necesario
                    h5_file = file_path
                    temp_dir = None
                    
                    if file_path.endswith('.gz'):
                        h5_file = wr110_algorithm.decompress_gz_file(file_path)
                        temp_dir = os.path.dirname(h5_file)
                    
                    # Extraer datos de múltiples elevaciones
                    all_rate_data = []
                    all_metadata = []
                    
                    for dataset_number in range(1, 6):
                        data, metadata = wr110_algorithm.extract_rate_data_gdal(h5_file, dataset_number, 1)
                        if data is not None and metadata is not None:
                            all_rate_data.append(data)
                            all_metadata.append(metadata)
                    
                    if not all_rate_data:
                        feedback.pushWarning(f"No se pudieron extraer datos de {os.path.basename(file_path)}")
                        return False
                    
                    # Agregar elevaciones según el método seleccionado
                    # Mapear el índice del método para compatibilidad
                    method_mapping = {0: 0, 1: 2, 2: 4}  # Máximo, Promedio ponderado, Compuesto vertical
                    mapped_method = method_mapping.get(aggregation_method, 2)
                    
                    rain_data = wr110_algorithm.aggregate_elevations(all_rate_data, all_metadata, mapped_method)
                    
                    if rain_data is None:
                        return False
                    
                    # Aplicar interpolación si está habilitada
                    if apply_interpolation:
                        rain_data = wr110_algorithm.interpolate_small_gaps(rain_data, max_gap_size=3)
                    
                    # Crear máscara para valores menores al umbral
                    mask = (~np.isnan(rain_data)) & (rain_data < thresh)
                    rain_masked = np.ma.masked_array(rain_data, mask=mask)
                    
                    # Obtener metadatos del radar
                    metadata = all_metadata[0]
                    center_lat = metadata.get('lat', -3.9869)
                    center_lon = metadata.get('lon', -79.1969)
                    nrays = metadata.get('nrays', rain_data.shape[0])
                    nbins = metadata.get('nbins', rain_data.shape[1])
                    rscale = metadata.get('rscale', 75.0)
                    
                    # Calcular extensión aproximada del radar
                    max_range_km = (nbins * rscale) / 1000.0
                    # Aproximación: 1 grado de latitud ≈ 111 km
                    lat_range = max_range_km / 111.0
                    lon_range = max_range_km / (111.0 * np.cos(np.radians(center_lat)))
                    
                    min_lon = center_lon - lon_range
                    max_lon = center_lon + lon_range
                    min_lat = center_lat - lat_range
                    max_lat = center_lat + lat_range
                    
                    # Crear figura
                    fig, ax = plt.subplots(figsize=(12, 10), dpi=150)
                    
                    # Configurar límites
                    data_extent = [min_lon, max_lon, min_lat, max_lat]
                    ax.set_xlim(min_lon, max_lon)
                    ax.set_ylim(min_lat, max_lat)
                    
                    # Añadir mapa base si está disponible
                    if use_basemap and base_map_image is not None:
                        try:
                            map_array = np.array(base_map_image)
                            ax.imshow(map_array, extent=base_map_extent, alpha=0.7, aspect='auto')
                        except Exception as e:
                            feedback.pushWarning(f"Error al mostrar mapa base: {str(e)}")
                    
                    # Convertir datos polares a cartesianos para visualización
                    # Crear grid de coordenadas
                    angles = np.linspace(0, 2*np.pi, nrays, endpoint=False)
                    ranges = np.arange(nbins) * rscale / 1000.0  # en km
                    
                    # Crear meshgrid polar
                    theta, r = np.meshgrid(angles, ranges, indexing='ij')
                    
                    # Convertir a cartesianas relativas al centro
                    x_rel = r * np.sin(theta)  # Este-Oeste en km
                    y_rel = r * np.cos(theta)  # Norte-Sur en km
                    
                    # Convertir a coordenadas geográficas
                    lon_grid = center_lon + x_rel / (111.0 * np.cos(np.radians(center_lat)))
                    lat_grid = center_lat + y_rel / 111.0
                    
                    # Mostrar datos de precipitación
                    rain_plot = ax.pcolormesh(
                        lon_grid,
                        lat_grid,
                        rain_masked,
                        cmap=rain_cmap,
                        norm=norm,
                        alpha=0.85,
                        shading='auto'
                    )
                    
                    # Función para convertir coordenadas
                    def decimal_to_dms(decimal, is_latitude=True):
                        direction = 'N' if decimal >= 0 and is_latitude else 'S' if is_latitude else 'E' if decimal >= 0 else 'W'
                        decimal = abs(decimal)
                        degrees = int(decimal)
                        minutes = int((decimal - degrees) * 60)
                        seconds = ((decimal - degrees) * 60 - minutes) * 60
                        return f"{degrees}°{minutes}'{seconds:.0f}\"{direction}"
                    
                    # Añadir etiquetas
                    x_ticks = np.linspace(min_lon, max_lon, 5)
                    y_ticks = np.linspace(min_lat, max_lat, 5)
                    ax.set_xticks(x_ticks)
                    ax.set_yticks(y_ticks)
                    ax.set_xticklabels([decimal_to_dms(x, False) for x in x_ticks], fontsize=8)
                    ax.set_yticklabels([decimal_to_dms(y) for y in y_ticks], fontsize=8)
                    
                    # Añadir barra de color
                    cbar = plt.colorbar(rain_plot, shrink=0.6)
                    cbar.set_label('Intensidad de precipitación (mm/h)')
                    
                    # Niveles para la barra de colores - menos etiquetas, mejor espaciadas
                    tick_levels = [0, 5, 10, 20, 30, 50]
                    cbar.set_ticks(tick_levels)
                    
                    tick_labels = [
                        '0',
                        '5\n(Moderada)', 
                        '10\n(Fuerte)', 
                        '20\n(Intensa)',
                        '30\n(Muy intensa)',
                        '50\n(Torrencial)'
                    ]
                    cbar.set_ticklabels(tick_labels)
                    
                    # Información sobre el umbral
                    plt.figtext(0.5, 0.01, f"Umbral de transparencia: {thresh} mm/h", 
                             ha="center", fontsize=10, bbox={"facecolor":"white", "alpha":0.5, "pad":5})
                    
                    # Añadir marca de tiempo
                    file_name = os.path.basename(file_path)
                    match = re.search(r'(\d{4})(\d{2})(\d{2})_(\d{2})(\d{2})(\d{2})', file_name)
                    if match:
                        year, month, day, hour, minute, second = map(int, match.groups())
                        utc_time = datetime(year, month, day, hour, minute, second)
                        ecuador_time = utc_time - timedelta(hours=5)
                        time_str = ecuador_time.strftime("%Y-%m-%d %H:%M:%S (UTC-5)")
                        plt.title(f"{radar_title}\n{time_str}", fontsize=12, pad=20)
                    
                    # Añadir información del radar
                    info_text = f"Radar: Furuno WR110 | Alcance: {max_range_km:.0f} km | Resolución: {rscale} m"
                    plt.figtext(0.5, 0.96, info_text, ha="center", fontsize=9, 
                             bbox={"facecolor":"white", "alpha":0.7, "pad":3})
                    
                    # Guardar imagen
                    plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=150)
                    plt.close()
                    
                    # Limpiar
                    if temp_dir and os.path.exists(temp_dir):
                        shutil.rmtree(temp_dir)
                    
                    return True
                    
                except Exception as e:
                    feedback.pushWarning(f"Error al procesar archivo {file_path}: {str(e)}")
                    import traceback
                    feedback.pushWarning(traceback.format_exc())
                    return False
            
            # Procesar archivos
            png_files = []
            h5_files = [f for f in sorted(os.listdir(input_folder)) 
                       if f.endswith(('.h5', '.hdf5', '.h5.gz'))]
            total_files = len(h5_files)
            
            if total_files == 0:
                feedback.pushWarning("No se encontraron archivos H5 en el directorio.")
                return {self.OUTPUT: None}
            
            feedback.pushInfo(f"Se encontraron {total_files} archivos H5 para procesar.")
            
            # Cargar mapa base una sola vez si se usa
            base_map_image = None
            base_map_extent = None
            
            if use_basemap and total_files > 0:
                try:
                    # Usar el primer archivo para obtener coordenadas aproximadas
                    first_file = os.path.join(input_folder, h5_files[0])
                    
                    # Coordenadas por defecto para Loja, Ecuador
                    center_lat = -3.9869
                    center_lon = -79.1969
                    
                    # Intentar extraer coordenadas del archivo
                    try:
                        h5_file = first_file
                        if first_file.endswith('.gz'):
                            h5_file = wr110_algorithm.decompress_gz_file(first_file)
                        
                        _, metadata = wr110_algorithm.extract_rate_data_gdal(h5_file, 1, 1)
                        if metadata:
                            center_lat = metadata.get('lat', center_lat)
                            center_lon = metadata.get('lon', center_lon)
                        
                        if first_file.endswith('.gz') and os.path.exists(os.path.dirname(h5_file)):
                            shutil.rmtree(os.path.dirname(h5_file))
                    except:
                        pass
                    
                    # Calcular extensión aproximada (150 km de radio típico)
                    lat_range = 150.0 / 111.0
                    lon_range = 150.0 / (111.0 * np.cos(np.radians(center_lat)))
                    
                    min_lon = center_lon - lon_range
                    max_lon = center_lon + lon_range
                    min_lat = center_lat - lat_range
                    max_lat = center_lat + lat_range
                    
                    feedback.pushInfo(f"Descargando mapa base para área: {min_lat:.4f}, {min_lon:.4f} - {max_lat:.4f}, {max_lon:.4f}")
                    base_map_image, base_map_extent = download_osm_tiles(min_lon, max_lon, min_lat, max_lat, zoom=9)
                    
                except Exception as e:
                    feedback.pushWarning(f"Error al obtener mapa base: {str(e)}")
                    use_basemap = False
            
            # Procesar cada archivo H5
            for idx, file in enumerate(h5_files):
                if feedback.isCanceled():
                    break
                
                feedback.setProgress(int((idx + 1) / total_files * 100))
                
                file_path = os.path.join(input_folder, file)
                output_path = os.path.join(temp_png_folder, f"frame_{idx:03d}.png")
                
                feedback.pushInfo(f"Procesando archivo {idx+1}/{total_files}: {file}")
                
                if process_h5_file(file_path, output_path, title, transparency_threshold):
                    png_files.append(output_path)
                    feedback.pushInfo(f"✓ Imagen generada: frame_{idx:03d}.png")
            
            # Crear GIF animado
            if png_files:
                feedback.pushInfo(f"Generando GIF con {len(png_files)} imágenes...")
                
                png_files = sorted(png_files)
                images = [Image.open(png) for png in png_files]
                
                images[0].save(
                    output_gif,
                    save_all=True,
                    append_images=images[1:],
                    optimize=False,
                    duration=frame_duration,
                    loop=0
                )
                
                feedback.pushInfo(f"GIF creado exitosamente: {output_gif}")
            else:
                feedback.pushWarning("No se generaron imágenes para el GIF")
        
        except Exception as e:
            feedback.reportError(f"Error en el procesamiento: {str(e)}")
            import traceback
            feedback.reportError(traceback.format_exc())
            raise QgsProcessingException(str(e))
        finally:
            # Limpiar archivos temporales
            try:
                shutil.rmtree(temp_png_folder)
                feedback.pushInfo("Archivos temporales eliminados")
            except Exception as e:
                feedback.pushWarning(f"No se pudieron eliminar todos los archivos temporales: {str(e)}")
        
        return {self.OUTPUT: output_gif}
        
    def name(self):
        return 'furunowr110togif'
        
    def displayName(self):
        return self.tr('Furuno WR110 to GIF')
        
    def group(self):
        return self.tr('Radar Meteorológico')
        
    def groupId(self):
        return 'radarmeteo'
        
    def shortHelpString(self):
        return self.tr('Crea un GIF animado a partir de archivos H5 del radar Doppler Furuno WR110. '
                     'Permite visualizar la evolución temporal de la precipitación con mapa base opcional. '
                     'Soporta archivos comprimidos (.h5.gz) y múltiples métodos de agregación de elevaciones. '
                     'El umbral de transparencia está en mm/h para datos de precipitación. '
                     'Departamento de Ingeniería Civil - UTPL')
        
    def tr(self, string):
        return QCoreApplication.translate('Processing', string)
        
    def createInstance(self):
        return FurunoWR110ToGIFAlgorithm()