import unittest
import numpy as np
from scipy.sparse import csr_matrix

import quapy as qp


class LabelCollectionTestCase(unittest.TestCase):
    def test_split(self):
        x = np.arange(100)
        y = np.random.randint(0,5,100)
        data = qp.data.LabelledCollection(x,y)
        tr, te = data.split_random(0.7)
        check_prev = tr.prevalence()*0.7 + te.prevalence()*0.3

        self.assertEqual(len(tr), 70)
        self.assertEqual(len(te), 30)
        self.assertEqual(np.allclose(check_prev, data.prevalence()), True)
        self.assertEqual(len(tr+te), len(data))

    def test_join(self):
        x = np.arange(50)
        y = np.random.randint(2, 5, 50)
        data1 = qp.data.LabelledCollection(x, y)

        x = np.arange(200)
        y = np.random.randint(0, 3, 200)
        data2 = qp.data.LabelledCollection(x, y)

        x = np.arange(100)
        y = np.random.randint(0, 6, 100)
        data3 = qp.data.LabelledCollection(x, y)

        combined = qp.data.LabelledCollection.join(data1, data2, data3)
        self.assertEqual(len(combined), len(data1)+len(data2)+len(data3))
        self.assertEqual(all(combined.classes_ == np.arange(6)), True)

        x = np.random.rand(10, 3)
        y = np.random.randint(0, 1, 10)
        data4 = qp.data.LabelledCollection(x, y)
        with self.assertRaises(Exception):
            combined = qp.data.LabelledCollection.join(data1, data2, data3, data4)

        x = np.random.rand(20, 3)
        y = np.random.randint(0, 1, 20)
        data5 = qp.data.LabelledCollection(x, y)
        combined = qp.data.LabelledCollection.join(data4, data5)
        self.assertEqual(len(combined), len(data4)+len(data5))

        x = np.random.rand(10, 4)
        y = np.random.randint(0, 1, 10)
        data6 = qp.data.LabelledCollection(x, y)
        with self.assertRaises(Exception):
            combined = qp.data.LabelledCollection.join(data4, data5, data6)

        data4.instances = csr_matrix(data4.instances)
        with self.assertRaises(Exception):
            combined = qp.data.LabelledCollection.join(data4, data5)
        data5.instances = csr_matrix(data5.instances)
        combined = qp.data.LabelledCollection.join(data4, data5)
        self.assertEqual(len(combined), len(data4) + len(data5))


if __name__ == '__main__':
    unittest.main()