From de48da638aaa820e1aebed86e7bd86da673ab05c Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Fri, 8 Dec 2023 00:44:02 +0100 Subject: [PATCH] added tests for data.py --- tests/test_data.py | 392 ++++++++++++++++++++++++++++----------------- 1 file changed, 247 insertions(+), 145 deletions(-) diff --git a/tests/test_data.py b/tests/test_data.py index c5a383f..69124a5 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -2,184 +2,286 @@ import numpy as np import pytest import scipy.sparse as sp -from quacc.data import ExtendedCollection +from quacc.data import ( + ExtendedCollection, + ExtendedData, + ExtendedLabels, + ExtendedPrev, + ExtensionPolicy, +) -class TestExtendedCollection: +@pytest.mark.ext +@pytest.mark.extpol +class TestExtendedPolicy: @pytest.mark.parametrize( - "instances,result", + "extpol,nbcl,result", [ + (ExtensionPolicy(), 2, np.array([0, 1, 2, 3])), + (ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2])), + (ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])), + (ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3])), + ], + ) + def test_qclasses(self, extpol, nbcl, result): + assert (result == extpol.qclasses(nbcl)).all() + + @pytest.mark.parametrize( + "extpol,nbcl,result", + [ + (ExtensionPolicy(), 2, np.array([0, 1, 2, 3])), + (ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2, 3])), + (ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])), ( - np.asarray( - [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]] - ), - [np.asarray([1, 3]), np.asarray([0, 2])], - ), - ( - sp.csr_matrix( - [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]] - ), - [np.asarray([1, 3]), np.asarray([0, 2])], - ), - ( - np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - [np.asarray([], dtype=int), np.asarray([0, 1])], - ), - ( - sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - [np.asarray([], dtype=int), np.asarray([0, 1])], - ), - ( - np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - [np.asarray([0, 1]), np.asarray([], dtype=int)], - ), - ( - sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - [np.asarray([0, 1]), np.asarray([], dtype=int)], + ExtensionPolicy(collapse_false=True), + 3, + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]), ), ], ) - def test__split_index_by_pred(self, instances, result): - ncl = 2 - assert all( - np.array_equal(a, b) - for (a, b) in zip( - ExtendedCollection._split_index_by_pred(ncl, instances), - result, - ) - ) + def test_eclasses(self, extpol, nbcl, result): + assert (result == extpol.eclasses(nbcl)).all() @pytest.mark.parametrize( - "instances,s_inst,norms", + "extpol,nbcl,result", [ ( - np.asarray( - [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]] + ExtensionPolicy(), + 2, + ( + np.array([0, 0, 1, 1]), + np.array([0, 1, 0, 1]), ), - [ - np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - ], - [0.5, 0.5], ), ( - sp.csr_matrix( - [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]] + ExtensionPolicy(collapse_false=True), + 2, + ( + np.array([0, 1, 0]), + np.array([0, 1, 1]), ), - [ - sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - ], - [0.5, 0.5], ), ( - np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - [ - np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - np.asarray([], dtype=int), - ], - [1.0, 0.0], + ExtensionPolicy(), + 3, + ( + np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]), + np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]), + ), ), ( - sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - [ - sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - sp.csr_matrix([], dtype=int), - ], - [1.0, 0.0], - ), - ( - np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - [ - np.asarray([], dtype=int), - np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - ], - [0.0, 1.0], - ), - ( - sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - [ - sp.csr_matrix([], dtype=int), - sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - ], - [0.0, 1.0], + ExtensionPolicy(collapse_false=True), + 3, + ( + np.array([0, 1, 2, 0]), + np.array([0, 1, 2, 1]), + ), ), ], ) - def test_split_inst_by_pred(self, instances, s_inst, norms): - ncl = 2 - _s_inst, _norms = ExtendedCollection.split_inst_by_pred(ncl, instances) - if isinstance(s_inst, np.ndarray): - assert all(np.array_equal(a, b) for (a, b) in zip(_s_inst, s_inst)) - if isinstance(s_inst, sp.csr_matrix): - assert all((a != b).nnz == 0 for (a, b) in zip(_s_inst, s_inst)) - assert all(a == b for (a, b) in zip(_norms, norms)) + def test_matrix_idx(self, extpol, nbcl, result): + _midx = extpol.matrix_idx(nbcl) + assert len(_midx) == len(result) + assert all((idx == r).all() for idx, r in zip(_midx, result)) @pytest.mark.parametrize( - "instances,labels,inst0,lbl0,inst1,lbl1", + "extpol,nbcl,true,pred,result", [ ( - np.asarray( - [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]] - ), - np.asarray([3, 0, 1, 2]), - np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - np.asarray([0, 1]), - np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - np.asarray([1, 0]), + ExtensionPolicy(), + 2, + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + np.array([3, 0, 2, 3, 1, 0]), ), ( - sp.csr_matrix( - [[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]] - ), - np.asarray([3, 0, 1, 2]), - sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - np.asarray([0, 1]), - sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - np.asarray([1, 0]), + ExtensionPolicy(collapse_false=True), + 2, + np.array([1, 0, 1, 1, 0, 0]), + np.array([1, 0, 0, 1, 1, 0]), + np.array([1, 0, 2, 1, 2, 0]), ), ( - np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - np.asarray([3, 1]), - np.asarray([], dtype=int), - np.asarray([], dtype=int), - np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - np.asarray([1, 0]), + ExtensionPolicy(), + 3, + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + np.array([4, 6, 0, 3, 1, 7, 2, 5, 8]), ), ( - sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - np.asarray([3, 1]), - sp.csr_matrix(np.empty((0, 0), dtype=int)), - np.asarray([], dtype=int), - sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]), - np.asarray([1, 0]), - ), - ( - np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - np.asarray([0, 2]), - np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - np.asarray([0, 1]), - np.asarray([], dtype=int), - np.asarray([], dtype=int), - ), - ( - sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - np.asarray([0, 2]), - sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]), - np.asarray([0, 1]), - sp.csr_matrix(np.empty((0, 0), dtype=int)), - np.asarray([], dtype=int), + ExtensionPolicy(collapse_false=True), + 3, + np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]), + np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]), + np.array([1, 3, 0, 3, 3, 3, 3, 3, 2]), ), ], ) - def test_split_by_pred(self, instances, labels, inst0, lbl0, inst1, lbl1): - ec = ExtendedCollection(instances, labels, classes=range(0, 4)) - [ec0, ec1] = ec.split_by_pred() - if isinstance(instances, np.ndarray): - assert np.array_equal(ec0.X, inst0) - assert np.array_equal(ec1.X, inst1) - if isinstance(instances, sp.csr_matrix): - assert (ec0.X != inst0).nnz == 0 - assert (ec1.X != inst1).nnz == 0 - assert np.array_equal(ec0.y, lbl0) - assert np.array_equal(ec1.y, lbl1) + def test_ext_lbl(self, extpol, nbcl, true, pred, result): + vfun = extpol.ext_lbl(nbcl) + assert (vfun(true, pred) == result).all() + + +@pytest.mark.ext +@pytest.mark.extd +class TestExtendedData: + @pytest.mark.parametrize( + "pred_proba,result", + [ + ( + np.array([[0.3, 0.7], [0.54, 0.46], [0.28, 0.72], [0.6, 0.4]]), + [np.array([1, 3]), np.array([0, 2])], + ), + ( + np.array([[0.3, 0.7], [0.28, 0.72]]), + [np.array([]), np.array([0, 1])], + ), + ( + np.array([[0.54, 0.46], [0.6, 0.4]]), + [np.array([0, 1]), np.array([])], + ), + ( + np.array( + [ + [0.25, 0.4, 0.35], + [0.24, 0.3, 0.46], + [0.61, 0.28, 0.11], + [0.4, 0.1, 0.5], + ] + ), + [np.array([2]), np.array([0]), np.array([1, 3])], + ), + ], + ) + def test__split_index_by_pred(self, monkeypatch, pred_proba, result): + def mockinit(self, pred_proba): + self.pred_proba_ = pred_proba + + monkeypatch.setattr(ExtendedData, "__init__", mockinit) + ed = ExtendedData(pred_proba) + _split_index = ed._ExtendedData__split_index_by_pred() + assert len(_split_index) == len(result) + assert all((a == b).all() for (a, b) in zip(_split_index, result)) + + +@pytest.mark.ext +@pytest.mark.extl +class TestExtendedLabels: + @pytest.mark.parametrize( + "true,pred,nbcl,extpol,result", + [ + ( + np.array([1, 0, 0, 1, 1]), + np.array([1, 1, 0, 0, 1]), + 2, + ExtensionPolicy(), + np.array([3, 1, 0, 2, 3]), + ), + ( + np.array([1, 0, 0, 1, 1]), + np.array([1, 1, 0, 0, 1]), + 2, + ExtensionPolicy(collapse_false=True), + np.array([1, 2, 0, 2, 1]), + ), + ], + ) + def test_y(self, true, pred, nbcl, extpol, result): + el = ExtendedLabels(true, pred, nbcl, extpol) + assert (el.y == result).all() + + +@pytest.mark.ext +@pytest.mark.extp +class TestExtendedPrev: + @pytest.mark.parametrize( + "flat,nbcl,extpol,q_classes,result", + [ + ( + np.array([0.2, 0, 0.8, 0]), + 2, + ExtensionPolicy(), + [0, 1, 2, 3], + np.array([0.2, 0, 0.8, 0]), + ), + ( + np.array([0.2, 0.8]), + 2, + ExtensionPolicy(), + [0, 3], + np.array([0.2, 0, 0, 0.8]), + ), + ( + np.array([0.2, 0.8]), + 2, + ExtensionPolicy(collapse_false=True), + [0, 2], + np.array([0.2, 0, 0.8]), + ), + ( + np.array([0.1, 0.1, 0.6, 0.2]), + 3, + ExtensionPolicy(), + [0, 1, 3, 5], + np.array([0.1, 0.1, 0, 0.6, 0, 0.2, 0, 0, 0]), + ), + ( + np.array([0.1, 0.1, 0.6]), + 3, + ExtensionPolicy(collapse_false=True), + [0, 1, 2], + np.array([0.1, 0.1, 0.6, 0]), + ), + ], + ) + def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result): + def mockinit(self, flat, nbcl, extpol): + self.flat = flat + self.nbcl = nbcl + self.extpol = extpol + + monkeypatch.setattr(ExtendedPrev, "__init__", mockinit) + ep = ExtendedPrev(flat, nbcl, extpol) + ep._ExtendedPrev__check_q_classes(q_classes) + assert (ep.flat == result).all() + + @pytest.mark.parametrize( + "flat,nbcl,extpol,result", + [ + ( + np.array([0.05, 0.1, 0.6, 0.25]), + 2, + ExtensionPolicy(), + np.array([[0.05, 0.1], [0.6, 0.25]]), + ), + ( + np.array([0.05, 0.1, 0.85]), + 2, + ExtensionPolicy(collapse_false=True), + np.array([[0.05, 0.85], [0, 0.1]]), + ), + ( + np.array([0.05, 0.1, 0.2, 0.15, 0.04, 0.06, 0.15, 0.14, 0.1]), + 3, + ExtensionPolicy(), + np.array([[0.05, 0.1, 0.2], [0.15, 0.04, 0.06], [0.15, 0.14, 0.1]]), + ), + ( + np.array([0.05, 0.2, 0.65, 0.1]), + 3, + ExtensionPolicy(collapse_false=True), + np.array([[0.05, 0.1, 0], [0, 0.2, 0], [0, 0, 0.65]]), + ), + ], + ) + def test__build_matrix(self, monkeypatch, flat, nbcl, extpol, result): + def mockinit(self, flat, nbcl, extpol): + self.flat = flat + self.nbcl = nbcl + self.extpol = extpol + + monkeypatch.setattr(ExtendedPrev, "__init__", mockinit) + ep = ExtendedPrev(flat, nbcl, extpol) + _matrix = ep._ExtendedPrev__build_matrix() + assert _matrix.shape == result.shape + assert (_matrix == result).all()