# -*- coding: utf-8 -*-

"""
/***************************************************************************
ViewshedAnalysis
A QGIS plugin
begin : 2013-05-22
copyright : (C) 2013 by Zoran Čučković
email : /
***************************************************************************/

/***************************************************************************
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation version 2 of the License, or *
* any later version. *
* *
***************************************************************************/
"""

from qgis.PyQt.QtCore import QCoreApplication
from qgis.core import *
from osgeo import gdal
import numpy as np

from os import path


#buffer modes
SINGLE = 0
ADD = 1
MIN = 2
MAX = 3


"""
This class handles input and output of raster data.
It doesn't do any calculations besides combining analysed patches. 
"""
class Raster:

    
    def __init__(self, raster, output=None, crs=None):
	
		
        gdal_raster=gdal.Open(raster)

       
        self.crs = crs if crs else gdal_raster.GetProjection()

        
        
        self.rst = gdal_raster #for speed, keep open raster ?
                        
        # size is y first, like numpy
        self.size = (gdal_raster.RasterYSize, gdal_raster.RasterXSize)
           
    
        #adfGeoTransform[0] /* top left x */
        #adfGeoTransform[1] /* w-e pixel resolution */
        #adfGeoTransform[2] /* rotation, 0 if image is "north up" */
        #adfGeoTransform[3] /* top left y */
        #adfGeoTransform[4] /* rotation, 0 if image is "north up" */
        #adfGeoTransform[5] /* n-s pixel resolution */

        gt=gdal_raster.GetGeoTransform()
        
        self.pix_x, self.pix_y = abs(gt[1]), abs(gt[5]) 
        self.pix = self.pix_x #for compatibility with older code ...
                
        raster_x_min = gt[0]
        raster_y_max = gt[3] # it's top left y, so maximum!

        raster_y_min = raster_y_max - self.size[0] * self.pix
        raster_x_max = raster_x_min + self.size[1] * self.pix

        
##        xsize = gdalData.RasterXSize 
##        ysize = gdalData.RasterYSize 
        
        self.extent = [raster_x_min, raster_y_min, 
                       raster_x_max, raster_y_max]
        srcband = gdal_raster.GetRasterBand(1)

        self.min, self.max = srcband.GetStatistics(True, True)[:2]

         # Get raster statistics
        
##        raster_max= srcband.GetMaximum()
##        raster_min = srcband.GetMinimum()
##
        self.nodata = srcband.GetNoDataValue()
