import os
from contextlib import redirect_stderr

import numpy as np
import pytest

from quacc.dataset import Dataset


@pytest.mark.dataset
class TestDataset:
    @pytest.mark.slow
    @pytest.mark.parametrize(
        "name,target,prevalence",
        [
            ("spambase", None, [0.5, 0.5]),
            ("imdb", None, [0.5, 0.5]),
            ("rcv1", "CCAT", [0.5, 0.5]),
            ("cifar10", "dog", [0.5, 0.5]),
            ("twitter_gasp", None, [0.33, 0.33, 0.33]),
        ],
    )
    def test__resample_all_train(self, name, target, prevalence, monkeypatch):
        def mockinit(self):
            self._name = name
            self._target = target
            self.all_train, self.test = self.alltrain_test(self._name, self._target)

        monkeypatch.setattr(Dataset, "__init__", mockinit)
        with open(os.devnull, "w") as dn:
            with redirect_stderr(dn):
                d = Dataset()
                d._Dataset__resample_all_train()
                assert (
                    np.around(d.all_train.prevalence(), decimals=2).tolist()
                    == prevalence
                )

    @pytest.mark.parametrize(
        "ncl, prevs,result",
        [
            (2, None, None),
            (2, [], None),
            (2, [[0.2, 0.1], [0.3, 0.2]], None),
            (2, [[0.2, 0.8], [0.3, 0.7]], [[0.2, 0.8], [0.3, 0.7]]),
            (2, [1.0, 2.0, 3.0], None),
            (2, [1, 2, 3], None),
            (2, [[1, 2], [2, 3], [3, 4]], None),
            (2, ["abc", "def"], None),
            (3, [[0.2, 0.3], [0.4, 0.1], [0.5, 0.2]], None),
            (3, [[0.2, 0.3, 0.2], [0.4, 0.1], [0.5, 0.6]], None),
            (2, [[0.2, 0.3, 0.1], [0.1, 0.5, 0.3]], None),
            (3, [[0.2, 0.3, 0.1], [0.1, 0.5, 0.3]], None),
            (3, [[0.2, 0.8], [0.1, 0.5]], None),
            (2, [[0.2, 0.9], [0.1, 0.5]], None),
            (2, 10, None),
            (2, [[0.2, 0.8], [0.5, 0.5]], [[0.2, 0.8], [0.5, 0.5]]),
            (3, [[0.2, 0.6], [0.3, 0.5]], None),
        ],
    )
    def test__check_prevs(self, ncl, prevs, result, monkeypatch):
        class MockLabelledCollection:
            def __init__(self):
                self.n_classes = ncl

        def mockinit(self):
            self.all_train = MockLabelledCollection()
            self.prevs = None

        monkeypatch.setattr(Dataset, "__init__", mockinit)
        d = Dataset()
        d._Dataset__check_prevs(prevs)
        _prevs = d.prevs if d.prevs is None else d.prevs.tolist()
        assert _prevs == result

    # fmt: off


    @pytest.mark.parametrize(
        "ncl,nprevs,built,result",
        [
            (2, 3, None, [[0.25, 0.75], [0.5, 0.5], [0.75, 0.25]]),
            (2, 3, np.array([[0.8, 0.2], [0.6, 0.4], [0.4, 0.6]]), [[0.8, 0.2], [0.6, 0.4], [0.4, 0.6]]),
            (2, 3, np.array([[0.75, 0.25], [0.5, 0.5], [0.25, 0.75]]), [[0.75, 0.25], [0.5, 0.5], [0.25, 0.75]]),
            (3, 3, None, [[0.25, 0.25, 0.5], [0.25, 0.5, 0.25], [0.5, 0.25, 0.25]]),
            (
                3, 4, None,
                [[0.2, 0.2, 0.6], [0.2, 0.4, 0.4], [0.2, 0.6, 0.2], [0.4, 0.2, 0.4], [0.4, 0.4, 0.2], [0.6, 0.2, 0.2]],
            ),
        ],
    )
    def test__build_prevs(self, ncl, nprevs, built, result, monkeypatch):
        class MockLabelledCollection:
            def __init__(self):
                self.n_classes = ncl

        def mockinit(self):
            self.all_train = MockLabelledCollection()
            self.prevs = built
            self._n_prevs = nprevs

        monkeypatch.setattr(Dataset, "__init__", mockinit)
        d = Dataset()
        _prevs = d._Dataset__build_prevs().tolist() 
        assert  _prevs == result

    # fmt: on

    @pytest.mark.parametrize(
        "ncl,prevs,atsize",
        [
            (2, np.array([[0.2, 0.8], [0.9, 0.1]]), 55),
            (3, np.array([[0.2, 0.7, 0.1], [0.9, 0.05, 0.05]]), 37),
        ],
    )
    def test_get(self, ncl, prevs, atsize, monkeypatch):
        class MockLabelledCollection:
            def __init__(self):
                self.n_classes = ncl

            def __len__(self):
                return 100

        def mockinit(self):
            self.prevs = prevs
            self.all_train = MockLabelledCollection()

        def mock_build_sample(self, p, at_size):
            return at_size

        monkeypatch.setattr(Dataset, "__init__", mockinit)
        monkeypatch.setattr(Dataset, "_Dataset__build_sample", mock_build_sample)
        d = Dataset()
        _get = d.get()
        assert all(s == atsize for s in _get)