# -*- coding: utf-8 -*-
"""
/***************************************************************************
 LiDARForestryHeight
                                 A QGIS plugin. LiDAR Forestry Height 
                                 generates a DEM with the forest height, 
                                 calculated from a classified LiDAR point
                                 cloud using LasPy Library
                                 
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                             -------------------
        begin                : 2018-09-24
        copyright            : (C) 2019 by PANOimagen S.L.
        email                : info@panoimagen.com
        git sha              : $Format:%H$
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""
import numpy as np
import scipy
from scipy.interpolate import griddata

from . import height_calculator
from osgeo import gdal, osr
from gdal import gdalconst

class RasterizeLiDAR(object):

    def __init__(self, lidar_extent, pixel_size, method='nearest'):
        
        self.lidar_extent = lidar_extent
        self.pixel_size = pixel_size
        self.epsg_code = None
        self.method = method

    def makegrid(self):
        """ This function generates the center grid of the future raster.
            This is needed to interpolate between the LiDAR points
        """
        corner0 = self.lidar_extent[2]
        # px_center0 = (min_x, max_y)
        px_center0 = ((corner0[0] + self.pixel_size / 2),
                      (corner0[1] - self.pixel_size / 2))
        corner1 = self.lidar_extent[1]
        # px_center1 = (max_x, min_y)
        px_center1 = ((corner1[0] - self.pixel_size / 2),
                      (corner1[1] + self.pixel_size / 2))
                
        self.grid_x, self.grid_y = np.mgrid[
                            px_center0[0]: px_center1[0]: self.pixel_size, 
                            px_center0[-1]: px_center1[-1]: -self.pixel_size]

    def interpolate_grid(self, lidar_array):
        """ This function generates an interpolated array from point cloud.
            lidar x-y points, z_values and mesh_grid is needed. 
            Avaible methods are nearest, linear, cubic (1-D) and cubic (2-D) 
        """
        self.lidar_xy_array = lidar_array[0]
        self.lidar_altitudes_array = lidar_array[-1]
        self.makegrid()
        interpolate_grid = griddata(self.lidar_xy_array, 
                                    self.lidar_altitudes_array,
                                    (self.grid_x, self.grid_y), 
                                    method=self.method)
        return interpolate_grid.T

    def array_2_raster(self, raster_array, dem_full_path, 
        data_type=gdalconst.GDT_Float64, no_data_value=-99999):
        """ Create a raster file in geotiff format from a numpy array.
            Geotransform information for the output file is taken from the 
            input lidar file.
            data_type specifies the data type to be used in the output_file 
            (types are defined in gdalconst)
        """

        data_driver = gdal.GetDriverByName("GTiff") # for QGIS3
        data_set_geotransform = self.set_raster_geotransform()
    
        rows = raster_array.shape[0]
        cols = raster_array.shape[-1]
        
        target_ds = data_driver.Create(
                dem_full_path, cols, rows, 1, data_type)
        target_ds.SetGeoTransform(data_set_geotransform)
        
        if self.epsg_code:
            data_set_out_SRS = self.set_crs(self.epsg_code)
            target_ds.SetProjection(data_set_out_SRS.ExportToWkt())
            
        data_set_out_band = target_ds.GetRasterBand(1)
        data_set_out_band.SetNoDataValue(no_data_value)
        data_set_out_band.WriteArray(raster_array)
        
        data_set_out_band.FlushCache()
        target_ds = None
        
        return dem_full_path

    def set_raster_geotransform(self):
        """ Set the extent for the output raster
            geotransform = (x_origin, pixel_x, 0, y_origin, 0, pixel_y)
            pixel_y is frecuently defined < 0
        """
        self.raster_origin = self.lidar_extent[2] # this is min_x and max_y
        return (self.raster_origin[0], 
                self.pixel_size, 
                0, 
                self.raster_origin[-1], 
                0,
                -self.pixel_size)
    
    def set_crs(self, epsg_code):
        """ Set the output raster crs by a given EPSG code
        """
        data_set_out_SRS = osr.SpatialReference()
        data_set_out_SRS.ImportFromEPSG(epsg_code)
        
        return data_set_out_SRS