from enmapboxprocessing.algorithm.randomsamplesfromclassificationdatasetalgorithm import RandomSamplesFromClassificationDatasetAlgorithm
from enmapboxprocessing.test.algorithm.testcase import TestCase
from enmapboxprocessing.utils import Utils
from enmapboxunittestdata import (classifierDumpPkl)

writeToDisk = True
c = ['', 'c:'][int(writeToDisk)]


class TestSubsampleClassificationSampleAlgorithm(TestCase):

    def test_N(self):
        alg = RandomSamplesFromClassificationDatasetAlgorithm()
        alg.initAlgorithm()
        parameters = {
            alg.P_DATASET: classifierDumpPkl,
            alg.P_N: 10,
            alg.P_OUTPUT_DATASET: c + '/vsimem/sample.pkl',
            alg.P_OUTPUT_COMPLEMENT: c + '/vsimem/sample2.pkl'
        }
        self.runalg(alg, parameters)
        self.assertEqual(10, len(Utils.pickleLoad(parameters[alg.P_OUTPUT_DATASET])['X']))
        self.assertEqual(48, len(Utils.pickleLoad(parameters[alg.P_OUTPUT_COMPLEMENT])['X']))

    def test_N_asList(self):
        alg = RandomSamplesFromClassificationDatasetAlgorithm()
        alg.initAlgorithm()
        parameters = {
            alg.P_DATASET: classifierDumpPkl,
            alg.P_N: str([3]),
            alg.P_OUTPUT_DATASET: c + '/vsimem/sample.pkl',
            alg.P_OUTPUT_COMPLEMENT: c + '/vsimem/sample2.pkl'
        }
        self.runalg(alg, parameters)
        self.assertEqual(3*5, len(Utils.pickleLoad(parameters[alg.P_OUTPUT_DATASET])['X']))

    def test_N_withReplacemant(self):
        alg = RandomSamplesFromClassificationDatasetAlgorithm()
        alg.initAlgorithm()
        parameters = {
            alg.P_DATASET: classifierDumpPkl,
            alg.P_N: 100,
            alg.P_REPLACE: True,
            alg.P_OUTPUT_DATASET: c + '/vsimem/sample.pkl',
            alg.P_OUTPUT_COMPLEMENT: c + '/vsimem/sample2.pkl'
        }
        self.runalg(alg, parameters)
        self.assertEqual(500, len(Utils.pickleLoad(parameters[alg.P_OUTPUT_DATASET])['X']))

    def test_P(self):
        alg = RandomSamplesFromClassificationDatasetAlgorithm()
        alg.initAlgorithm()
        parameters = {
            alg.P_DATASET: classifierDumpPkl,
            alg.P_N: 10,
            alg.P_PROPORTIONAL: True,
            alg.P_REPLACE: True,
            alg.P_OUTPUT_DATASET: c + '/vsimem/sample.pkl',
            alg.P_OUTPUT_COMPLEMENT: c + '/vsimem/sample_complement.pkl'

        }
        self.runalg(alg, parameters)
        self.assertEqual(193, len(Utils.pickleLoad(parameters[alg.P_OUTPUT_DATASET])['X']))
