# -*- coding: utf-8 -*-
"""
/***************************************************************************
 DeepLearningTools
                                 A QGIS plugin
 QGIS plugin to aid training Deep Learning Models
 Generated by Plugin Builder: http://g-sherman.github.io/Qgis-Plugin-Builder/
                              -------------------
        begin                : 2020-03-13
        copyright            : (C) 2020 by Philipe Borba
        email                : philipeborba@gmail.com
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/
"""
import sys
import os
THIS_FOLDER = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(THIS_FOLDER, '..', '..'))
import hashlib
from qgis.testing import unittest, start_app
from qgis.core import QgsVectorLayer
import processing
# from qgis import core
from DeepLearningTools.core.image_processing.image_utils import ImageUtils

class TestImageProcessing(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        start_app()
        from processing.core.Processing import Processing
        Processing.initialize()

    @classmethod
    def tearDownClass(cls):
        try:
            from processing.core.Processing import Processing
            Processing.deinitialize()
            for path in cls.cleanup_paths:
                shutil.rmtree(path)
        except:
            pass

    def test_passes(self):
        # x = ImageUtils()
        self.assertTrue(True)

    # def test_create_image_label(self):
    #     image_utils = ImageUtils()
    #     current_folder = os.path.dirname(os.path.abspath(__file__))
    #     test_data_dir = os.path.join(current_folder, 'test_data')
    #     test_dataset_dir = os.path.join(test_data_dir, 'test_dataset')

    #     input_image = os.path.join(test_dataset_dir, 'images/0.tif')
    #     input_polygon_lyr_path = os.path.join(
    #         test_data_dir,
    #         'test_polygons.geojson'
    #     )
    #     polygon_lyr = QgsVectorLayer(
    #         input_polygon_lyr_path, 'polygon_lyr', 'ogr'
    #     )
    #     expected_label = os.path.join(test_dataset_dir, 'labels/0.tif')
    #     expected_hash = hash_file(expected_label)

    #     generated_label = os.path.join(current_folder, '0_output.tif')
    #     image_utils.create_image_label(
    #         input_image,
    #         generated_label,
    #         polygon_lyr
    #     )
    #     generated_hash = hash_file(generated_label)
    #     os.remove(generated_label)
    #     self.assertEqual(
    #         expected_hash,
    #         generated_hash
    #     )
    
    def test_create_image_label_png(self):
        image_utils = ImageUtils()
        current_folder = os.path.dirname(os.path.abspath(__file__))
        test_data_dir = os.path.join(current_folder, 'test_data')
        test_dataset_dir = os.path.join(test_data_dir, 'test_dataset')

        input_image = os.path.join(test_dataset_dir, 'images/10.png')
        input_polygon_lyr_path = os.path.join(
            test_data_dir,
            'test_polygons2.geojson'
        )
        polygon_lyr = QgsVectorLayer(
            input_polygon_lyr_path, 'test_polygons2', 'ogr'
        )
        expected_label = os.path.join(test_dataset_dir, 'labels/10.png')
        expected_hash = hash_file(expected_label)

        generated_label = os.path.join(current_folder, '10_output.png')
        image_utils.create_image_label(
            input_image,
            generated_label,
            polygon_lyr
        )
        generated_hash = hash_file(generated_label)
        os.remove(generated_label)
        self.assertEqual(
            expected_hash,
            generated_hash
        )

def hash_file(filename):
    """"This function returns the SHA-1 hash
    of the file passed into it"""
    h = hashlib.sha1()
    with open(filename, 'rb') as file:
        chunk = 0
        while chunk != b'':
            chunk = file.read(1024)
            h.update(chunk)
    return h.hexdigest()

def run_all():
    """Default function that is called by the runner if nothing else is specified"""
    suite = unittest.TestSuite()
    suite.addTests(unittest.makeSuite(TestImageProcessing, 'test'))
    unittest.TextTestRunner(verbosity=3, stream=sys.stdout).run(suite)

if __name__ == '__main__':
    unittest.main()