# -*- coding: utf-8 -*-
"""
/***************************************************************************
 LiDARForestryHeight
                                 A QGIS plugin. LiDAR Forestry Height 
                                 generates a DEM with the forest height, 
                                 calculated from a classified LiDAR point
                                 cloud using LasPy Library
                                 
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                             -------------------
        begin                : 2018-09-24
        copyright            : (C) 2019 by PANOimagen S.L.
        email                : info@panoimagen.com
        git sha              : $Format:%H$
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""

import numpy as np
import os

from .lfh_errors import LasPyNotFoundError

try:
    import laspy
except ModuleNotFoundError:
    raise LasPyNotFoundError

from laspy.file import File
from laspy.header import Header
from . import files_paths_funs as dir_fns

class LiDAR(object):

    def __init__(self, in_las_path, out_path, partials_create):
        
        """ Init variables
        """
        self.in_las_path = in_las_path
        self.out_path = out_path
        self.partials_create = partials_create
        
        _, self.filename = os.path.split(in_las_path)
        
        self.dirs = dir_fns.DirAndPaths(self.filename, out_path)
        
        self.read_las_file()
        self.get_all_points()
        self.get_scaled_points()
        self.get_file_extent()
        
        self.get_terrain_points_array()
        self.get_first_returns_array()
        self.in_file.close()
        
        self.lidar_results = [self.terrain_arrays_list, 
                   self.surfaces_arrays_list, 
                   self.las_file_extent]
        
# TODO: Si .laz descomprimir    
        
    def read_las_file(self):
        """ Read the input LiDAR file in las format. Not laz format
        """
        try:
            self.in_file = File(self.in_las_path, mode='r')
        except OSError:
            raise OSError(u"LiDAR Forestry Height can't open the file.\n" + 
                          u"Please try again or with other LiDAR file")
            
        self.scale = self.in_file.header.scale
        self.offset = self.in_file.header.offset
        
    def get_all_points(self):
        """ Get points for file (points information and coordinates)
        """
        self.points_array = self.in_file.get_points()
        self.points_number = len(self.in_file)
    
    def get_file_extent(self):
        """ Get extent of the lidar file
        """

        self.las_file_extent = [(max(self.x_dimension), max(self.y_dimension)), 
                                (max(self.x_dimension), min(self.y_dimension)), 
                                (min(self.x_dimension), max(self.y_dimension)), 
                                (min(self.x_dimension), min(self.y_dimension))]
    
    def get_scaled_points(self):
        """ Get the coordinates scalated
        """
        x = self.in_file.X
        y = self.in_file.Y
        z = self.in_file.Z
        
        self.x_dimension = x * self.scale[0] + self.offset[0]
        self.y_dimension = y * self.scale[1] + self.offset[1]
        self.z_dimension = z * self.scale[-1] + self.offset[-1]
        
    def get_points_by_class(self, classif=2):
        """ Get points array with the given classification id (ASPRS classes)
        """
        class_points_bool = self.in_file.Classification == classif
        return self.points_array[class_points_bool], class_points_bool
        
    def get_terrain_points_array(self):
        """ Creates arrays for a given class (default=2) with the coordinates
            of the points classificated by that class flag
        """
        self.class_flag = 2
        class_2_points, class_2_bool = self.get_points_by_class(
                self.class_flag)
        size = class_2_points.shape[0]
        x_array = self.x_dimension[class_2_bool].reshape(size, 1)
        y_array = self.y_dimension[class_2_bool].reshape(size, 1)
        z_array = self.z_dimension[class_2_bool]
        
        xy_array = np.concatenate((x_array, y_array), axis=1)
        self.terrain_arrays_list = [xy_array, z_array]
        
        if self.partials_create:
            full_path = self.dirs.out_paths['las_surfaces']
            self.dirs.create_dir(self.dirs.out_dirs['las'])
            self.write_las_file()
        
    def get_first_returns_array(self):
    
        # Guardo el archivo para poder leerlo
        
        if self.partials_create:
            full_path = self.dirs.out_paths['las_surfaces']
            self.dirs.create_dir(self.dirs.out_dirs['las'])
        
        else:
            full_path = self.dirs.temp_full_paths['las_surfaces']
            self.dirs.create_dir(self.dirs.temp_dirs['temp_dir'])            
        
        out_file = File(full_path, mode='w', header=self.in_file.header)
        out_file.points = self.in_file.points[
                self.in_file.return_num == 1]
        out_file.close()
        
        #leo el archivo
        in_file = File(full_path, mode='r')
        scale = in_file.header.scale
        offset = in_file.header.offset
                        
        x = in_file.X
        y = in_file.Y
        z = in_file.Z
        
        x_dimension = x * scale[0] + offset[0]
        y_dimension = y * scale[1] + offset[1]
        z_dimension = z * scale[-1] + offset[-1]
        
        size = x_dimension.shape[0]
        
        x_array = x_dimension.reshape(size, 1)
        y_array = y_dimension.reshape(size, 1)
        z_array = z_dimension
        
        # Cerrar archivo para poder eliminarlo
        in_file.close()
        if not self.partials_create:
            self.dirs.remove_temp_file(full_path)
            self.dirs.remove_temp_dir(self.dirs.temp_dirs['temp_dir'])
            
        xy_array = np.concatenate((x_array, y_array), axis=1)
        self.surfaces_arrays_list = [xy_array, z_array]
    
    def write_las_file(self):
        """ Create and write a new lidar file with the desirable points
        """ 
        
        self.dirs.set_output_dir()
        full_path = self.dirs.out_paths['las_terrain']
        self.dirs.create_dir(self.dirs.out_dirs['las'])
        
        out_file = File(full_path, mode='w', 
                            header=self.in_file.header)
        class_2_points, class_2_bool = self.get_points_by_class(
                    self.class_flag)
        out_file.points = self.in_file.points[class_2_bool]
        out_file.close()