# coding=utf-8
"""Test for utilities module."""
import unittest
from copy import deepcopy
from inasafe import definitions

from ..definitions import (
    hazard_flood,
    hazard_tsunami,
    hazard_earthquake,
    hazard_volcano,
    hazard_volcanic_ash,
    hazard_generic,
    exposure_population,
    exposure_land_cover,
    exposure_road,
    exposure_structure,
    hazard_category_single_event,
    hazard_category_multiple_event,
    count_exposure_unit,
    unit_metres,
    unit_feet,
    unit_generic,
    exposure_fields,
    hazard_fields,
    flood_hazard_classes,
    flood_petabencana_hazard_classes,
    generic_hazard_classes,
    aggregation_fields,
    layer_purpose_hazard,
    layer_purpose_exposure,
    layer_geometry_raster,
    layer_geometry_line,
    layer_geometry_point,
    layer_geometry_polygon,
    cyclone_au_bom_hazard_classes,
    unit_knots
)
from ..definitions.hazard import hazard_cyclone

from ..definitions.utilities import (
    definition,
    purposes_for_layer,
    hazards_for_layer,
    exposures_for_layer,
    hazard_categories_for_layer,
    hazard_units,
    exposure_units,
    get_fields,
    get_classifications,
    get_allowed_geometries,
    all_default_fields,
    get_compulsory_fields,
    get_non_compulsory_fields,
    default_classification_thresholds,
    default_classification_value_maps
)


__copyright__ = "Copyright 2016, The InaSAFE Project"
__license__ = "GPL version 3"
__email__ = "info@inasafe.org"
__revision__ = '4c85bcb847131a3d634744b9ea01083b158493bf'