##
##        data_type =  srcband.DataType

        """
        NP2GDAL_CONVERSION = {
          "uint8": 1,
          "int8": 1,
          "uint16": 2,
          "int16": 3,
          "uint32": 4,
          "int32": 5,
          "float32": 6,
          "float64": 7,
          "complex64": 10,
          "complex128": 11,
        }

        """

        self.output = output

    def pixel_coords (self, x, y):
        
        x_min = self.extent[0]; y_max = self.extent[3]
        return (int((x - x_min) / self.pix),
                int((y_max - y) / self.pix)) #reversed !
    

     

    """
    This is the largest window, used for all analyses.
    Smaller windows are slices of this one.

    [ theoretically, a window should be a subclass,
    but we can have only one window at a time ...]

    """

    def set_master_window (self, radius,
                           size_factor = 1,
                           curvature=False,
                           refraction =0,
                           background_value=0,
                           pad=False):           
        
        
        self.radius = radius
        radius_pix = int(radius/self.pix)
        self.radius_pix = radius_pix
        
        full_size = radius_pix *2 +1
        self.window = np.zeros((full_size, full_size))
        # this is not mask value ! (self.fill)
        self.initial_value=background_value
        self.pad= pad
        
        
        self.mx_dist = self.distance_matrix()

        if curvature:
            self.curvature =  self.curvature_matrix(refraction)
        else:
            self.curvature = 0

        self.angles = self.angular_matrix()
        
    """
    Create the output file in memory and determine the mode of combining results
    (addition or min/max)

    If live_memory is True, a buffer will be the same size as the entire raster,
    otherwise it will have the size of master window. The latter approach is 15 - 20% slower. 
    
    """
    def set_buffer (self, mode = ADD, live_memory = False):

        self.fill =0 if mode == ADD else np.nan

        self.mode = mode 

        if live_memory:

            self.result = np.zeros(self.size)
            if mode != ADD : self.result [:] = np.nan
                        
        else: self.result = None           
        

    """
    Name is self-explanatory... Divide with pixel size if needed.
    Diameter is expressed as half of the major axis plus half of the minor:
    this should work best for moderate latitudes.
    """

    """
    Model vertical drop from a plane to spherical surface (of the Earth)
    Note that it has to be multiplied with pixel size to get usable values

    """
    def get_curvature_earth (self):

        crs= self.crs		
    
        start = crs.find("SPHEROID") + len("SPHEROID[")
        end = crs.find("]],", start) + 1
        tmp = crs[start:end].split(",")

        try:
                semiMajor = float(tmp[1])
                if not 6000000 < semiMajor < 7000000:
                        semiMajor = 6378137
        except:
                semiMajor = 6378137

        try:
                flattening = float(tmp[2])
                if not 296 < flattening < 301:
                        flattening = 298.257223563
        except:
                flattening = 298.257223563

        semiMinor = semiMajor - semiMajor / flattening
        
        return semiMajor + semiMinor

        
    def curvature_matrix(self, refraction=0):
        #see https://www.usna.edu/Users/oceano/pguth/md_help/html/demb30q0.htm
    
        dist_squared = self.distance_matrix(squared=True)
        # all distances are in pixels in doViewshed module !!
        # formula is  squared distance / diam_earth 
        # need to divide all with pixel size (squared !!)
        D = self.get_curvature_earth() / (self.pix **2)
            
        return (dist_squared / D) * (1 - refraction) 

 
    """
    Calculate a mask (can be set for each point)
    """
    def set_mask (self,
                  radius_out,
                  radius_in=None,
                  azimuth_1=None,
                  azimuth_2=None ):

        #if not radius_out : radius_out = self.mx_dist.size[0]

        mask = self.mx_dist < radius_out

        if radius_in : mask *= self.mx_dist > radius_in 

        if azimuth_1 != None and azimuth_2 != None:

            operator = np.logical_and if azimuth_1 < azimuth_2 else np.logical_or

            mask_az = operator(self.angles > azimuth_1, self.angles < azimuth_2)

         
                
            mask *= mask_az
    
        self.mask = ~ mask

    """
    Return a map of distances from the central pixel.
    Attention: these are pixel distances, not geographical !
    (to convert to geographical distances: multiply with pixel size)
    """
    def distance_matrix (self, squared=False):

        r = self.radius_pix
        window = self.window.shape[0]
        
        temp_x= ((np.arange(window) - r) ) **2
        temp_y= ((np.arange(window) - r) ) **2

        if not squared:
            return np.sqrt(temp_x[:,None] + temp_y[None,:])
        # squared values
        else: return temp_x[:,None] + temp_y[None,:]


    def angular_matrix (self):
        r = self.radius_pix
        window = self.window.shape[0]

        temp_x= np.arange(window)[::-1] - r
        temp_y= np.arange(window) - r

        angles=np.arctan2(temp_y[None,:], temp_x[:,None]) * 180 / np.pi

        angles[angles<0] += 360

        return angles

    """
    Extract a quadrangular window from the raster file.
    Observer point (x,y) is always in the centre.

    Upon opening a window, all parameters regarding its size and position are
    registered in the Raster class instance - and reused for writing results
 
    """
    def open_window (self, pixel_coord):

        rx = self.radius_pix
        x, y = pixel_coord

        #NONSENSE !!there can be no smaller window than the master window (unless cropped)
        #to place smaller windows inside the master window
        diff_x = self.window.shape[1] - (rx *2 +1)
        diff_y =  self.window.shape[0] - (rx *2 +1)

        if x <= rx:  #cropping from the front
            x_offset =0
            x_offset_dist_mx = (rx - x) + diff_x
        else:               #cropping from the back
            x_offset = x - rx
            x_offset_dist_mx= 0

                           
        x_offset2 = min(x + rx +1, self.size[1]) #could be enormus radius, so check both ends always
        
        if y <= rx:
            y_offset =0
            y_offset_dist_mx = (rx - y) + diff_y
        else:
            y_offset = y - rx
            y_offset_dist_mx= 0

        y_offset2 = min(y + rx + 1, self.size[0] )

        window_size_y = y_offset2 - y_offset
        window_size_x = x_offset2 - x_offset

        self.window_slice = np.s_[y_offset : y_offset + window_size_y,
                                  x_offset : x_offset + window_size_x ]
        
        in_slice_y = (y_offset_dist_mx, y_offset_dist_mx +  window_size_y)
        in_slice_x = (x_offset_dist_mx , x_offset_dist_mx + window_size_x)

        self.inside_window_slice = [in_slice_y, in_slice_x]

        self.gdal_slice = [x_offset, y_offset, window_size_x, window_size_y]

        self.window [:]=self.initial_value 
        
        self.window[ slice(*in_slice_y), slice(*in_slice_x)] = \
                         self.rst.ReadAsArray(*self.gdal_slice ).astype(float)

        if isinstance(self.curvature, np.ndarray):
            
            self.window[
                slice(*in_slice_y), slice(*in_slice_x)] -= self.curvature[
                slice(*in_slice_y), slice(*in_slice_x)]
        # there is a problem with interpolation:
        # when the analysis window stretches outside raster borders
        # the last row/column will be interpolated with the fill value
        # the solution is to copy the same values or to catch these vaules (eg. by initialising to np.nan)
        if self.pad:
            if x_offset_dist_mx:
                self.window[:,in_slice_x[0] -1] = self.window[:,in_slice_x[0]]
            # slice[:4] will give indices 0 to 3, so we need -1 to get the last index!
            if x + rx + 1 > self.size[1]:
                self.window[:,in_slice_x[1] ] =  self.window[:,in_slice_x[1] -1 ]

            if y_offset_dist_mx:
                self.window[in_slice_y[0] -1,:] = self.window[in_slice_y[0],:]

            if y + rx + 1 > self.size[0]:
                self.window[in_slice_y[1] , : ] = self.window[in_slice_y[1] -1, : ]
        
