import numpy as np
import pytest
import scipy.sparse as sp

from quacc.data import ExtendedCollection


class TestExtendedCollection:
    @pytest.mark.parametrize(
        "instances,result",
        [
            (
                np.asarray(
                    [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
                ),
                [np.asarray([1, 3]), np.asarray([0, 2])],
            ),
            (
                sp.csr_matrix(
                    [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
                ),
                [np.asarray([1, 3]), np.asarray([0, 2])],
            ),
            (
                np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                [np.asarray([], dtype=int), np.asarray([0, 1])],
            ),
            (
                sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                [np.asarray([], dtype=int), np.asarray([0, 1])],
            ),
            (
                np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                [np.asarray([0, 1]), np.asarray([], dtype=int)],
            ),
            (
                sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                [np.asarray([0, 1]), np.asarray([], dtype=int)],
            ),
        ],
    )
    def test__split_index_by_pred(self, instances, result):
        ncl = 2
        assert all(
            np.array_equal(a, b)
            for (a, b) in zip(
                ExtendedCollection._split_index_by_pred(ncl, instances),
                result,
            )
        )

    @pytest.mark.parametrize(
        "instances,s_inst,norms",
        [
            (
                np.asarray(
                    [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
                ),
                [
                    np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                    np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                ],
                [0.5, 0.5],
            ),
            (
                sp.csr_matrix(
                    [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
                ),
                [
                    sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                    sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                ],
                [0.5, 0.5],
            ),
            (
                np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                [
                    np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                    np.asarray([], dtype=int),
                ],
                [1.0, 0.0],
            ),
            (
                sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                [
                    sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                    sp.csr_matrix([], dtype=int),
                ],
                [1.0, 0.0],
            ),
            (
                np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                [
                    np.asarray([], dtype=int),
                    np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                ],
                [0.0, 1.0],
            ),
            (
                sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                [
                    sp.csr_matrix([], dtype=int),
                    sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                ],
                [0.0, 1.0],
            ),
        ],
    )
    def test_split_inst_by_pred(self, instances, s_inst, norms):
        ncl = 2
        _s_inst, _norms = ExtendedCollection.split_inst_by_pred(ncl, instances)
        if isinstance(s_inst, np.ndarray):
            assert all(np.array_equal(a, b) for (a, b) in zip(_s_inst, s_inst))
        if isinstance(s_inst, sp.csr_matrix):
            assert all((a != b).nnz == 0 for (a, b) in zip(_s_inst, s_inst))
        assert all(a == b for (a, b) in zip(_norms, norms))

    @pytest.mark.parametrize(
        "instances,labels,inst0,lbl0,inst1,lbl1",
        [
            (
                np.asarray(
                    [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
                ),
                np.asarray([3, 0, 1, 2]),
                np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                np.asarray([0, 1]),
                np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                np.asarray([1, 0]),
            ),
            (
                sp.csr_matrix(
                    [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
                ),
                np.asarray([3, 0, 1, 2]),
                sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                np.asarray([0, 1]),
                sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                np.asarray([1, 0]),
            ),
            (
                np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                np.asarray([3, 1]),
                np.asarray([], dtype=int),
                np.asarray([], dtype=int),
                np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                np.asarray([1, 0]),
            ),
            (
                sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                np.asarray([3, 1]),
                sp.csr_matrix(np.empty((0, 0), dtype=int)),
                np.asarray([], dtype=int),
                sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
                np.asarray([1, 0]),
            ),
            (
                np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                np.asarray([0, 2]),
                np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                np.asarray([0, 1]),
                np.asarray([], dtype=int),
                np.asarray([], dtype=int),
            ),
            (
                sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                np.asarray([0, 2]),
                sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
                np.asarray([0, 1]),
                sp.csr_matrix(np.empty((0, 0), dtype=int)),
                np.asarray([], dtype=int),
            ),
        ],
    )
    def test_split_by_pred(self, instances, labels, inst0, lbl0, inst1, lbl1):
        ec = ExtendedCollection(instances, labels, classes=range(0, 4))
        [ec0, ec1] = ec.split_by_pred()
        if isinstance(instances, np.ndarray):
            assert np.array_equal(ec0.X, inst0)
            assert np.array_equal(ec1.X, inst1)
        if isinstance(instances, sp.csr_matrix):
            assert (ec0.X != inst0).nnz == 0
            assert (ec1.X != inst1).nnz == 0
        assert np.array_equal(ec0.y, lbl0)
        assert np.array_equal(ec1.y, lbl1)