class TestDefinitionsUtilities(unittest.TestCase):
    """Test Utilities Class for Definitions."""

    def test_definition(self):
        """Test we can get definitions for keywords.

        .. versionadded:: 3.2

        """
        keyword = 'hazards'
        keyword_definition = definition(keyword)
        self.assertTrue('description' in keyword_definition)

    def test_layer_purpose_for_layer(self):
        """Test for purpose_for_layer method."""
        expected = ['aggregation', 'exposure', 'hazard']
        self.assertListEqual(expected, purposes_for_layer('polygon'))

        expected = ['exposure', 'hazard']
        self.assertListEqual(expected, purposes_for_layer('raster'))

        expected = ['exposure']
        self.assertListEqual(expected, purposes_for_layer('line'))

    def test_hazards_for_layer(self):
        """Test for hazards_for_layer"""
        self.maxDiff = None
        hazards = hazards_for_layer(
            'polygon', 'single_event')
        hazards = [hazard['key'] for hazard in hazards]
        expected = [
            hazard_flood['key'],
            hazard_tsunami['key'],
            hazard_earthquake['key'],
            hazard_volcano['key'],
            hazard_volcanic_ash['key'],
            hazard_cyclone['key'],
            hazard_generic['key']
        ]
        self.assertItemsEqual(hazards, expected)

        hazards = hazards_for_layer('polygon')
        hazards = [hazard['key'] for hazard in hazards]
        expected = [
            hazard_flood['key'],
            hazard_tsunami['key'],
            hazard_earthquake['key'],
            hazard_volcano['key'],
            hazard_volcanic_ash['key'],
            hazard_cyclone['key'],
            hazard_generic['key']
        ]
        self.assertItemsEqual(hazards, expected)

        hazards = hazards_for_layer(
            'raster', 'single_event')
        hazards = [hazard['key'] for hazard in hazards]
        expected = [
            hazard_flood['key'],
            hazard_tsunami['key'],
            hazard_earthquake['key'],
            hazard_volcanic_ash['key'],
            hazard_cyclone['key'],
            hazard_generic['key']
        ]
        self.assertItemsEqual(hazards, expected)

    def test_exposures_for_layer(self):
        """Test for exposures_for_layer"""
        exposures = exposures_for_layer('polygon')
        expected = [
            exposure_structure,
            exposure_population,
            exposure_land_cover,
        ]
        self.assertItemsEqual(exposures, expected)

        exposures = exposures_for_layer('line')
        expected = [exposure_road]
        self.assertItemsEqual(exposures, expected)

    def test_hazard_categories_for_layer(self):
        """Test for hazard_categories_for_layer"""
        hazard_categories = hazard_categories_for_layer()
        expected = [
            hazard_category_multiple_event,
            hazard_category_single_event]
        self.assertListEqual(hazard_categories, expected)

    def test_exposure_units(self):
        """Test for exposure_units"""
        expected = [count_exposure_unit]
        self.assertItemsEqual(exposure_units('population'), expected)

    def test_hazards_units(self):
        """Test for hazard_units"""
        expected = [unit_metres, unit_feet, unit_generic]
        self.assertItemsEqual(hazard_units('flood'), expected)

    def test_hazards_classifications(self):
        """Test for get_hazards_classifications."""
        self.maxDiff = None
        expected = [
            flood_hazard_classes,
            flood_petabencana_hazard_classes,
            generic_hazard_classes]
        self.assertItemsEqual(
            get_classifications('flood'), expected)

    def test_get_compulsory_field(self):
        """Test get_compulsory_field method."""
        compulsory_field = get_compulsory_fields('exposure', 'structure')
        expected_fields = exposure_structure['compulsory_fields']
        self.assertListEqual([compulsory_field], expected_fields)

    def test_get_not_compulsory_field(self):
        """Test get_non_compulsory_field method."""
        non_compulsory_fields = get_non_compulsory_fields(
            'exposure', 'structure')
        expected_fields = [
            field for field in exposure_structure['fields'] if not field[
                'replace_null']]
        expected_fields += [
            field for field in exposure_structure['extra_fields'] if not
            field['replace_null']]

        for field in expected_fields:
            if field.get('replace_null'):
                expected_fields.remove(field)
        self.assertListEqual(non_compulsory_fields, expected_fields)

    def test_get_fields(self):
        """Test get_fields method."""
        fields = get_fields('exposure', 'structure')
        expected_fields = deepcopy(exposure_structure['compulsory_fields'])
        expected_fields += exposure_structure['fields']
        expected_fields += exposure_structure['extra_fields']
        self.assertListEqual(fields, expected_fields)

        fields = get_fields('hazard', 'flood')
        expected_fields = deepcopy(hazard_flood['compulsory_fields'])
        expected_fields += hazard_flood['fields']
        expected_fields += hazard_flood['extra_fields']
        self.assertListEqual(fields, expected_fields)

        fields = get_fields('hazard')
        expected_fields = deepcopy(hazard_fields)
        self.assertListEqual(fields, expected_fields)

        fields = get_fields('exposure')
        expected_fields = deepcopy(exposure_fields)
        self.assertListEqual(fields, expected_fields)

        fields = get_fields('aggregation')
        expected_fields = deepcopy(aggregation_fields)
        self.assertListEqual(fields, expected_fields)

        fields = get_fields('aggregation', replace_null=True)
        expected_fields = [f for f in aggregation_fields if f['replace_null']]
        self.assertListEqual(fields, expected_fields)

        fields = get_fields('aggregation', replace_null=False)
        expected_fields = [
            f for f in aggregation_fields if not f['replace_null']]
        self.assertListEqual(fields, expected_fields)

    def test_get_allowed_geometries(self):
        """Test get_allowed_geometries"""
        allowed_geometries = get_allowed_geometries(
            layer_purpose_hazard['key'])
        expected = [
            layer_geometry_polygon,
            layer_geometry_raster
        ]
        self.assertEqual(allowed_geometries, expected)

        allowed_geometries = get_allowed_geometries(
            layer_purpose_exposure['key'])
        expected = [
            layer_geometry_point,
            layer_geometry_line,
            layer_geometry_polygon,
            layer_geometry_raster
        ]
        print[x['key'] for x in expected]
        print[x['key'] for x in allowed_geometries]
        self.assertEqual(allowed_geometries, expected)

    def test_all_default_fields(self):
        """Test all_default_fields method."""
        default_fields = all_default_fields()
        for default_field in default_fields:
            self.assertTrue(default_field.get('replace_null'), False)
            self.assertIsNotNone(default_field.get('default_value'))

    def test_classification_thresholds(self):
        """Test for classification_thresholds method."""
        thresholds = default_classification_thresholds(flood_hazard_classes)
        wet_class = flood_hazard_classes['classes'][0]
        dry_class = flood_hazard_classes['classes'][1]

        expected = {
            'dry': [
                dry_class['numeric_default_min'],
                dry_class['numeric_default_max']
            ],
            'wet': [
                wet_class['numeric_default_min'],
                wet_class['numeric_default_max']
            ]
        }
        self.assertDictEqual(thresholds, expected)

        unit_knots_key = unit_knots['key']
        thresholds = default_classification_thresholds(
            cyclone_au_bom_hazard_classes, unit_knots_key)
        category_5_class = cyclone_au_bom_hazard_classes['classes'][0]
        category_4_class = cyclone_au_bom_hazard_classes['classes'][1]
        category_3_class = cyclone_au_bom_hazard_classes['classes'][2]
        category_2_class = cyclone_au_bom_hazard_classes['classes'][3]
        category_1_class = cyclone_au_bom_hazard_classes['classes'][4]
        tropical_depression_class = cyclone_au_bom_hazard_classes['classes'][5]
        expected = {
            'tropical_depression': [
                tropical_depression_class['numeric_default_min'],
                tropical_depression_class['numeric_default_max'][
                    unit_knots_key]
            ],
            'category_1': [
                category_1_class['numeric_default_min'][unit_knots_key],
                category_1_class['numeric_default_max'][unit_knots_key]
            ],
            'category_2': [
                category_2_class['numeric_default_min'][unit_knots_key],
                category_2_class['numeric_default_max'][unit_knots_key]
            ],
            'category_3': [
                category_3_class['numeric_default_min'][unit_knots_key],
                category_3_class['numeric_default_max'][unit_knots_key]
            ],
            'category_4': [
                category_4_class['numeric_default_min'][unit_knots_key],
                category_4_class['numeric_default_max'][unit_knots_key]
            ],
            'category_5': [
                category_5_class['numeric_default_min'][unit_knots_key],
                category_5_class['numeric_default_max']
            ]
        }
        self.assertDictEqual(thresholds, expected)

    def test_classification_value_maps(self):
        """Test for classification_value_maps method."""
        value_maps = default_classification_value_maps(flood_hazard_classes)
        wet_class = flood_hazard_classes['classes'][0]
        dry_class = flood_hazard_classes['classes'][1]
        expected = {
            'dry': dry_class['string_defaults'],
            'wet': wet_class['string_defaults']
        }
        self.assertDictEqual(value_maps, expected)

        value_maps = default_classification_value_maps(
            cyclone_au_bom_hazard_classes)
        category_5_class = cyclone_au_bom_hazard_classes['classes'][0]
        category_4_class = cyclone_au_bom_hazard_classes['classes'][1]
        category_3_class = cyclone_au_bom_hazard_classes['classes'][2]
        category_2_class = cyclone_au_bom_hazard_classes['classes'][3]
        category_1_class = cyclone_au_bom_hazard_classes['classes'][4]
        tropical_depression_class = cyclone_au_bom_hazard_classes['classes'][5]
        expected = {
            'category_1': category_1_class.get('string_defaults', []),
            'category_2': category_2_class.get('string_defaults', []),
            'category_3': category_3_class.get('string_defaults', []),
            'category_4': category_4_class.get('string_defaults', []),
            'category_5': category_5_class.get('string_defaults', []),
            'tropical_depression': tropical_depression_class.get(
                'string_defaults', [])
        }
        self.assertDictEqual(value_maps, expected)

    def test_unique_definition_key(self):
        """Test to make sure all definitions have different key."""
        keys = {}
        for item in dir(definitions):
            if not item.startswith("__"):
                var = getattr(definitions, item)
                if isinstance(var, dict):
                    if not var.get('key'):
                        continue
                    if var.get('key') not in keys:
                        keys[var.get('key')] = [var]
                    else:
                        keys[var.get('key')].append(var)
        duplicate_keys = [k for k, v in keys.items() if len(v) > 1]
        message = 'There are duplicate keys: %s\n' % ', '.join(duplicate_keys)
        for duplicate_key in duplicate_keys:
            message += 'Duplicate key: %s\n' % duplicate_key
            for v in keys[duplicate_key]:
                message += v['name'] + ' ' + v['description'] + '\n'
        self.assertEqual(len(duplicate_keys), 0, message)


if __name__ == '__main__':
    unittest.main()