##        self.offset = (x_offset, y_offset)
##        self.win_offset= (x_offset_dist_mx, y_offset_dist_mx)
##        self.win_size = (window_size_x, window_size_y)
        return self.window
   
    def open_raster (self):
        self.raster = self.rst.ReadAsArray().astype(float)
        return self.raster
        """
        reads entire raster
        """
        
   
    def add_to_buffer(self, in_array, report = False):
        """
        Insert a numpy matrix to the same place where data has been extracted.
        Data can be added-up or chosen from highest/lowest values.
        All parameteres are copied from class properties
        because only one window is possible at a time.
        """

        try: in_array[self.mask] = self.fill
        except: pass #an array may be unmasked 

        y_in = slice(*self.inside_window_slice[0])
        x_in = slice(*self.inside_window_slice[1])

        m_in = in_array [y_in, x_in]
  

        if isinstance(self.result, np.ndarray):
            m = self.result[self.window_slice]
        else :
            m = self.gdal_output.ReadAsArray(*self.gdal_slice).astype(float)

        if self.mode == SINGLE: m = m_in
        
        elif self.mode == ADD:  m += m_in

        else:
            flt = m_in < m if self.mode == MIN else m_in > m
            
        #there is a problem to initialise a comparison without knowing min/max values
##            # nan will always give False in any comarison
##            # so make a trick with isnan()...
            flt[np.isnan(m)]= True

            m[flt]= m_in[flt]
        
        if not isinstance(self.result, np.ndarray): #write to file
            
            bd = self.gdal_output.GetRasterBand(1)
            #for writing gdal takes only x and y offset (1st 2 values of self.gdal_slice)
            bd.WriteArray( m, *self.gdal_slice[:2] )
    
            bd.FlushCache()
            
            
            #np.where(self.result [self.window_slice] < in_array [self.inside_window_slice],
            #         in_array [self.inside_window_slice], self.result [self.window_slice])      
            

        if report:
            try:
               # Count values outside mask (mask is True on the outside!)
                crop = np.count_nonzero(self.mask[y_in, x_in])

                c = np.count_nonzero(m_in)

                # nans in the mask are non_zero
                if self.fill != 0 : c -= crop
                                #this is total area analysed 
                return ( c ,  - crop )

            except: #unmasked array
                return (np.count_nonzero(m_in), m_in.size) 
    
        
    def read_buffer(self, apply_mask=True, fill_value=None):
        """
        Read a windowed matrix from the buffer with a subset inside it,
        matching the layout used in add_to_buffer.
        
        CANNOT USE SPECIFIC POINTS
    
        Parameters:
            apply_mask (bool): Whether to apply the stored mask.
            fill_value (float or None): Fill value for areas outside the subset.
                                        Defaults to self.fill if not given.
    
        Returns:
            numpy.ndarray: Full window array with subset inserted.
        """
    
        if fill_value is None:
            fill_value = self.fill

        # Create full window filled with fill_value
        full_window = np.full(self.window.shape, fill_value)
    
        # Determine the subset slice inside the window
        y_in = slice(*self.inside_window_slice[0])
        x_in = slice(*self.inside_window_slice[1])
    
        # Read data from source
        if isinstance(self.result, np.ndarray):
                        
            data_subset = self.result[self.window_slice].copy()
        else:
            data_subset = self.gdal_output.ReadAsArray(*self.gdal_slice).astype(float)
    
        # Apply mask if needed
        if apply_mask and hasattr(self, "mask") and self.mask is not None:
            try:
                data_subset[self.mask] = fill_value
            except:
                pass  # Incompatible mask or unmasked array
    
        # Place the subset in the correct position inside full window
        full_window[y_in, x_in] = data_subset
    
        return full_window

    def write_output (self, file_name=None,
                     no_data = np.nan,
                     dataFormat = gdal.GDT_Float32):
        """
        Writing analysis result.
         - If there is no result assigned to the class, it will produce an empty file.
         - If there is no file name, it will write the result to previously created file. 
           
        """

        if file_name: self.output = file_name
        
        if self.output :
            
            driver = gdal.GetDriverByName('GTiff')
            ds = driver.Create(self.output, self.size[1], self.size[0], 1, dataFormat)
            ds.SetProjection(self.crs)
            ds.SetGeoTransform(self.rst.GetGeoTransform())
            ds.GetRasterBand(1).SetNoDataValue(no_data) 
            try:
                ds.GetRasterBand(1).Fill(self.fill)
                ds.FlushCache() #important, otherwise we need to delete ds, to force the flush!
            except: pass
            # for buffered operations (...hacky ...)
            self.gdal_output = ds           
      
        else:
            ds = self.gdal_output
            

        try:
            ds.GetRasterBand(1).WriteArray(self.result )
            ds = None
        except: pass
    

    def span_window_not_used (point1, point2, padding_ratio):
           """
           ==> open a window between two points
           points as lists [x,y]
           padding ratio : in terms of distance (0.5 : half distance etc.)
           """

           # Step 1: Center of the window (mean point)
           center = (p1 + p2) // 2

           # Step 2: Full span between points
           span = np.abs(p2 - p1)

           span +=  2 * span // padding

           # Step 4: Compute window corners (top-left and bottom-right)
           top_left = (center - span // 2).astype(int)
           bottom_right = top_left + span
           
           return center, span //2

	  
    def merge_windows_NOT_IMPLEMENTED(w1, w2):
        """
        Sum two overlapping windows and place them in a common array.
        
        Parameters:
        - w1, w2: two objects with attributes:
            - window: the extracted window data
            - gdal_slice: [x_offset, y_offset, width, height] relative to the main array
    
        Returns:
        - combined_array: summed array of appropriate size
        - combined_offset: (x, y) offset of the top-left corner in the original raster
        """
        # Unpack slice data
        x1, y1, w1_w, w1_h = w1.gdal_slice
        x2, y2, w2_w, w2_h = w2.gdal_slice
    
        # Determine overall bounding box
        x_min = min(x1, x2)
        y_min = min(y1, y2)
        x_max = max(x1 + w1_w, x2 + w2_w)
        y_max = max(y1 + w1_h, y2 + w2_h)
    
        total_w = x_max - x_min
        total_h = y_max - y_min
    
        # Initialize combined array
        combined_array = np.zeros((total_h, total_w), dtype=float)
    
        # Compute placement of each window in combined array
        w1_x_in = x1 - x_min
        w1_y_in = y1 - y_min
        w2_x_in = x2 - x_min
        w2_y_in = y2 - y_min
    
        # Add first window
        combined_array[
            w1_y_in:w1_y_in + w1_h,
            w1_x_in:w1_x_in + w1_w
        ] += w1.window[:w1_h, :w1_w]  # in case padding was larger
    
        # Add second window
        combined_array[
            w2_y_in:w2_y_in + w2_h,
            w2_x_in:w2_x_in + w2_w
        ] += w2.window[:w2_h, :w2_w]
    
        return combined_array, (x_min, y_min)
        
        

