From db5064dbafe9e22d49beafd7d1225a0bc0923ecf Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 21 Dec 2023 16:31:06 +0100 Subject: [PATCH] tests added, passing --- tests/test_data.py | 633 +++++++++++++----- tests/test_dataset.py | 134 +++- tests/test_error.py | 95 +++ tests/test_evaluation/test_report.py | 425 ++++++++++++ .../test_base.cpython-311-pytest-7.4.2.pyc | Bin 5919 -> 0 bytes ...del_selection.cpython-311-pytest-7.4.2.pyc | Bin 554 -> 0 bytes tests/test_method/test_base.py | 100 +++ .../test_BQAE.cpython-311-pytest-7.4.2.pyc | Bin 5871 -> 0 bytes .../test_MCAE.cpython-311-pytest-7.4.2.pyc | Bin 543 -> 0 bytes tests/test_method/test_base/test_BQAE.py | 66 -- tests/test_method/test_base/test_MCAE.py | 2 - 11 files changed, 1207 insertions(+), 248 deletions(-) create mode 100644 tests/test_error.py create mode 100644 tests/test_evaluation/test_report.py delete mode 100644 tests/test_method/__pycache__/test_base.cpython-311-pytest-7.4.2.pyc delete mode 100644 tests/test_method/__pycache__/test_model_selection.cpython-311-pytest-7.4.2.pyc create mode 100644 tests/test_method/test_base.py delete mode 100644 tests/test_method/test_base/__pycache__/test_BQAE.cpython-311-pytest-7.4.2.pyc delete mode 100644 tests/test_method/test_base/__pycache__/test_MCAE.cpython-311-pytest-7.4.2.pyc delete mode 100644 tests/test_method/test_base/test_BQAE.py delete mode 100644 tests/test_method/test_base/test_MCAE.py diff --git a/tests/test_data.py b/tests/test_data.py index 69124a5..56ecc0c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,130 +1,33 @@ +from unittest import mock + import numpy as np import pytest import scipy.sparse as sp from quacc.data import ( + ExtBinPrev, ExtendedCollection, ExtendedData, ExtendedLabels, ExtendedPrev, ExtensionPolicy, + ExtMulPrev, + _split_index_by_pred, ) -@pytest.mark.ext -@pytest.mark.extpol -class TestExtendedPolicy: - @pytest.mark.parametrize( - "extpol,nbcl,result", - [ - (ExtensionPolicy(), 2, np.array([0, 1, 2, 3])), - (ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2])), - (ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])), - (ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3])), - ], - ) - def test_qclasses(self, extpol, nbcl, result): - assert (result == extpol.qclasses(nbcl)).all() +@pytest.fixture +def nd_1(): + return np.arange(12).reshape((4, 3)) - @pytest.mark.parametrize( - "extpol,nbcl,result", - [ - (ExtensionPolicy(), 2, np.array([0, 1, 2, 3])), - (ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2, 3])), - (ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])), - ( - ExtensionPolicy(collapse_false=True), - 3, - np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]), - ), - ], - ) - def test_eclasses(self, extpol, nbcl, result): - assert (result == extpol.eclasses(nbcl)).all() - @pytest.mark.parametrize( - "extpol,nbcl,result", - [ - ( - ExtensionPolicy(), - 2, - ( - np.array([0, 0, 1, 1]), - np.array([0, 1, 0, 1]), - ), - ), - ( - ExtensionPolicy(collapse_false=True), - 2, - ( - np.array([0, 1, 0]), - np.array([0, 1, 1]), - ), - ), - ( - ExtensionPolicy(), - 3, - ( - np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]), - np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]), - ), - ), - ( - ExtensionPolicy(collapse_false=True), - 3, - ( - np.array([0, 1, 2, 0]), - np.array([0, 1, 2, 1]), - ), - ), - ], - ) - def test_matrix_idx(self, extpol, nbcl, result): - _midx = extpol.matrix_idx(nbcl) - assert len(_midx) == len(result) - assert all((idx == r).all() for idx, r in zip(_midx, result)) - - @pytest.mark.parametrize( - "extpol,nbcl,true,pred,result", - [ - ( - ExtensionPolicy(), - 2, - np.array([1, 0, 1, 1, 0, 0]), - np.array([1, 0, 0, 1, 1, 0]), - np.array([3, 0, 2, 3, 1, 0]), - ), - ( - ExtensionPolicy(collapse_false=True), - 2, - np.array([1, 0, 1, 1, 0, 0]), - np.array([1, 0, 0, 1, 1, 0]), - np.array([1, 0, 2, 1, 2, 0]), - ), - ( - ExtensionPolicy(), - 3, - np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), - np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), - np.array([4, 6, 0, 3, 1, 7, 2, 5, 8]), - ), - ( - ExtensionPolicy(collapse_false=True), - 3, - np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), - np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), - np.array([1, 3, 0, 3, 3, 3, 3, 3, 2]), - ), - ], - ) - def test_ext_lbl(self, extpol, nbcl, true, pred, result): - vfun = extpol.ext_lbl(nbcl) - assert (vfun(true, pred) == result).all() +@pytest.fixture +def csr_1(nd_1): + return sp.csr_matrix(nd_1) @pytest.mark.ext -@pytest.mark.extd -class TestExtendedData: +class TestData: @pytest.mark.parametrize( "pred_proba,result", [ @@ -153,17 +56,219 @@ class TestExtendedData: ), ], ) - def test__split_index_by_pred(self, monkeypatch, pred_proba, result): - def mockinit(self, pred_proba): - self.pred_proba_ = pred_proba - - monkeypatch.setattr(ExtendedData, "__init__", mockinit) - ed = ExtendedData(pred_proba) - _split_index = ed._ExtendedData__split_index_by_pred() + def test_split_index_by_pred(self, pred_proba, result): + _split_index = _split_index_by_pred(pred_proba) assert len(_split_index) == len(result) assert all((a == b).all() for (a, b) in zip(_split_index, result)) +@pytest.mark.ext +@pytest.mark.extpol +class TestExtendedPolicy: + # fmt: off + @pytest.mark.parametrize( + "extpol,nbcl,result", + [ + (ExtensionPolicy(), 2, np.array([0, 1, 2, 3])), + (ExtensionPolicy(group_false=True), 2, np.array([0, 1, 2, 3])), + (ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2])), + (ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])), + (ExtensionPolicy(group_false=True), 3, np.array([0, 1, 2, 3, 4, 5])), + (ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3])), + ], + ) + def test_qclasses(self, extpol, nbcl, result): + assert (result == extpol.qclasses(nbcl)).all() + + @pytest.mark.parametrize( + "extpol,nbcl,result", + [ + (ExtensionPolicy(), 2, np.array([0, 1, 2, 3])), + (ExtensionPolicy(group_false=True), 2, np.array([0, 1, 2, 3])), + (ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2, 3])), + (ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])), + (ExtensionPolicy(group_false=True), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])), + (ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])), + ], + ) + def test_eclasses(self, extpol, nbcl, result): + assert (result == extpol.eclasses(nbcl)).all() + + @pytest.mark.parametrize( + "extpol,nbcl,result", + [ + (ExtensionPolicy(), 2, np.array([0, 1])), + (ExtensionPolicy(group_false=True), 2, np.array([0, 1])), + (ExtensionPolicy(collapse_false=True), 2, np.array([0, 1])), + (ExtensionPolicy(), 3, np.array([0, 1, 2])), + (ExtensionPolicy(group_false=True), 3, np.array([0, 1])), + (ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2])), + ], + ) + def test_tfp_classes(self, extpol, nbcl, result): + assert (result == extpol.tfp_classes(nbcl)).all() + + @pytest.mark.parametrize( + "extpol,nbcl,result", + [ + ( + ExtensionPolicy(), 2, + (np.array([0, 0, 1, 1]), np.array([0, 1, 0, 1])), + ), + ( + ExtensionPolicy(group_false=True), 2, + (np.array([0, 1, 1, 0]), np.array([0, 1, 0, 1])), + ), + ( + ExtensionPolicy(collapse_false=True), 2, + (np.array([0, 1, 0]), np.array([0, 1, 1])), + ), + ( + ExtensionPolicy(), 3, + (np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]), np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])), + ), + ( + ExtensionPolicy(group_false=True), 3, + (np.array([0, 1, 2, 1, 2, 0]), np.array([0, 1, 2, 0, 1, 2])), + ), + ( + ExtensionPolicy(collapse_false=True), 3, + (np.array([0, 1, 2, 0]), np.array([0, 1, 2, 1])), + ), + ], + ) + def test_matrix_idx(self, extpol, nbcl, result): + _midx = extpol.matrix_idx(nbcl) + assert len(_midx) == len(result) + assert all((idx == r).all() for idx, r in zip(_midx, result)) + + @pytest.mark.parametrize( + "extpol,nbcl,true,pred,result", + [ + ( + ExtensionPolicy(), 2, + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + np.array([3, 0, 2, 3, 1, 0]), + ), + ( + ExtensionPolicy(group_false=True), 2, + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + np.array([1, 0, 2, 1, 3, 0]), + ), + ( + ExtensionPolicy(collapse_false=True), 2, + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + np.array([1, 0, 2, 1, 2, 0]), + ), + ( + ExtensionPolicy(), 3, + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + np.array([4, 6, 0, 3, 1, 7, 2, 5, 8]), + ), + ( + ExtensionPolicy(group_false=True), 3, + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + np.array([1, 3, 0, 3, 4, 4, 5, 5, 2]), + ), + ( + ExtensionPolicy(collapse_false=True), 3, + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + np.array([1, 3, 0, 3, 3, 3, 3, 3, 2]), + ), + ], + ) + def test_ext_lbl(self, extpol, nbcl, true, pred, result): + vfun = extpol.ext_lbl(nbcl) + assert (vfun(true, pred) == result).all() + + @pytest.mark.parametrize( + "extpol,nbcl,true,pred,result", + [ + ( + ExtensionPolicy(), 2, + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + np.array([1, 0, 1, 1, 0, 0]), + ), + ( + ExtensionPolicy(group_false=True), 2, + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + np.array([0, 0, 1, 0, 1, 0]), + ), + ( + ExtensionPolicy(collapse_false=True), 2, + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + np.array([1, 0, 1, 1, 0, 0]), + ), + ( + ExtensionPolicy(), 3, + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + ), + ( + ExtensionPolicy(group_false=True), 3, + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + np.array([0, 1, 0, 1, 1, 1, 1, 1, 0]), + ), + ( + ExtensionPolicy(collapse_false=True), 3, + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + ), + ], + ) + def test_true_lbl_from_pred(self, extpol, nbcl, true, pred, result): + vfun = extpol.true_lbl_from_pred(nbcl) + assert (vfun(true, pred) == result).all() + # fmt: on + + +@pytest.mark.ext +@pytest.mark.extd +class TestExtendedData: + @pytest.mark.parametrize( + "instances_name,indexes,result", + [ + ( + "nd_1", + [np.array([0, 2]), np.array([1, 3])], + [ + np.array([[0, 1, 2], [6, 7, 8]]), + np.array([[3, 4, 5], [9, 10, 11]]), + ], + ), + ( + "nd_1", + [np.array([0]), np.array([1, 3]), np.array([2])], + [ + np.array([[0, 1, 2]]), + np.array([[3, 4, 5], [9, 10, 11]]), + np.array([[6, 7, 8]]), + ], + ), + ], + ) + def test_split_by_pred(self, instances_name, indexes, result, monkeypatch, request): + def mockinit(self): + self.instances = request.getfixturevalue(instances_name) + + monkeypatch.setattr(ExtendedData, "__init__", mockinit) + d = ExtendedData() + split = d.split_by_pred(indexes) + assert all([(s == r).all() for s, r in zip(split, result)]) + + @pytest.mark.ext @pytest.mark.extl class TestExtendedLabels: @@ -177,6 +282,13 @@ class TestExtendedLabels: ExtensionPolicy(), np.array([3, 1, 0, 2, 3]), ), + ( + np.array([1, 0, 0, 1, 1]), + np.array([1, 1, 0, 0, 1]), + 2, + ExtensionPolicy(group_false=True), + np.array([1, 3, 0, 2, 1]), + ), ( np.array([1, 0, 0, 1, 1]), np.array([1, 1, 0, 0, 1]), @@ -184,92 +296,128 @@ class TestExtendedLabels: ExtensionPolicy(collapse_false=True), np.array([1, 2, 0, 2, 1]), ), + ( + np.array([1, 0, 0, 1, 0, 1, 2, 2, 2]), + np.array([1, 1, 0, 0, 2, 2, 2, 0, 1]), + 3, + ExtensionPolicy(), + np.array([4, 1, 0, 3, 2, 5, 8, 6, 7]), + ), + ( + np.array([1, 0, 0, 1, 0, 1, 2, 2, 2]), + np.array([1, 1, 0, 0, 2, 2, 2, 0, 1]), + 3, + ExtensionPolicy(group_false=True), + np.array([1, 4, 0, 3, 5, 5, 2, 3, 4]), + ), + ( + np.array([1, 0, 0, 1, 0, 1, 2, 2, 2]), + np.array([1, 1, 0, 0, 2, 2, 2, 0, 1]), + 3, + ExtensionPolicy(collapse_false=True), + np.array([1, 3, 0, 3, 3, 3, 2, 3, 3]), + ), ], ) def test_y(self, true, pred, nbcl, extpol, result): el = ExtendedLabels(true, pred, nbcl, extpol) assert (el.y == result).all() + @pytest.mark.parametrize( + "extpol,nbcl,indexes,true,pred,result,rcls", + [ + ( + ExtensionPolicy(), + 2, + [np.array([1, 2, 5]), np.array([0, 3, 4])], + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + [np.array([0, 1, 0]), np.array([1, 1, 0])], + np.array([0, 1]), + ), + ( + ExtensionPolicy(group_false=True), + 2, + [np.array([1, 2, 5]), np.array([0, 3, 4])], + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + [np.array([0, 1, 0]), np.array([0, 0, 1])], + np.array([0, 1]), + ), + ( + ExtensionPolicy(collapse_false=True), + 2, + [np.array([1, 2, 5]), np.array([0, 3, 4])], + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + [np.array([0, 1, 0]), np.array([1, 1, 0])], + np.array([0, 1]), + ), + ( + ExtensionPolicy(), + 3, + [np.array([1, 2, 3]), np.array([0, 4, 5]), np.array([6, 7, 8])], + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + [np.array([2, 0, 1]), np.array([1, 0, 2]), np.array([0, 1, 2])], + np.array([0, 1, 2]), + ), + ( + ExtensionPolicy(group_false=True), + 3, + [np.array([1, 2, 3]), np.array([0, 4, 5]), np.array([6, 7, 8])], + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + [np.array([1, 0, 1]), np.array([0, 1, 1]), np.array([1, 1, 0])], + np.array([0, 1]), + ), + ( + ExtensionPolicy(collapse_false=True), + 3, + [np.array([1, 2, 3]), np.array([0, 4, 5]), np.array([6, 7, 8])], + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + [np.array([2, 0, 1]), np.array([1, 0, 2]), np.array([0, 1, 2])], + np.array([0, 1, 2]), + ), + ], + ) + def test_split_by_pred(self, extpol, nbcl, indexes, true, pred, result, rcls): + el = ExtendedLabels(true, pred, nbcl, extpol) + labels, cls = el.split_by_pred(indexes) + assert (cls == rcls).all() + assert all([(lbl == r).all() for lbl, r in zip(labels, result)]) + @pytest.mark.ext @pytest.mark.extp class TestExtendedPrev: - @pytest.mark.parametrize( - "flat,nbcl,extpol,q_classes,result", - [ - ( - np.array([0.2, 0, 0.8, 0]), - 2, - ExtensionPolicy(), - [0, 1, 2, 3], - np.array([0.2, 0, 0.8, 0]), - ), - ( - np.array([0.2, 0.8]), - 2, - ExtensionPolicy(), - [0, 3], - np.array([0.2, 0, 0, 0.8]), - ), - ( - np.array([0.2, 0.8]), - 2, - ExtensionPolicy(collapse_false=True), - [0, 2], - np.array([0.2, 0, 0.8]), - ), - ( - np.array([0.1, 0.1, 0.6, 0.2]), - 3, - ExtensionPolicy(), - [0, 1, 3, 5], - np.array([0.1, 0.1, 0, 0.6, 0, 0.2, 0, 0, 0]), - ), - ( - np.array([0.1, 0.1, 0.6]), - 3, - ExtensionPolicy(collapse_false=True), - [0, 1, 2], - np.array([0.1, 0.1, 0.6, 0]), - ), - ], - ) - def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result): - def mockinit(self, flat, nbcl, extpol): - self.flat = flat - self.nbcl = nbcl - self.extpol = extpol - - monkeypatch.setattr(ExtendedPrev, "__init__", mockinit) - ep = ExtendedPrev(flat, nbcl, extpol) - ep._ExtendedPrev__check_q_classes(q_classes) - assert (ep.flat == result).all() - + # fmt: off @pytest.mark.parametrize( "flat,nbcl,extpol,result", [ ( - np.array([0.05, 0.1, 0.6, 0.25]), - 2, - ExtensionPolicy(), + np.array([0.05, 0.1, 0.6, 0.25]), 2, ExtensionPolicy(), np.array([[0.05, 0.1], [0.6, 0.25]]), ), ( - np.array([0.05, 0.1, 0.85]), - 2, - ExtensionPolicy(collapse_false=True), + np.array([0.05, 0.1, 0.6, 0.25]), 2, ExtensionPolicy(group_false=True), + np.array([[0.05, 0.25], [0.6, 0.1]]), + ), + ( + np.array([0.05, 0.1, 0.85]), 2, ExtensionPolicy(collapse_false=True), np.array([[0.05, 0.85], [0, 0.1]]), ), ( - np.array([0.05, 0.1, 0.2, 0.15, 0.04, 0.06, 0.15, 0.14, 0.1]), - 3, - ExtensionPolicy(), + np.array([0.05, 0.1, 0.2, 0.15, 0.04, 0.06, 0.15, 0.14, 0.1]), 3, ExtensionPolicy(), np.array([[0.05, 0.1, 0.2], [0.15, 0.04, 0.06], [0.15, 0.14, 0.1]]), ), ( - np.array([0.05, 0.2, 0.65, 0.1]), - 3, - ExtensionPolicy(collapse_false=True), + np.array([0.15, 0.2, 0.15, 0.1, 0.15, 0.25]), 3, ExtensionPolicy(group_false=True), + np.array([[0.15, 0.0, 0.25], [0.1, 0.2, 0.0], [0.0, 0.15, 0.15]]), + ), + ( + np.array([0.05, 0.2, 0.65, 0.1]), 3, ExtensionPolicy(collapse_false=True), np.array([[0.05, 0.1, 0], [0, 0.2, 0], [0, 0, 0.65]]), ), ], @@ -285,3 +433,130 @@ class TestExtendedPrev: _matrix = ep._ExtendedPrev__build_matrix() assert _matrix.shape == result.shape assert (_matrix == result).all() + + # fmt: on + + +@pytest.mark.ext +@pytest.mark.extp +class TestExtMulPrev: + # fmt: off + @pytest.mark.parametrize( + "flat,nbcl,extpol,q_classes,result", + [ + (np.array([0.2, 0, 0.8, 0]), 2, ExtensionPolicy(), [0, 1, 2, 3], np.array([0.2, 0, 0.8, 0])), + (np.array([0.2, 0.8]), 2, ExtensionPolicy(), [0, 3], np.array([0.2, 0, 0, 0.8])), + (np.array([0.2, 0.8]), 2, ExtensionPolicy(group_false=True), [0, 3], np.array([0.2, 0, 0, 0.8])), + (np.array([0.2, 0.8]), 2, ExtensionPolicy(collapse_false=True), [0, 2], np.array([0.2, 0, 0.8])), + (np.array([0.1, 0.1, 0.6, 0.2]), 3, ExtensionPolicy(), [0, 1, 3, 5], np.array([0.1, 0.1, 0, 0.6, 0, 0.2, 0, 0, 0])), + (np.array([0.1, 0.1, 0.6, 0.2]), 3, ExtensionPolicy(group_false=True), [0, 1, 3, 5], np.array([0.1, 0.1, 0, 0.6, 0, 0.2])), + (np.array([0.1, 0.1, 0.6]), 3, ExtensionPolicy(collapse_false=True), [0, 1, 2], np.array([0.1, 0.1, 0.6, 0])), + ], + ) + def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result): + def mockinit(self, nbcl, extpol): + self.nbcl = nbcl + self.extpol = extpol + + monkeypatch.setattr(ExtMulPrev, "__init__", mockinit) + ep = ExtMulPrev(nbcl, extpol) + _flat = ep._ExtMulPrev__check_q_classes(q_classes, flat) + assert (_flat == result).all() + + # fmt: on + + +@pytest.mark.ext +@pytest.mark.extp +class TestExtBinPrev: + # fmt: off + @pytest.mark.parametrize( + "flat,nbcl,extpol,q_classes,result", + [ + ([np.array([0.2, 0]), np.array([0.8, 0])], 2, ExtensionPolicy(), [[0, 1], [0, 1]], np.array([[0.2, 0], [0.8, 0]])), + ([np.array([0.2]), np.array([0.8])], 2, ExtensionPolicy(), [[0], [1]], np.array([[0.2, 0], [0, 0.8]])), + ([np.array([0.2]), np.array([0.8])], 2, ExtensionPolicy(group_false=True), [[0], [1]], np.array([[0.2, 0], [0, 0.8]])), + ([np.array([0.2]), np.array([0.8])], 2, ExtensionPolicy(collapse_false=True), [[0], [1]], np.array([[0.2, 0], [0, 0.8]])), + ([np.array([0.1, 0.1]), np.array([0.6]), np.array([0.2])], 3, ExtensionPolicy(), [[0, 1], [0], [2]], np.array([[0.1, 0.1, 0], [0.6, 0, 0], [0, 0, 0.2]])), + ([np.array([0.1, 0.1]), np.array([0.6]), np.array([0.2])], 3, ExtensionPolicy(group_false=True), [[0, 1], [0], [1]], np.array([[0.1, 0.1], [0.6, 0], [0, 0.2]])), + ([np.array([0.1, 0.1]), np.array([0.6]), np.array([0.2])], 3, ExtensionPolicy(collapse_false=True), [[0, 1], [0], [2]], np.array([[0.1, 0.1, 0], [0.6, 0, 0], [0, 0, 0.2]])), + ], + ) + def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result): + def mockinit(self, nbcl, extpol): + self.nbcl = nbcl + self.extpol = extpol + + monkeypatch.setattr(ExtBinPrev, "__init__", mockinit) + ep = ExtBinPrev(nbcl, extpol) + _flat = ep._ExtBinPrev__check_q_classes(q_classes, flat) + assert (_flat == result).all() + + @pytest.mark.parametrize( + "flat,result", + [ + (np.array([[0.2, 0], [0.8, 0]]), np.array([0.2, 0.8, 0, 0])), + (np.array([[0.2, 0], [0, 0.8]]), np.array([0.2, 0, 0, 0.8])), + (np.array([[0.1, 0.1, 0], [0.6, 0, 0], [0, 0, 0.2]]), np.array([0.1, 0.6, 0, 0.1, 0, 0, 0, 0, 0.2])), + (np.array([[0.1, 0.1], [0.6, 0], [0, 0.2]]), np.array([0.1, 0.6, 0, 0.1, 0, 0.2])), + ], + ) + def test__build_flat(self, monkeypatch, flat, result): + def mockinit(self): + pass + + monkeypatch.setattr(ExtBinPrev, "__init__", mockinit) + ep = ExtBinPrev() + _flat = ep._ExtBinPrev__build_flat(flat) + assert (_flat == result).all() + # fmt: on + + +@pytest.mark.ext +@pytest.mark.extc +class TestExtendedCollection: + @pytest.mark.parametrize( + "instances_name,labels,pred_proba,extpol,result", + [ + ( + "nd_1", + np.array([0, 1, 1, 0]), + np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]), + ExtensionPolicy(), + np.array([0, 0.5, 0.25, 0.25]), + ), + ( + "nd_1", + np.array([0, 1, 1, 0]), + np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]), + ExtensionPolicy(collapse_false=True), + np.array([0, 0.25, 0.75]), + ), + ( + "csr_1", + np.array([0, 1, 1, 0]), + np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]), + ExtensionPolicy(), + np.array([0, 0.5, 0.25, 0.25]), + ), + ( + "csr_1", + np.array([0, 1, 1, 0]), + np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]), + ExtensionPolicy(collapse_false=True), + np.array([0, 0.25, 0.75]), + ), + ], + ) + def test_prevalence( + self, instances_name, labels, pred_proba, extpol, result, request + ): + instances = request.getfixturevalue(instances_name) + ec = ExtendedCollection( + instances=instances, + labels=labels, + pred_proba=pred_proba, + ext=pred_proba, + extpol=extpol, + ) + assert (ec.prevalence() == result).all() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9b2a72a..bd0c837 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,3 +1,135 @@ +import os +from contextlib import redirect_stderr +import numpy as np +import pytest + +from quacc.dataset import Dataset + + +@pytest.mark.dataset class TestDataset: - pass \ No newline at end of file + @pytest.mark.slow + @pytest.mark.parametrize( + "name,target,prevalence", + [ + ("spambase", None, [0.5, 0.5]), + ("imdb", None, [0.5, 0.5]), + ("rcv1", "CCAT", [0.5, 0.5]), + ("cifar10", "dog", [0.5, 0.5]), + ("twitter_gasp", None, [0.33, 0.33, 0.33]), + ], + ) + def test__resample_all_train(self, name, target, prevalence, monkeypatch): + def mockinit(self): + self._name = name + self._target = target + self.all_train, self.test = self.alltrain_test(self._name, self._target) + + monkeypatch.setattr(Dataset, "__init__", mockinit) + with open(os.devnull, "w") as dn: + with redirect_stderr(dn): + d = Dataset() + d._Dataset__resample_all_train() + assert ( + np.around(d.all_train.prevalence(), decimals=2).tolist() + == prevalence + ) + + @pytest.mark.parametrize( + "ncl, prevs,result", + [ + (2, None, None), + (2, [], None), + (2, [[0.2, 0.1], [0.3, 0.2]], None), + (2, [[0.2, 0.8], [0.3, 0.7]], [[0.2, 0.8], [0.3, 0.7]]), + (2, [1.0, 2.0, 3.0], None), + (2, [1, 2, 3], None), + (2, [[1, 2], [2, 3], [3, 4]], None), + (2, ["abc", "def"], None), + (3, [[0.2, 0.3], [0.4, 0.1], [0.5, 0.2]], None), + (3, [[0.2, 0.3, 0.2], [0.4, 0.1], [0.5, 0.6]], None), + (2, [[0.2, 0.3, 0.1], [0.1, 0.5, 0.3]], None), + (3, [[0.2, 0.3, 0.1], [0.1, 0.5, 0.3]], None), + (3, [[0.2, 0.8], [0.1, 0.5]], None), + (2, [[0.2, 0.9], [0.1, 0.5]], None), + (2, 10, None), + (2, [[0.2, 0.8], [0.5, 0.5]], [[0.2, 0.8], [0.5, 0.5]]), + (3, [[0.2, 0.6], [0.3, 0.5]], None), + ], + ) + def test__check_prevs(self, ncl, prevs, result, monkeypatch): + class MockLabelledCollection: + def __init__(self): + self.n_classes = ncl + + def mockinit(self): + self.all_train = MockLabelledCollection() + self.prevs = None + + monkeypatch.setattr(Dataset, "__init__", mockinit) + d = Dataset() + d._Dataset__check_prevs(prevs) + _prevs = d.prevs if d.prevs is None else d.prevs.tolist() + assert _prevs == result + + # fmt: off + + + @pytest.mark.parametrize( + "ncl,nprevs,built,result", + [ + (2, 3, None, [[0.25, 0.75], [0.5, 0.5], [0.75, 0.25]]), + (2, 3, np.array([[0.8, 0.2], [0.6, 0.4], [0.4, 0.6]]), [[0.8, 0.2], [0.6, 0.4], [0.4, 0.6]]), + (2, 3, np.array([[0.75, 0.25], [0.5, 0.5], [0.25, 0.75]]), [[0.75, 0.25], [0.5, 0.5], [0.25, 0.75]]), + (3, 3, None, [[0.25, 0.25, 0.5], [0.25, 0.5, 0.25], [0.5, 0.25, 0.25]]), + ( + 3, 4, None, + [[0.2, 0.2, 0.6], [0.2, 0.4, 0.4], [0.2, 0.6, 0.2], [0.4, 0.2, 0.4], [0.4, 0.4, 0.2], [0.6, 0.2, 0.2]], + ), + ], + ) + def test__build_prevs(self, ncl, nprevs, built, result, monkeypatch): + class MockLabelledCollection: + def __init__(self): + self.n_classes = ncl + + def mockinit(self): + self.all_train = MockLabelledCollection() + self.prevs = built + self._n_prevs = nprevs + + monkeypatch.setattr(Dataset, "__init__", mockinit) + d = Dataset() + _prevs = d._Dataset__build_prevs().tolist() + assert _prevs == result + + # fmt: on + + @pytest.mark.parametrize( + "ncl,prevs,atsize", + [ + (2, np.array([[0.2, 0.8], [0.9, 0.1]]), 55), + (3, np.array([[0.2, 0.7, 0.1], [0.9, 0.05, 0.05]]), 37), + ], + ) + def test_get(self, ncl, prevs, atsize, monkeypatch): + class MockLabelledCollection: + def __init__(self): + self.n_classes = ncl + + def __len__(self): + return 100 + + def mockinit(self): + self.prevs = prevs + self.all_train = MockLabelledCollection() + + def mock_build_sample(self, p, at_size): + return at_size + + monkeypatch.setattr(Dataset, "__init__", mockinit) + monkeypatch.setattr(Dataset, "_Dataset__build_sample", mock_build_sample) + d = Dataset() + _get = d.get() + assert all(s == atsize for s in _get) diff --git a/tests/test_error.py b/tests/test_error.py new file mode 100644 index 0000000..8d9cb58 --- /dev/null +++ b/tests/test_error.py @@ -0,0 +1,95 @@ +import numpy as np +import pytest + +from quacc import error +from quacc.data import ExtendedPrev, ExtensionPolicy + + +@pytest.mark.err +class TestError: + @pytest.mark.parametrize( + "prev,result", + [ + (np.array([[1, 4], [4, 4]]), 0.5), + (np.array([[6, 2, 4], [2, 4, 2], [4, 2, 6]]), 0.5), + ], + ) + def test_f1(self, prev, result): + ep = ExtendedPrev(prev.flatten(), prev.shape[0], extpol=ExtensionPolicy()) + assert error.f1(prev) == result + assert error.f1(ep) == result + + @pytest.mark.parametrize( + "prev,result", + [ + (np.array([[4, 4], [4, 4]]), 0.5), + (np.array([[2, 4, 2], [2, 2, 4], [4, 2, 2]]), 0.25), + ], + ) + def test_acc(self, prev, result): + ep = ExtendedPrev(prev.flatten(), prev.shape[0], extpol=ExtensionPolicy()) + assert error.acc(prev) == result + assert error.acc(ep) == result + + @pytest.mark.parametrize( + "true_prev,estim_prev,nbcl,extpol,result", + [ + ( + [ + np.array([0.2, 0.4, 0.1, 0.3]), + np.array([0.1, 0.5, 0.1, 0.3]), + ], + [ + np.array([0.3, 0.4, 0.2, 0.1]), + np.array([0.5, 0.3, 0.1, 0.1]), + ], + 2, + ExtensionPolicy(), + np.array([0.1, 0.2]), + ), + ( + [ + np.array([0.2, 0.4, 0.4]), + np.array([0.1, 0.5, 0.4]), + ], + [ + np.array([0.3, 0.4, 0.3]), + np.array([0.5, 0.3, 0.2]), + ], + 2, + ExtensionPolicy(collapse_false=True), + np.array([0.1, 0.2]), + ), + ( + [ + np.array([0.02, 0.04, 0.16, 0.38, 0.1, 0.05, 0.15, 0.08, 0.02]), + np.array([0.04, 0.02, 0.14, 0.40, 0.1, 0.03, 0.17, 0.07, 0.03]), + ], + [ + np.array([0.02, 0.04, 0.16, 0.48, 0.0, 0.05, 0.15, 0.08, 0.02]), + np.array([0.14, 0.02, 0.04, 0.30, 0.2, 0.03, 0.17, 0.07, 0.03]), + ], + 3, + ExtensionPolicy(), + np.array([0.1, 0.2]), + ), + ( + [ + np.array([0.2, 0.4, 0.2, 0.2]), + np.array([0.1, 0.3, 0.2, 0.4]), + ], + [ + np.array([0.3, 0.3, 0.1, 0.3]), + np.array([0.5, 0.2, 0.1, 0.2]), + ], + 3, + ExtensionPolicy(collapse_false=True), + np.array([0.1, 0.2]), + ), + ], + ) + def test_accd(self, true_prev, estim_prev, nbcl, extpol, result): + true_prev = [ExtendedPrev(tp, nbcl, extpol=extpol) for tp in true_prev] + estim_prev = [ExtendedPrev(ep, nbcl, extpol=extpol) for ep in estim_prev] + _err = error.accd(true_prev, estim_prev) + assert (np.abs(_err - result) < 1e-15).all() diff --git a/tests/test_evaluation/test_report.py b/tests/test_evaluation/test_report.py new file mode 100644 index 0000000..922bac6 --- /dev/null +++ b/tests/test_evaluation/test_report.py @@ -0,0 +1,425 @@ +import numpy as np +import pytest + +from quacc.evaluation.report import ( + CompReport, + DatasetReport, + EvaluationReport, + _get_shift, +) + + +@pytest.fixture +def empty_er(): + return EvaluationReport("empty") + + +@pytest.fixture +def er_list(): + er1 = EvaluationReport("er1") + er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1)) + er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4)) + er1.append_row(np.array([0.3, 0.7]), **dict(acc=0.7, acc_score=0.3)) + er2 = EvaluationReport("er2") + er2.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1)) + er2.append_row( + np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4, f1=0.9, f1_score=0.6) + ) + er2.append_row(np.array([0.4, 0.6]), **dict(acc=0.7, acc_score=0.3)) + return [er1, er2] + + +@pytest.fixture +def er_list2(): + er1 = EvaluationReport("er12") + er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1)) + er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4)) + er1.append_row(np.array([0.3, 0.7]), **dict(acc=0.7, acc_score=0.3)) + er2 = EvaluationReport("er2") + er2.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1)) + er2.append_row( + np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4, f1=0.9, f1_score=0.6) + ) + er2.append_row(np.array([0.4, 0.6]), **dict(acc=0.8, acc_score=0.3)) + return [er1, er2] + + +@pytest.fixture +def er_list3(): + er1 = EvaluationReport("er31") + er1.append_row(np.array([0.2, 0.5, 0.3]), **dict(acc=0.9, acc_score=0.1)) + er1.append_row(np.array([0.2, 0.4, 0.4]), **dict(acc=0.6, acc_score=0.4)) + er1.append_row(np.array([0.3, 0.6, 0.1]), **dict(acc=0.7, acc_score=0.3)) + er2 = EvaluationReport("er32") + er2.append_row(np.array([0.2, 0.5, 0.3]), **dict(acc=0.9, acc_score=0.1)) + er2.append_row( + np.array([0.2, 0.5, 0.3]), **dict(acc=0.6, acc_score=0.4, f1=0.9, f1_score=0.6) + ) + er2.append_row(np.array([0.3, 0.3, 0.4]), **dict(acc=0.8, acc_score=0.3)) + return [er1, er2] + + +@pytest.fixture +def cr_1(er_list): + return CompReport( + er_list, + "cr_1", + train_prev=np.array([0.2, 0.8]), + valid_prev=np.array([0.25, 0.75]), + g_time=0.0, + ) + + +@pytest.fixture +def cr_2(er_list2): + return CompReport( + er_list2, + "cr_2", + train_prev=np.array([0.3, 0.7]), + valid_prev=np.array([0.35, 0.65]), + g_time=0.0, + ) + + +@pytest.fixture +def cr_3(er_list3): + return CompReport( + er_list3, + "cr_3", + train_prev=np.array([0.4, 0.1, 0.5]), + valid_prev=np.array([0.45, 0.25, 0.2]), + g_time=0.0, + ) + + +@pytest.fixture +def cr_4(er_list3): + return CompReport( + er_list3, + "cr_4", + train_prev=np.array([0.5, 0.1, 0.4]), + valid_prev=np.array([0.45, 0.25, 0.2]), + g_time=0.0, + ) + + +@pytest.fixture +def dr_1(cr_1, cr_2): + return DatasetReport("dr_1", [cr_1, cr_2]) + + +@pytest.fixture +def dr_2(cr_3, cr_4): + return DatasetReport("dr_2", [cr_3, cr_4]) + + +@pytest.mark.rep +@pytest.mark.mrep +class TestReport: + @pytest.mark.parametrize( + "cr_name,train_prev,shift", + [ + ( + "cr_1", + np.array([0.2, 0.8]), + np.array([0.2, 0.1, 0.0, 0.0]), + ), + ( + "cr_3", + np.array([0.2, 0.5, 0.3]), + np.array([0.2, 0.2, 0.0, 0.0, 0.1]), + ), + ], + ) + def test_get_shift(self, cr_name, train_prev, shift, request): + cr = request.getfixturevalue(cr_name) + assert ( + _get_shift(cr._data.index.get_level_values(0), train_prev) == shift + ).all() + + +@pytest.mark.rep +@pytest.mark.erep +class TestEvaluationReport: + def test_init(self, empty_er): + assert empty_er.data is None + + @pytest.mark.parametrize( + "rows,index,columns,data", + [ + ( + [ + (np.array([0.2, 0.8]), dict(acc=0.9, acc_score=0.1)), + (np.array([0.2, 0.8]), dict(acc=0.6, acc_score=0.4)), + (np.array([0.3, 0.7]), dict(acc=0.7, acc_score=0.3)), + ], + [((0.2, 0.8), 0), ((0.2, 0.8), 1), ((0.3, 0.7), 0)], + ["acc", "acc_score"], + np.array([[0.9, 0.1], [0.6, 0.4], [0.7, 0.3]]), + ), + ], + ) + def test_append_row(self, empty_er, rows, index, columns, data): + er: EvaluationReport = empty_er + for prev, r in rows: + er.append_row(prev, **r) + assert er.data.index.to_list() == index + assert er.data.columns.to_list() == columns + assert (er.data.to_numpy() == data).all() + + +@pytest.mark.rep +@pytest.mark.crep +class TestCompReport: + @pytest.mark.parametrize( + "train_prev,valid_prev,index,columns", + [ + ( + np.array([0.2, 0.8]), + np.array([0.25, 0.75]), + [ + ((0.4, 0.6), 0), + ((0.3, 0.7), 0), + ((0.2, 0.8), 0), + ((0.2, 0.8), 1), + ], + [ + ("acc", "er1"), + ("acc", "er2"), + ("acc_score", "er1"), + ("acc_score", "er2"), + ("f1", "er2"), + ("f1_score", "er2"), + ], + ) + ], + ) + def test_init(self, er_list, train_prev, valid_prev, index, columns): + cr = CompReport(er_list, "cr", train_prev, valid_prev, g_time=0.0) + assert cr._data.index.to_list() == index + assert cr._data.columns.to_list() == columns + assert (cr.train_prev == train_prev).all() + assert (cr.valid_prev == valid_prev).all() + + @pytest.mark.parametrize( + "cr_name,prev", + [ + ("cr_1", [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_2", [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ( + "cr_3", + [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)], + ), + ( + "cr_4", + [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)], + ), + ], + ) + def test_prevs(self, cr_name, prev, request): + cr = request.getfixturevalue(cr_name) + assert cr.prevs.tolist() == prev + + def test_join(self, er_list, er_list2): + tp = np.array([0.2, 0.8]) + vp = np.array([0.25, 0.75]) + cr1 = CompReport(er_list, "cr1", train_prev=tp, valid_prev=vp) + cr2 = CompReport(er_list2, "cr2", train_prev=tp, valid_prev=vp) + crj = cr1.join(cr2) + _loc = crj._data.loc[((0.4, 0.6), 0), ("acc", "er2")].to_numpy() + assert (_loc == np.array([0.8])).all() + + @pytest.mark.parametrize( + "cr_name,metric,estimators,columns", + [ + ("cr_1", "acc", None, ["er1", "er2"]), + ("cr_1", "acc", ["er1"], ["er1"]), + ("cr_1", "acc", ["er1", "er2"], ["er1", "er2"]), + ("cr_1", "f1", None, ["er2"]), + ("cr_1", "f1", ["er2"], ["er2"]), + ("cr_3", "acc", None, ["er31", "er32"]), + ("cr_3", "acc", ["er31"], ["er31"]), + ("cr_3", "acc", ["er31", "er32"], ["er31", "er32"]), + ("cr_3", "f1", None, ["er32"]), + ("cr_3", "f1", ["er32"], ["er32"]), + ], + ) + def test_data(self, cr_name, metric, estimators, columns, request): + cr = request.getfixturevalue(cr_name) + _data = cr.data(metric=metric, estimators=estimators) + assert _data.columns.to_list() == columns + assert all(_data.index == cr._data.index) + + # fmt: off + @pytest.mark.parametrize( + "cr_name,metric,estimators,columns,index", + [ + ("cr_1", "acc", None, ["er1", "er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]), + ("cr_1", "acc", ["er1"], ["er1"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]), + ("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]), + ("cr_1", "f1", None, ["er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]), + ("cr_1", "f1", ["er2"], ["er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]), + ("cr_3", "acc", None, ["er31", "er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]), + ("cr_3", "acc", ["er31"], ["er31"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]), + ("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]), + ("cr_3", "f1", None, ["er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]), + ("cr_3", "f1", ["er32"], ["er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]), + ], + ) + def test_shift_data(self, cr_name, metric, estimators, columns, index, request): + cr = request.getfixturevalue(cr_name) + _data = cr.shift_data(metric=metric, estimators=estimators) + assert _data.columns.to_list() == columns + assert _data.index.to_list() == index + + @pytest.mark.parametrize( + "cr_name,metric,estimators,columns,index", + [ + ("cr_1", "acc", None, ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_1", "acc", ["er1"], ["er1"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_1", "f1", None, ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_1", "f1", ["er2"], ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_3", "acc", None, ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]), + ("cr_3", "acc", ["er31"], ["er31"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]), + ("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]), + ("cr_3", "f1", None, ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]), + ("cr_3", "f1", ["er32"], ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]), + ], + ) + def test_avg_by_prevs(self, cr_name, metric, estimators, columns, index, request): + cr = request.getfixturevalue(cr_name) + _data = cr.avg_by_prevs(metric=metric, estimators=estimators) + assert _data.columns.to_list() == columns + assert _data.index.to_list() == index + + @pytest.mark.parametrize( + "cr_name,metric,estimators,columns,index", + [ + ("cr_1", "acc", None, ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_1", "acc", ["er1"], ["er1"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_1", "f1", None, ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_1", "f1", ["er2"], ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]), + ("cr_3", "acc", None, ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]), + ("cr_3", "acc", ["er31"], ["er31"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]), + ("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]), + ("cr_3", "f1", None, ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]), + ("cr_3", "f1", ["er32"], ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]), + ], + ) + def test_stdev_by_prevs(self, cr_name, metric, estimators, columns, index, request): + cr = request.getfixturevalue(cr_name) + _data = cr.stdev_by_prevs(metric=metric, estimators=estimators) + assert _data.columns.to_list() == columns + assert _data.index.to_list() == index + + @pytest.mark.parametrize( + "cr_name,metric,estimators,columns,index", + [ + ("cr_1", "acc", None, ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]), + ("cr_1", "acc", ["er1"], ["er1"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]), + ("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]), + ("cr_1", "f1", None, ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]), + ("cr_1", "f1", ["er2"], ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]), + ("cr_3", "acc", None, ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]), + ("cr_3", "acc", ["er31"], ["er31"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]), + ("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]), + ("cr_3", "f1", None, ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]), + ("cr_3", "f1", ["er32"], ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]), + ], + ) + def test_train_table(self, cr_name, metric, estimators, columns, index, request): + cr = request.getfixturevalue(cr_name) + _data = cr.train_table(metric=metric, estimators=estimators) + assert _data.columns.to_list() == columns + assert _data.index.to_list() == index + + @pytest.mark.parametrize( + "cr_name,metric,estimators,columns,index", + [ + ("cr_1", "acc", None, ["er1", "er2"], [0.0, 0.1, 0.2, "mean"]), + ("cr_1", "acc", ["er1"], ["er1"], [0.0, 0.1, 0.2, "mean"]), + ("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [0.0, 0.1, 0.2, "mean"]), + ("cr_1", "f1", None, ["er2"], [0.0, 0.1, 0.2, "mean"]), + ("cr_1", "f1", ["er2"], ["er2"], [0.0, 0.1, 0.2, "mean"]), + ("cr_3", "acc", None, ["er31", "er32"], [0.2, 0.3, 0.4, 0.5, "mean"]), + ("cr_3", "acc", ["er31"], ["er31"], [0.2, 0.3, 0.4, 0.5, "mean"]), + ("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [0.2, 0.3, 0.4, 0.5, "mean"]), + ("cr_3", "f1", None, ["er32"], [0.2, 0.3, 0.4, 0.5, "mean"]), + ("cr_3", "f1", ["er32"], ["er32"], [0.2, 0.3, 0.4, 0.5, "mean"]), + ], + ) + def test_shift_table(self, cr_name, metric, estimators, columns, index, request): + cr = request.getfixturevalue(cr_name) + _data = cr.shift_table(metric=metric, estimators=estimators) + assert _data.columns.to_list() == columns + assert _data.index.to_list() == index + # fmt: on + + +@pytest.mark.rep +@pytest.mark.drep +class TestDatasetReport: + # fmt: off + @pytest.mark.parametrize( + "dr_name,metric,estimators,columns,index", + [ + ( + "dr_1", "acc", None, ["er1", "er2", "er12"], + [ + ((0.3, 0.7), (0.4, 0.6), 0), + ((0.3, 0.7), (0.3, 0.7), 0), + ((0.3, 0.7), (0.2, 0.8), 0), + ((0.3, 0.7), (0.2, 0.8), 1), + ((0.2, 0.8), (0.4, 0.6), 0), + ((0.2, 0.8), (0.3, 0.7), 0), + ((0.2, 0.8), (0.2, 0.8), 0), + ((0.2, 0.8), (0.2, 0.8), 1), + ], + ), + ( + "dr_2", "acc", None, ["er31", "er32"], + [ + ((0.5, 0.1, 0.4), (0.3, 0.6, 0.1), 0), + ((0.5, 0.1, 0.4), (0.3, 0.3, 0.4), 0), + ((0.5, 0.1, 0.4), (0.2, 0.5, 0.3), 0), + ((0.5, 0.1, 0.4), (0.2, 0.5, 0.3), 1), + ((0.5, 0.1, 0.4), (0.2, 0.4, 0.4), 0), + ((0.4, 0.1, 0.5), (0.3, 0.6, 0.1), 0), + ((0.4, 0.1, 0.5), (0.3, 0.3, 0.4), 0), + ((0.4, 0.1, 0.5), (0.2, 0.5, 0.3), 0), + ((0.4, 0.1, 0.5), (0.2, 0.5, 0.3), 1), + ((0.4, 0.1, 0.5), (0.2, 0.4, 0.4), 0), + ], + ), + ], + ) + def test_data(self, dr_name, metric, estimators, columns, index, request): + dr = request.getfixturevalue(dr_name) + _data = dr.data(metric=metric, estimators=estimators) + assert _data.columns.to_list() == columns + assert _data.index.to_list() == index + + @pytest.mark.parametrize( + "dr_name,metric,estimators,columns,index", + [ + ( + "dr_1", "acc", None, ["er1", "er2", "er12"], + [(0.0, 0), (0.0, 1), (0.0, 2), (0.1, 0), + (0.1, 1), (0.1, 2), (0.1, 3), (0.2, 0)], + ), + ( + "dr_2", "acc", None, ["er31", "er32"], + [(0.2, 0), (0.2, 1), (0.3, 0), (0.3, 1), (0.4, 0), + (0.4, 1), (0.4, 2), (0.4, 3), (0.5, 0), (0.5, 1)], + ), + ], + ) + def test_shift_data(self, dr_name, metric, estimators, columns, index, request): + dr = request.getfixturevalue(dr_name) + _data = dr.shift_data(metric=metric, estimators=estimators) + print(_data.index.tolist()) + assert _data.columns.to_list() == columns + assert _data.index.to_list() == index + # fmt: off diff --git a/tests/test_method/__pycache__/test_base.cpython-311-pytest-7.4.2.pyc b/tests/test_method/__pycache__/test_base.cpython-311-pytest-7.4.2.pyc deleted file mode 100644 index 526d081a6d2f0586496ca4bb5fd2595c25573c42..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5919 zcmd5ATWlN0agRqH?~V^sFQc*&MQ$2bECg1REX8(gv5I3jN>Uq1Tmxog9L{$qn)31C zy<;q@Bx)2c;E#UzqiEEk1%$M49V7_++0O*+kNzl^so4VxGzd_^s6QxJE{uG2c8^E$ zR8m$EH}US>?Ck8!?Cjj`?2vy5g+v5p=G?6IO97!TNXIGO2J#>a$a_ddDx;$$Lt9UV zRXhyJc#}TH9^sOHI>IM;+6oyl8K7exJ(vk4LlkHAaAr?(4}(1DBvQTCkm|eXLFhwx z?UIbB%-bl+-6P#oQO4?hBbQ1WW?FemOJU70(zz^9_rILZ%DDJuLC%`#g|voe6{Ub> zrFa5ZGqRb(N<$%S_B1*VES4-qqsvHoGg3plvSc#y^|`kj6Xsh;;4fGmYx z3Vi^*HyKz4ymoO*shsLhd3NOBRbgiim{WY8aqC#{!aV#Ud^o#2h1PZQ9r4=E-gq+p z3hdfn->gM$AK5#L?5VHAN5NhA=o_~v#kqro8oGvV!j}Kg+4v|as$mEld#o4EXomUN z;IKTL&KjniRWxHPk2Td8x10DFgqVVEMp;ZE(#9k$e@SdC(DKhjl)XpJc|^_lsKswuS?8gQ*yN#OdU<6Epdy@(5iaeA94yYXjdAVtLUhK zZJ(nRtR>qx-91w?xE|?k|6Vo>4Vxn`jOL5+$*3Xz*uh4QQR0L{1EfJ3 zj~X!~AtIvq8}M!=e5M9-n95f#g0 zwncH;u{_zlB{YO`*bbihyYX4q0AU-Z+JE3MZp|n79O! zKyC?HNzsWX49l07P35A+nntQR!gM@WG- zw^+QhN-Ow-d{MekNb4pTZgF}}k#)lofocMeBxCDal5sIt&{YX*c`T*$+yz;evT{bV z1j&#WG(xt*v-DfRkAV{ygB$v!g&YPqO4?Fhm$R}-ekO!0uc7G+7MIB>mo$tAh+Uk0 z-#6Mg-U~$BPc0^<*?S;v1zMeh?VT6fJBRJ%G-M|1o^43{ph4pMNPGEmRyi{NBiLic z{2Mu}WtVd!XLEW!J#XX+SkdOe18D+GgF-2znHO`ajbD%rEtW4@1Jq59JEg2j-MswL z86Z72?TK|1CbymOUDyFv(ZhKL4NQMF{p$MkYt`x3*85+t_P>5zs0}}NU8wRywNT&B z=kJC_?u15G2dn!J-d#)y@wt)6yD+x8?O5St0TM1wRI`2E7+tG1pH4dRIDZO zurV(m-Xvyhx%JV}``kk^v~4~|&=>ABN+)Qs&?RAAzmua|8)-LKz&S$nkZh6@wbeXUZ3o9qd-Uj0*Rj^XZnaT~$Eq}5gn;oW zyMHdme)^r02j=$r*2Exm*7kbm_6ckISb{&x5PX>8G4I@je{G5Zcq$sOc!`E{NwP$V z+$aTI13V~6eo(TQYCPco!d`p-Z(>*bv%8aj|FLvj0Ak4#h} z6ZMcEf>ST*83z8)*E%CzM$%N8>EBF)k83?T#!C7mzE3^-$kW^s0lUnWES#wA^oaWv(VbZZ4ls&^QocG lRa$Fku;PAdXtLsd{(%w|_fzK%GvMj}2AhPo)nR~|{~x*17OVgO diff --git a/tests/test_method/__pycache__/test_model_selection.cpython-311-pytest-7.4.2.pyc b/tests/test_method/__pycache__/test_model_selection.cpython-311-pytest-7.4.2.pyc deleted file mode 100644 index f9d759899a3c454034ebc854d97ca5e833c6a860..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 554 zcmZuuy-ve05I!e?rd11qjg5t&1ATxBA%rTiKnDs8Q6wwGmWY(Z!A^zB)PZ;C8-Nfm zRApsiD=Jebd`?mrdRFef`?0?-_FJdZ0&T~GbAF`bCnmKy6se;rD)w)<%C7|D z@rcJf%p;KwyLRDHng$7{v}jS92sVxpw<&!b2XWbXhb5l8x(xgNt>RMomqPONSsdPp zIE#EG#xmr7&Q*jiSH{zX=MREa7r&CEYc0JQ$7vu-0WNYFB7C!zds1G( isoc}4QFg7t!X|{w|112{if93ZG&i|#Z2WBK8qQyZ_m0T` diff --git a/tests/test_method/test_base.py b/tests/test_method/test_base.py new file mode 100644 index 0000000..66e50b2 --- /dev/null +++ b/tests/test_method/test_base.py @@ -0,0 +1,100 @@ +import numpy as np +import pytest +import scipy.sparse as sp + +from quacc.data import ExtendedData, ExtensionPolicy +from quacc.method.base import MultiClassAccuracyEstimator + + +@pytest.mark.mcae +class TestMultiClassAccuracyEstimator: + @pytest.mark.parametrize( + "instances,pred_proba,extpol,result", + [ + ( + np.arange(12).reshape((4, 3)), + np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]), + ExtensionPolicy(), + np.array([0.21, 0.39, 0.1, 0.4]), + ), + ( + np.arange(12).reshape((4, 3)), + np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]), + ExtensionPolicy(collapse_false=True), + np.array([0.21, 0.39, 0.5]), + ), + ( + sp.csr_matrix(np.arange(12).reshape((4, 3))), + np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]), + ExtensionPolicy(), + np.array([0.21, 0.39, 0.1, 0.4]), + ), + ( + np.arange(12).reshape((4, 3)), + np.array( + [ + [0.3, 0.2, 0.5], + [0.13, 0.67, 0.2], + [0.21, 0.09, 0.8], + [0.19, 0.1, 0.71], + ] + ), + ExtensionPolicy(), + np.array([0.21, 0.09, 0.1, 0.04, 0.06, 0.11, 0.11, 0.18, 0.1]), + ), + ( + np.arange(12).reshape((4, 3)), + np.array( + [ + [0.3, 0.2, 0.5], + [0.13, 0.67, 0.2], + [0.21, 0.09, 0.8], + [0.19, 0.1, 0.71], + ] + ), + ExtensionPolicy(collapse_false=True), + np.array([0.21, 0.09, 0.1, 0.7]), + ), + ( + sp.csr_matrix(np.arange(12).reshape((4, 3))), + np.array( + [ + [0.3, 0.2, 0.5], + [0.13, 0.67, 0.2], + [0.21, 0.09, 0.8], + [0.19, 0.1, 0.71], + ] + ), + ExtensionPolicy(), + np.array([0.21, 0.09, 0.1, 0.04, 0.06, 0.11, 0.11, 0.18, 0.1]), + ), + ], + ) + def test_estimate(self, monkeypatch, instances, pred_proba, extpol, result): + ed = ExtendedData(instances, pred_proba, pred_proba, extpol) + + class MockQuantifier: + def __init__(self): + self.classes_ = np.arange(result.shape[0]) + + def quantify(self, X): + return result + + def mockinit(self): + self.extpol = extpol + self.quantifier = MockQuantifier() + + def mock_extend_instances(self, instances): + return ed + + monkeypatch.setattr(MultiClassAccuracyEstimator, "__init__", mockinit) + monkeypatch.setattr( + MultiClassAccuracyEstimator, "_extend_instances", mock_extend_instances + ) + mcae = MultiClassAccuracyEstimator() + + ep1 = mcae.estimate(instances) + ep2 = mcae.estimate(ed) + + assert (ep1.flat == ep2.flat).all() + assert (ep1.flat == result).all() diff --git a/tests/test_method/test_base/__pycache__/test_BQAE.cpython-311-pytest-7.4.2.pyc b/tests/test_method/test_base/__pycache__/test_BQAE.cpython-311-pytest-7.4.2.pyc deleted file mode 100644 index e1ad6b881cc43ecd4037e2059cc05b86f9622c01..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5871 zcmd5ATWlN0agUGV-I0{3mr+@XA~&ro76L0umSVfMS=lih#iJJH)3nO2h+2fHs zl~h&4O}u+IJ3BiwJ3F^KJLGS|VTnL_|Ktg+Bogu&cAVm?Bad@{yh~J~aym(Jto3BQ zDGx`ozNDYCNBCrbjR;ABwPIFE2HBWL4`sv2FvGoiS9VWw4@W%YG*NvwiR!=OA>>1N z?UIbB+?yoIKg90YC};J&kxyp~Gn4v(mZq9vWb!$n?mv;qDYSIHsN~GdTt=hQsZ^0F zsnRK6%_?S|rs@h=v!~v9Y z{Pd~Cp-j#&m0U_QMhjF^jWN54k3tkF>SokSG3cyMu=1zG`W!3&NJPC4@ffFH`*gAI zqtl-z($|hj4}Mkn8xEjE`qBB{5B>C?Yl$?Hk4ft|0Aa;Z^;u&5{r9l^H9;|MPys<0 zK?p$9m$s+;Ln3YStt8U++<&nJqP}(DW5pKh0=zD3Yh(nb7O@m>Gc$+;H0JTYAhNwn z%tk~a#6)_%xV`6&DivPk%DmhN$z^W7A-UX$myRZE%lv#Jda54xx0)mz*_DRID%xuB zw$9Ov$1-|>m9h^w8amo>m^}_JA@l6o!+s?IazSULBB?w=C^SRejkzgcH zmde{MtxB*Qyv-T6RKLob1CB131;08gZhyZ+i~;jSXAJn9`-!yL=>96yEfH~~J4@s^ zu%>mEhDx8zYrJU992T*%g!Ey*wMO=8t14S za?84VKY>(nfvW`I72p-&l`26<4xzfw6-RQxmr0x$zU!U^?_1lYoUTwi#{Iy~j8a0Q%IZ`OaC!&V*BL^Ek#)#t%4Ujr% zENa9QN|lnVfj(WTj4&A>*QEigvralRKF(bSi3$4udtvP|u;L5tS;V zwncH)u{^ngCDw)V)DjyPnZ?)dF46K9N@EDd5lkSML@&gNB3m&clDoOYv7h-zk$=Hz!2y{OS%Bn4aC zVhPSFtusFw1tRTZ7NcqQ9*A4PX6ImQ=atsZVS71snQ^;k8xlWgV0<5LEuWuB9hv@j2Z zjXc$Ii}~S;dA*RCHS$H8(q_Q}83d-mpq$mrD|yw%FDnLPFlEeNVug~`&m87BR?ex+ z;fq5TfW&y(j^ov4d<#DJ4f14`BmGlh>uY^ys(oi}iM64ZZi!W4uomw9 z+3bVx@cr=c@<4U}!F%&-;pu94dV>gJ7Xu$p{qg97lNaxwyjTnMS3@r?@7wfxUge}M zg0RWA!IQ&@TQk5>4UJSoleI&~TY67C-f-6zPrO|l0l+qh-Q`=ljMfosG6@9#6AKe- zN<690D|Bt58C!0BZ1fTT1c$cG=Lq_{+KsXank;OI9Te<1^GbJi9b$9gw`b#?H&YmuYX$kEQ$w_VV8RtND}i;PtxV;iC$ z7SvT=&{75K-8|OSq2p_jL^YClChBeL&n-c{dkO3R!}=dDLA@8o`V6f<%|%6N(?81d zTLi#Hk_$qdCIB|y=Gsu%5IyB=;ApocIB60c0jf11+s7nkryG#q#Herd11KL@cab`U5J2fPxr+4txz!BrC+0h?K;^c7)2*f&b8d07Cqs z$}18QTTz)hVLNSM=tX(=-rdFblDxKB4bbt~JEA)}e;Ba{Z%(FNO-4X~pcDdRvy-{8 zgCO&QH3SZv0&$*!xGxUCyS}Rk)(N_Tw)bUZKb}`@LUS?Q&}0PEo*OuJ-ckNF74o8Y zc+zjX%Evg5GK#TkV4N{hriM2$ekh}~>b$|!oHsj*yWx$XT!iP0)BK6;-m>~4{X8vCF1@`%t gme$#&CKDGSH2&}K&#Pe*K&ZY7zD?t2Mc3~91<4$V;Q#;t diff --git a/tests/test_method/test_base/test_BQAE.py b/tests/test_method/test_base/test_BQAE.py deleted file mode 100644 index 426b08f..0000000 --- a/tests/test_method/test_base/test_BQAE.py +++ /dev/null @@ -1,66 +0,0 @@ -import numpy as np -import pytest -import scipy.sparse as sp -from sklearn.linear_model import LogisticRegression - -from quacc.method.base import BinaryQuantifierAccuracyEstimator - - -class TestBQAE: - @pytest.mark.parametrize( - "instances,preds0,preds1,result", - [ - ( - np.asarray( - [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]] - ), - np.asarray([0.3, 0.7]), - np.asarray([0.4, 0.6]), - np.asarray([0.15, 0.2, 0.35, 0.3]), - ), - ( - sp.csr_matrix( - [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]] - ), - np.asarray([0.3, 0.7]), - np.asarray([0.4, 0.6]), - np.asarray([0.15, 0.2, 0.35, 0.3]), - ), - ( - np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - np.asarray([0.3, 0.7]), - np.asarray([0.4, 0.6]), - np.asarray([0.0, 0.4, 0.0, 0.6]), - ), - ( - sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - np.asarray([0.3, 0.7]), - np.asarray([0.4, 0.6]), - np.asarray([0.0, 0.4, 0.0, 0.6]), - ), - ( - np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - np.asarray([0.3, 0.7]), - np.asarray([0.4, 0.6]), - np.asarray([0.3, 0.0, 0.7, 0.0]), - ), - ( - sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - np.asarray([0.3, 0.7]), - np.asarray([0.4, 0.6]), - np.asarray([0.3, 0.0, 0.7, 0.0]), - ), - ], - ) - def test_estimate_ndarray(self, mocker, instances, preds0, preds1, result): - estimator = BinaryQuantifierAccuracyEstimator(LogisticRegression()) - estimator.n_classes = 4 - with mocker.patch.object(estimator.q_model_0, "quantify"), mocker.patch.object( - estimator.q_model_1, "quantify" - ): - estimator.q_model_0.quantify.return_value = preds0 - estimator.q_model_1.quantify.return_value = preds1 - assert np.array_equal( - estimator.estimate(instances, ext=True), - result, - ) diff --git a/tests/test_method/test_base/test_MCAE.py b/tests/test_method/test_base/test_MCAE.py deleted file mode 100644 index a6cae5e..0000000 --- a/tests/test_method/test_base/test_MCAE.py +++ /dev/null @@ -1,2 +0,0 @@ -class TestMCAE: - pass