tests added, passing
This commit is contained in:
parent
b239b2e38a
commit
db5064dbaf
|
@ -1,130 +1,33 @@
|
|||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.sparse as sp
|
||||
|
||||
from quacc.data import (
|
||||
ExtBinPrev,
|
||||
ExtendedCollection,
|
||||
ExtendedData,
|
||||
ExtendedLabels,
|
||||
ExtendedPrev,
|
||||
ExtensionPolicy,
|
||||
ExtMulPrev,
|
||||
_split_index_by_pred,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.ext
|
||||
@pytest.mark.extpol
|
||||
class TestExtendedPolicy:
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,result",
|
||||
[
|
||||
(ExtensionPolicy(), 2, np.array([0, 1, 2, 3])),
|
||||
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2])),
|
||||
(ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
|
||||
(ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3])),
|
||||
],
|
||||
)
|
||||
def test_qclasses(self, extpol, nbcl, result):
|
||||
assert (result == extpol.qclasses(nbcl)).all()
|
||||
@pytest.fixture
|
||||
def nd_1():
|
||||
return np.arange(12).reshape((4, 3))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,result",
|
||||
[
|
||||
(ExtensionPolicy(), 2, np.array([0, 1, 2, 3])),
|
||||
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2, 3])),
|
||||
(ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
3,
|
||||
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_eclasses(self, extpol, nbcl, result):
|
||||
assert (result == extpol.eclasses(nbcl)).all()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,result",
|
||||
[
|
||||
(
|
||||
ExtensionPolicy(),
|
||||
2,
|
||||
(
|
||||
np.array([0, 0, 1, 1]),
|
||||
np.array([0, 1, 0, 1]),
|
||||
),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
2,
|
||||
(
|
||||
np.array([0, 1, 0]),
|
||||
np.array([0, 1, 1]),
|
||||
),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(),
|
||||
3,
|
||||
(
|
||||
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]),
|
||||
np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]),
|
||||
),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
3,
|
||||
(
|
||||
np.array([0, 1, 2, 0]),
|
||||
np.array([0, 1, 2, 1]),
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_matrix_idx(self, extpol, nbcl, result):
|
||||
_midx = extpol.matrix_idx(nbcl)
|
||||
assert len(_midx) == len(result)
|
||||
assert all((idx == r).all() for idx, r in zip(_midx, result))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,true,pred,result",
|
||||
[
|
||||
(
|
||||
ExtensionPolicy(),
|
||||
2,
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
np.array([3, 0, 2, 3, 1, 0]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
2,
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
np.array([1, 0, 2, 1, 2, 0]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(),
|
||||
3,
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
np.array([4, 6, 0, 3, 1, 7, 2, 5, 8]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
3,
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
np.array([1, 3, 0, 3, 3, 3, 3, 3, 2]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_ext_lbl(self, extpol, nbcl, true, pred, result):
|
||||
vfun = extpol.ext_lbl(nbcl)
|
||||
assert (vfun(true, pred) == result).all()
|
||||
@pytest.fixture
|
||||
def csr_1(nd_1):
|
||||
return sp.csr_matrix(nd_1)
|
||||
|
||||
|
||||
@pytest.mark.ext
|
||||
@pytest.mark.extd
|
||||
class TestExtendedData:
|
||||
class TestData:
|
||||
@pytest.mark.parametrize(
|
||||
"pred_proba,result",
|
||||
[
|
||||
|
@ -153,17 +56,219 @@ class TestExtendedData:
|
|||
),
|
||||
],
|
||||
)
|
||||
def test__split_index_by_pred(self, monkeypatch, pred_proba, result):
|
||||
def mockinit(self, pred_proba):
|
||||
self.pred_proba_ = pred_proba
|
||||
|
||||
monkeypatch.setattr(ExtendedData, "__init__", mockinit)
|
||||
ed = ExtendedData(pred_proba)
|
||||
_split_index = ed._ExtendedData__split_index_by_pred()
|
||||
def test_split_index_by_pred(self, pred_proba, result):
|
||||
_split_index = _split_index_by_pred(pred_proba)
|
||||
assert len(_split_index) == len(result)
|
||||
assert all((a == b).all() for (a, b) in zip(_split_index, result))
|
||||
|
||||
|
||||
@pytest.mark.ext
|
||||
@pytest.mark.extpol
|
||||
class TestExtendedPolicy:
|
||||
# fmt: off
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,result",
|
||||
[
|
||||
(ExtensionPolicy(), 2, np.array([0, 1, 2, 3])),
|
||||
(ExtensionPolicy(group_false=True), 2, np.array([0, 1, 2, 3])),
|
||||
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2])),
|
||||
(ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
|
||||
(ExtensionPolicy(group_false=True), 3, np.array([0, 1, 2, 3, 4, 5])),
|
||||
(ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3])),
|
||||
],
|
||||
)
|
||||
def test_qclasses(self, extpol, nbcl, result):
|
||||
assert (result == extpol.qclasses(nbcl)).all()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,result",
|
||||
[
|
||||
(ExtensionPolicy(), 2, np.array([0, 1, 2, 3])),
|
||||
(ExtensionPolicy(group_false=True), 2, np.array([0, 1, 2, 3])),
|
||||
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1, 2, 3])),
|
||||
(ExtensionPolicy(), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
|
||||
(ExtensionPolicy(group_false=True), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
|
||||
(ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])),
|
||||
],
|
||||
)
|
||||
def test_eclasses(self, extpol, nbcl, result):
|
||||
assert (result == extpol.eclasses(nbcl)).all()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,result",
|
||||
[
|
||||
(ExtensionPolicy(), 2, np.array([0, 1])),
|
||||
(ExtensionPolicy(group_false=True), 2, np.array([0, 1])),
|
||||
(ExtensionPolicy(collapse_false=True), 2, np.array([0, 1])),
|
||||
(ExtensionPolicy(), 3, np.array([0, 1, 2])),
|
||||
(ExtensionPolicy(group_false=True), 3, np.array([0, 1])),
|
||||
(ExtensionPolicy(collapse_false=True), 3, np.array([0, 1, 2])),
|
||||
],
|
||||
)
|
||||
def test_tfp_classes(self, extpol, nbcl, result):
|
||||
assert (result == extpol.tfp_classes(nbcl)).all()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,result",
|
||||
[
|
||||
(
|
||||
ExtensionPolicy(), 2,
|
||||
(np.array([0, 0, 1, 1]), np.array([0, 1, 0, 1])),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(group_false=True), 2,
|
||||
(np.array([0, 1, 1, 0]), np.array([0, 1, 0, 1])),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True), 2,
|
||||
(np.array([0, 1, 0]), np.array([0, 1, 1])),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(), 3,
|
||||
(np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]), np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(group_false=True), 3,
|
||||
(np.array([0, 1, 2, 1, 2, 0]), np.array([0, 1, 2, 0, 1, 2])),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True), 3,
|
||||
(np.array([0, 1, 2, 0]), np.array([0, 1, 2, 1])),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_matrix_idx(self, extpol, nbcl, result):
|
||||
_midx = extpol.matrix_idx(nbcl)
|
||||
assert len(_midx) == len(result)
|
||||
assert all((idx == r).all() for idx, r in zip(_midx, result))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,true,pred,result",
|
||||
[
|
||||
(
|
||||
ExtensionPolicy(), 2,
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
np.array([3, 0, 2, 3, 1, 0]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(group_false=True), 2,
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
np.array([1, 0, 2, 1, 3, 0]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True), 2,
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
np.array([1, 0, 2, 1, 2, 0]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(), 3,
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
np.array([4, 6, 0, 3, 1, 7, 2, 5, 8]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(group_false=True), 3,
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
np.array([1, 3, 0, 3, 4, 4, 5, 5, 2]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True), 3,
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
np.array([1, 3, 0, 3, 3, 3, 3, 3, 2]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_ext_lbl(self, extpol, nbcl, true, pred, result):
|
||||
vfun = extpol.ext_lbl(nbcl)
|
||||
assert (vfun(true, pred) == result).all()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,true,pred,result",
|
||||
[
|
||||
(
|
||||
ExtensionPolicy(), 2,
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(group_false=True), 2,
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
np.array([0, 0, 1, 0, 1, 0]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True), 2,
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(), 3,
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(group_false=True), 3,
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
np.array([0, 1, 0, 1, 1, 1, 1, 1, 0]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True), 3,
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_true_lbl_from_pred(self, extpol, nbcl, true, pred, result):
|
||||
vfun = extpol.true_lbl_from_pred(nbcl)
|
||||
assert (vfun(true, pred) == result).all()
|
||||
# fmt: on
|
||||
|
||||
|
||||
@pytest.mark.ext
|
||||
@pytest.mark.extd
|
||||
class TestExtendedData:
|
||||
@pytest.mark.parametrize(
|
||||
"instances_name,indexes,result",
|
||||
[
|
||||
(
|
||||
"nd_1",
|
||||
[np.array([0, 2]), np.array([1, 3])],
|
||||
[
|
||||
np.array([[0, 1, 2], [6, 7, 8]]),
|
||||
np.array([[3, 4, 5], [9, 10, 11]]),
|
||||
],
|
||||
),
|
||||
(
|
||||
"nd_1",
|
||||
[np.array([0]), np.array([1, 3]), np.array([2])],
|
||||
[
|
||||
np.array([[0, 1, 2]]),
|
||||
np.array([[3, 4, 5], [9, 10, 11]]),
|
||||
np.array([[6, 7, 8]]),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_by_pred(self, instances_name, indexes, result, monkeypatch, request):
|
||||
def mockinit(self):
|
||||
self.instances = request.getfixturevalue(instances_name)
|
||||
|
||||
monkeypatch.setattr(ExtendedData, "__init__", mockinit)
|
||||
d = ExtendedData()
|
||||
split = d.split_by_pred(indexes)
|
||||
assert all([(s == r).all() for s, r in zip(split, result)])
|
||||
|
||||
|
||||
@pytest.mark.ext
|
||||
@pytest.mark.extl
|
||||
class TestExtendedLabels:
|
||||
|
@ -177,6 +282,13 @@ class TestExtendedLabels:
|
|||
ExtensionPolicy(),
|
||||
np.array([3, 1, 0, 2, 3]),
|
||||
),
|
||||
(
|
||||
np.array([1, 0, 0, 1, 1]),
|
||||
np.array([1, 1, 0, 0, 1]),
|
||||
2,
|
||||
ExtensionPolicy(group_false=True),
|
||||
np.array([1, 3, 0, 2, 1]),
|
||||
),
|
||||
(
|
||||
np.array([1, 0, 0, 1, 1]),
|
||||
np.array([1, 1, 0, 0, 1]),
|
||||
|
@ -184,92 +296,128 @@ class TestExtendedLabels:
|
|||
ExtensionPolicy(collapse_false=True),
|
||||
np.array([1, 2, 0, 2, 1]),
|
||||
),
|
||||
(
|
||||
np.array([1, 0, 0, 1, 0, 1, 2, 2, 2]),
|
||||
np.array([1, 1, 0, 0, 2, 2, 2, 0, 1]),
|
||||
3,
|
||||
ExtensionPolicy(),
|
||||
np.array([4, 1, 0, 3, 2, 5, 8, 6, 7]),
|
||||
),
|
||||
(
|
||||
np.array([1, 0, 0, 1, 0, 1, 2, 2, 2]),
|
||||
np.array([1, 1, 0, 0, 2, 2, 2, 0, 1]),
|
||||
3,
|
||||
ExtensionPolicy(group_false=True),
|
||||
np.array([1, 4, 0, 3, 5, 5, 2, 3, 4]),
|
||||
),
|
||||
(
|
||||
np.array([1, 0, 0, 1, 0, 1, 2, 2, 2]),
|
||||
np.array([1, 1, 0, 0, 2, 2, 2, 0, 1]),
|
||||
3,
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
np.array([1, 3, 0, 3, 3, 3, 2, 3, 3]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_y(self, true, pred, nbcl, extpol, result):
|
||||
el = ExtendedLabels(true, pred, nbcl, extpol)
|
||||
assert (el.y == result).all()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extpol,nbcl,indexes,true,pred,result,rcls",
|
||||
[
|
||||
(
|
||||
ExtensionPolicy(),
|
||||
2,
|
||||
[np.array([1, 2, 5]), np.array([0, 3, 4])],
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
[np.array([0, 1, 0]), np.array([1, 1, 0])],
|
||||
np.array([0, 1]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(group_false=True),
|
||||
2,
|
||||
[np.array([1, 2, 5]), np.array([0, 3, 4])],
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
[np.array([0, 1, 0]), np.array([0, 0, 1])],
|
||||
np.array([0, 1]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
2,
|
||||
[np.array([1, 2, 5]), np.array([0, 3, 4])],
|
||||
np.array([1, 0, 1, 1, 0, 0]),
|
||||
np.array([1, 0, 0, 1, 1, 0]),
|
||||
[np.array([0, 1, 0]), np.array([1, 1, 0])],
|
||||
np.array([0, 1]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(),
|
||||
3,
|
||||
[np.array([1, 2, 3]), np.array([0, 4, 5]), np.array([6, 7, 8])],
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
[np.array([2, 0, 1]), np.array([1, 0, 2]), np.array([0, 1, 2])],
|
||||
np.array([0, 1, 2]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(group_false=True),
|
||||
3,
|
||||
[np.array([1, 2, 3]), np.array([0, 4, 5]), np.array([6, 7, 8])],
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
[np.array([1, 0, 1]), np.array([0, 1, 1]), np.array([1, 1, 0])],
|
||||
np.array([0, 1]),
|
||||
),
|
||||
(
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
3,
|
||||
[np.array([1, 2, 3]), np.array([0, 4, 5]), np.array([6, 7, 8])],
|
||||
np.array([1, 2, 0, 1, 0, 2, 0, 1, 2]),
|
||||
np.array([1, 0, 0, 0, 1, 1, 2, 2, 2]),
|
||||
[np.array([2, 0, 1]), np.array([1, 0, 2]), np.array([0, 1, 2])],
|
||||
np.array([0, 1, 2]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_by_pred(self, extpol, nbcl, indexes, true, pred, result, rcls):
|
||||
el = ExtendedLabels(true, pred, nbcl, extpol)
|
||||
labels, cls = el.split_by_pred(indexes)
|
||||
assert (cls == rcls).all()
|
||||
assert all([(lbl == r).all() for lbl, r in zip(labels, result)])
|
||||
|
||||
|
||||
@pytest.mark.ext
|
||||
@pytest.mark.extp
|
||||
class TestExtendedPrev:
|
||||
@pytest.mark.parametrize(
|
||||
"flat,nbcl,extpol,q_classes,result",
|
||||
[
|
||||
(
|
||||
np.array([0.2, 0, 0.8, 0]),
|
||||
2,
|
||||
ExtensionPolicy(),
|
||||
[0, 1, 2, 3],
|
||||
np.array([0.2, 0, 0.8, 0]),
|
||||
),
|
||||
(
|
||||
np.array([0.2, 0.8]),
|
||||
2,
|
||||
ExtensionPolicy(),
|
||||
[0, 3],
|
||||
np.array([0.2, 0, 0, 0.8]),
|
||||
),
|
||||
(
|
||||
np.array([0.2, 0.8]),
|
||||
2,
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
[0, 2],
|
||||
np.array([0.2, 0, 0.8]),
|
||||
),
|
||||
(
|
||||
np.array([0.1, 0.1, 0.6, 0.2]),
|
||||
3,
|
||||
ExtensionPolicy(),
|
||||
[0, 1, 3, 5],
|
||||
np.array([0.1, 0.1, 0, 0.6, 0, 0.2, 0, 0, 0]),
|
||||
),
|
||||
(
|
||||
np.array([0.1, 0.1, 0.6]),
|
||||
3,
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
[0, 1, 2],
|
||||
np.array([0.1, 0.1, 0.6, 0]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result):
|
||||
def mockinit(self, flat, nbcl, extpol):
|
||||
self.flat = flat
|
||||
self.nbcl = nbcl
|
||||
self.extpol = extpol
|
||||
|
||||
monkeypatch.setattr(ExtendedPrev, "__init__", mockinit)
|
||||
ep = ExtendedPrev(flat, nbcl, extpol)
|
||||
ep._ExtendedPrev__check_q_classes(q_classes)
|
||||
assert (ep.flat == result).all()
|
||||
|
||||
# fmt: off
|
||||
@pytest.mark.parametrize(
|
||||
"flat,nbcl,extpol,result",
|
||||
[
|
||||
(
|
||||
np.array([0.05, 0.1, 0.6, 0.25]),
|
||||
2,
|
||||
ExtensionPolicy(),
|
||||
np.array([0.05, 0.1, 0.6, 0.25]), 2, ExtensionPolicy(),
|
||||
np.array([[0.05, 0.1], [0.6, 0.25]]),
|
||||
),
|
||||
(
|
||||
np.array([0.05, 0.1, 0.85]),
|
||||
2,
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
np.array([0.05, 0.1, 0.6, 0.25]), 2, ExtensionPolicy(group_false=True),
|
||||
np.array([[0.05, 0.25], [0.6, 0.1]]),
|
||||
),
|
||||
(
|
||||
np.array([0.05, 0.1, 0.85]), 2, ExtensionPolicy(collapse_false=True),
|
||||
np.array([[0.05, 0.85], [0, 0.1]]),
|
||||
),
|
||||
(
|
||||
np.array([0.05, 0.1, 0.2, 0.15, 0.04, 0.06, 0.15, 0.14, 0.1]),
|
||||
3,
|
||||
ExtensionPolicy(),
|
||||
np.array([0.05, 0.1, 0.2, 0.15, 0.04, 0.06, 0.15, 0.14, 0.1]), 3, ExtensionPolicy(),
|
||||
np.array([[0.05, 0.1, 0.2], [0.15, 0.04, 0.06], [0.15, 0.14, 0.1]]),
|
||||
),
|
||||
(
|
||||
np.array([0.05, 0.2, 0.65, 0.1]),
|
||||
3,
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
np.array([0.15, 0.2, 0.15, 0.1, 0.15, 0.25]), 3, ExtensionPolicy(group_false=True),
|
||||
np.array([[0.15, 0.0, 0.25], [0.1, 0.2, 0.0], [0.0, 0.15, 0.15]]),
|
||||
),
|
||||
(
|
||||
np.array([0.05, 0.2, 0.65, 0.1]), 3, ExtensionPolicy(collapse_false=True),
|
||||
np.array([[0.05, 0.1, 0], [0, 0.2, 0], [0, 0, 0.65]]),
|
||||
),
|
||||
],
|
||||
|
@ -285,3 +433,130 @@ class TestExtendedPrev:
|
|||
_matrix = ep._ExtendedPrev__build_matrix()
|
||||
assert _matrix.shape == result.shape
|
||||
assert (_matrix == result).all()
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
@pytest.mark.ext
|
||||
@pytest.mark.extp
|
||||
class TestExtMulPrev:
|
||||
# fmt: off
|
||||
@pytest.mark.parametrize(
|
||||
"flat,nbcl,extpol,q_classes,result",
|
||||
[
|
||||
(np.array([0.2, 0, 0.8, 0]), 2, ExtensionPolicy(), [0, 1, 2, 3], np.array([0.2, 0, 0.8, 0])),
|
||||
(np.array([0.2, 0.8]), 2, ExtensionPolicy(), [0, 3], np.array([0.2, 0, 0, 0.8])),
|
||||
(np.array([0.2, 0.8]), 2, ExtensionPolicy(group_false=True), [0, 3], np.array([0.2, 0, 0, 0.8])),
|
||||
(np.array([0.2, 0.8]), 2, ExtensionPolicy(collapse_false=True), [0, 2], np.array([0.2, 0, 0.8])),
|
||||
(np.array([0.1, 0.1, 0.6, 0.2]), 3, ExtensionPolicy(), [0, 1, 3, 5], np.array([0.1, 0.1, 0, 0.6, 0, 0.2, 0, 0, 0])),
|
||||
(np.array([0.1, 0.1, 0.6, 0.2]), 3, ExtensionPolicy(group_false=True), [0, 1, 3, 5], np.array([0.1, 0.1, 0, 0.6, 0, 0.2])),
|
||||
(np.array([0.1, 0.1, 0.6]), 3, ExtensionPolicy(collapse_false=True), [0, 1, 2], np.array([0.1, 0.1, 0.6, 0])),
|
||||
],
|
||||
)
|
||||
def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result):
|
||||
def mockinit(self, nbcl, extpol):
|
||||
self.nbcl = nbcl
|
||||
self.extpol = extpol
|
||||
|
||||
monkeypatch.setattr(ExtMulPrev, "__init__", mockinit)
|
||||
ep = ExtMulPrev(nbcl, extpol)
|
||||
_flat = ep._ExtMulPrev__check_q_classes(q_classes, flat)
|
||||
assert (_flat == result).all()
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
@pytest.mark.ext
|
||||
@pytest.mark.extp
|
||||
class TestExtBinPrev:
|
||||
# fmt: off
|
||||
@pytest.mark.parametrize(
|
||||
"flat,nbcl,extpol,q_classes,result",
|
||||
[
|
||||
([np.array([0.2, 0]), np.array([0.8, 0])], 2, ExtensionPolicy(), [[0, 1], [0, 1]], np.array([[0.2, 0], [0.8, 0]])),
|
||||
([np.array([0.2]), np.array([0.8])], 2, ExtensionPolicy(), [[0], [1]], np.array([[0.2, 0], [0, 0.8]])),
|
||||
([np.array([0.2]), np.array([0.8])], 2, ExtensionPolicy(group_false=True), [[0], [1]], np.array([[0.2, 0], [0, 0.8]])),
|
||||
([np.array([0.2]), np.array([0.8])], 2, ExtensionPolicy(collapse_false=True), [[0], [1]], np.array([[0.2, 0], [0, 0.8]])),
|
||||
([np.array([0.1, 0.1]), np.array([0.6]), np.array([0.2])], 3, ExtensionPolicy(), [[0, 1], [0], [2]], np.array([[0.1, 0.1, 0], [0.6, 0, 0], [0, 0, 0.2]])),
|
||||
([np.array([0.1, 0.1]), np.array([0.6]), np.array([0.2])], 3, ExtensionPolicy(group_false=True), [[0, 1], [0], [1]], np.array([[0.1, 0.1], [0.6, 0], [0, 0.2]])),
|
||||
([np.array([0.1, 0.1]), np.array([0.6]), np.array([0.2])], 3, ExtensionPolicy(collapse_false=True), [[0, 1], [0], [2]], np.array([[0.1, 0.1, 0], [0.6, 0, 0], [0, 0, 0.2]])),
|
||||
],
|
||||
)
|
||||
def test__check_q_classes(self, monkeypatch, flat, nbcl, extpol, q_classes, result):
|
||||
def mockinit(self, nbcl, extpol):
|
||||
self.nbcl = nbcl
|
||||
self.extpol = extpol
|
||||
|
||||
monkeypatch.setattr(ExtBinPrev, "__init__", mockinit)
|
||||
ep = ExtBinPrev(nbcl, extpol)
|
||||
_flat = ep._ExtBinPrev__check_q_classes(q_classes, flat)
|
||||
assert (_flat == result).all()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"flat,result",
|
||||
[
|
||||
(np.array([[0.2, 0], [0.8, 0]]), np.array([0.2, 0.8, 0, 0])),
|
||||
(np.array([[0.2, 0], [0, 0.8]]), np.array([0.2, 0, 0, 0.8])),
|
||||
(np.array([[0.1, 0.1, 0], [0.6, 0, 0], [0, 0, 0.2]]), np.array([0.1, 0.6, 0, 0.1, 0, 0, 0, 0, 0.2])),
|
||||
(np.array([[0.1, 0.1], [0.6, 0], [0, 0.2]]), np.array([0.1, 0.6, 0, 0.1, 0, 0.2])),
|
||||
],
|
||||
)
|
||||
def test__build_flat(self, monkeypatch, flat, result):
|
||||
def mockinit(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(ExtBinPrev, "__init__", mockinit)
|
||||
ep = ExtBinPrev()
|
||||
_flat = ep._ExtBinPrev__build_flat(flat)
|
||||
assert (_flat == result).all()
|
||||
# fmt: on
|
||||
|
||||
|
||||
@pytest.mark.ext
|
||||
@pytest.mark.extc
|
||||
class TestExtendedCollection:
|
||||
@pytest.mark.parametrize(
|
||||
"instances_name,labels,pred_proba,extpol,result",
|
||||
[
|
||||
(
|
||||
"nd_1",
|
||||
np.array([0, 1, 1, 0]),
|
||||
np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]),
|
||||
ExtensionPolicy(),
|
||||
np.array([0, 0.5, 0.25, 0.25]),
|
||||
),
|
||||
(
|
||||
"nd_1",
|
||||
np.array([0, 1, 1, 0]),
|
||||
np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]),
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
np.array([0, 0.25, 0.75]),
|
||||
),
|
||||
(
|
||||
"csr_1",
|
||||
np.array([0, 1, 1, 0]),
|
||||
np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]),
|
||||
ExtensionPolicy(),
|
||||
np.array([0, 0.5, 0.25, 0.25]),
|
||||
),
|
||||
(
|
||||
"csr_1",
|
||||
np.array([0, 1, 1, 0]),
|
||||
np.array([[0.2, 0.8], [0.3, 0.7], [0.9, 0.1], [0.45, 0.55]]),
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
np.array([0, 0.25, 0.75]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_prevalence(
|
||||
self, instances_name, labels, pred_proba, extpol, result, request
|
||||
):
|
||||
instances = request.getfixturevalue(instances_name)
|
||||
ec = ExtendedCollection(
|
||||
instances=instances,
|
||||
labels=labels,
|
||||
pred_proba=pred_proba,
|
||||
ext=pred_proba,
|
||||
extpol=extpol,
|
||||
)
|
||||
assert (ec.prevalence() == result).all()
|
||||
|
|
|
@ -1,3 +1,135 @@
|
|||
import os
|
||||
from contextlib import redirect_stderr
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from quacc.dataset import Dataset
|
||||
|
||||
|
||||
@pytest.mark.dataset
|
||||
class TestDataset:
|
||||
pass
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"name,target,prevalence",
|
||||
[
|
||||
("spambase", None, [0.5, 0.5]),
|
||||
("imdb", None, [0.5, 0.5]),
|
||||
("rcv1", "CCAT", [0.5, 0.5]),
|
||||
("cifar10", "dog", [0.5, 0.5]),
|
||||
("twitter_gasp", None, [0.33, 0.33, 0.33]),
|
||||
],
|
||||
)
|
||||
def test__resample_all_train(self, name, target, prevalence, monkeypatch):
|
||||
def mockinit(self):
|
||||
self._name = name
|
||||
self._target = target
|
||||
self.all_train, self.test = self.alltrain_test(self._name, self._target)
|
||||
|
||||
monkeypatch.setattr(Dataset, "__init__", mockinit)
|
||||
with open(os.devnull, "w") as dn:
|
||||
with redirect_stderr(dn):
|
||||
d = Dataset()
|
||||
d._Dataset__resample_all_train()
|
||||
assert (
|
||||
np.around(d.all_train.prevalence(), decimals=2).tolist()
|
||||
== prevalence
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ncl, prevs,result",
|
||||
[
|
||||
(2, None, None),
|
||||
(2, [], None),
|
||||
(2, [[0.2, 0.1], [0.3, 0.2]], None),
|
||||
(2, [[0.2, 0.8], [0.3, 0.7]], [[0.2, 0.8], [0.3, 0.7]]),
|
||||
(2, [1.0, 2.0, 3.0], None),
|
||||
(2, [1, 2, 3], None),
|
||||
(2, [[1, 2], [2, 3], [3, 4]], None),
|
||||
(2, ["abc", "def"], None),
|
||||
(3, [[0.2, 0.3], [0.4, 0.1], [0.5, 0.2]], None),
|
||||
(3, [[0.2, 0.3, 0.2], [0.4, 0.1], [0.5, 0.6]], None),
|
||||
(2, [[0.2, 0.3, 0.1], [0.1, 0.5, 0.3]], None),
|
||||
(3, [[0.2, 0.3, 0.1], [0.1, 0.5, 0.3]], None),
|
||||
(3, [[0.2, 0.8], [0.1, 0.5]], None),
|
||||
(2, [[0.2, 0.9], [0.1, 0.5]], None),
|
||||
(2, 10, None),
|
||||
(2, [[0.2, 0.8], [0.5, 0.5]], [[0.2, 0.8], [0.5, 0.5]]),
|
||||
(3, [[0.2, 0.6], [0.3, 0.5]], None),
|
||||
],
|
||||
)
|
||||
def test__check_prevs(self, ncl, prevs, result, monkeypatch):
|
||||
class MockLabelledCollection:
|
||||
def __init__(self):
|
||||
self.n_classes = ncl
|
||||
|
||||
def mockinit(self):
|
||||
self.all_train = MockLabelledCollection()
|
||||
self.prevs = None
|
||||
|
||||
monkeypatch.setattr(Dataset, "__init__", mockinit)
|
||||
d = Dataset()
|
||||
d._Dataset__check_prevs(prevs)
|
||||
_prevs = d.prevs if d.prevs is None else d.prevs.tolist()
|
||||
assert _prevs == result
|
||||
|
||||
# fmt: off
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ncl,nprevs,built,result",
|
||||
[
|
||||
(2, 3, None, [[0.25, 0.75], [0.5, 0.5], [0.75, 0.25]]),
|
||||
(2, 3, np.array([[0.8, 0.2], [0.6, 0.4], [0.4, 0.6]]), [[0.8, 0.2], [0.6, 0.4], [0.4, 0.6]]),
|
||||
(2, 3, np.array([[0.75, 0.25], [0.5, 0.5], [0.25, 0.75]]), [[0.75, 0.25], [0.5, 0.5], [0.25, 0.75]]),
|
||||
(3, 3, None, [[0.25, 0.25, 0.5], [0.25, 0.5, 0.25], [0.5, 0.25, 0.25]]),
|
||||
(
|
||||
3, 4, None,
|
||||
[[0.2, 0.2, 0.6], [0.2, 0.4, 0.4], [0.2, 0.6, 0.2], [0.4, 0.2, 0.4], [0.4, 0.4, 0.2], [0.6, 0.2, 0.2]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test__build_prevs(self, ncl, nprevs, built, result, monkeypatch):
|
||||
class MockLabelledCollection:
|
||||
def __init__(self):
|
||||
self.n_classes = ncl
|
||||
|
||||
def mockinit(self):
|
||||
self.all_train = MockLabelledCollection()
|
||||
self.prevs = built
|
||||
self._n_prevs = nprevs
|
||||
|
||||
monkeypatch.setattr(Dataset, "__init__", mockinit)
|
||||
d = Dataset()
|
||||
_prevs = d._Dataset__build_prevs().tolist()
|
||||
assert _prevs == result
|
||||
|
||||
# fmt: on
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ncl,prevs,atsize",
|
||||
[
|
||||
(2, np.array([[0.2, 0.8], [0.9, 0.1]]), 55),
|
||||
(3, np.array([[0.2, 0.7, 0.1], [0.9, 0.05, 0.05]]), 37),
|
||||
],
|
||||
)
|
||||
def test_get(self, ncl, prevs, atsize, monkeypatch):
|
||||
class MockLabelledCollection:
|
||||
def __init__(self):
|
||||
self.n_classes = ncl
|
||||
|
||||
def __len__(self):
|
||||
return 100
|
||||
|
||||
def mockinit(self):
|
||||
self.prevs = prevs
|
||||
self.all_train = MockLabelledCollection()
|
||||
|
||||
def mock_build_sample(self, p, at_size):
|
||||
return at_size
|
||||
|
||||
monkeypatch.setattr(Dataset, "__init__", mockinit)
|
||||
monkeypatch.setattr(Dataset, "_Dataset__build_sample", mock_build_sample)
|
||||
d = Dataset()
|
||||
_get = d.get()
|
||||
assert all(s == atsize for s in _get)
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from quacc import error
|
||||
from quacc.data import ExtendedPrev, ExtensionPolicy
|
||||
|
||||
|
||||
@pytest.mark.err
|
||||
class TestError:
|
||||
@pytest.mark.parametrize(
|
||||
"prev,result",
|
||||
[
|
||||
(np.array([[1, 4], [4, 4]]), 0.5),
|
||||
(np.array([[6, 2, 4], [2, 4, 2], [4, 2, 6]]), 0.5),
|
||||
],
|
||||
)
|
||||
def test_f1(self, prev, result):
|
||||
ep = ExtendedPrev(prev.flatten(), prev.shape[0], extpol=ExtensionPolicy())
|
||||
assert error.f1(prev) == result
|
||||
assert error.f1(ep) == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prev,result",
|
||||
[
|
||||
(np.array([[4, 4], [4, 4]]), 0.5),
|
||||
(np.array([[2, 4, 2], [2, 2, 4], [4, 2, 2]]), 0.25),
|
||||
],
|
||||
)
|
||||
def test_acc(self, prev, result):
|
||||
ep = ExtendedPrev(prev.flatten(), prev.shape[0], extpol=ExtensionPolicy())
|
||||
assert error.acc(prev) == result
|
||||
assert error.acc(ep) == result
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"true_prev,estim_prev,nbcl,extpol,result",
|
||||
[
|
||||
(
|
||||
[
|
||||
np.array([0.2, 0.4, 0.1, 0.3]),
|
||||
np.array([0.1, 0.5, 0.1, 0.3]),
|
||||
],
|
||||
[
|
||||
np.array([0.3, 0.4, 0.2, 0.1]),
|
||||
np.array([0.5, 0.3, 0.1, 0.1]),
|
||||
],
|
||||
2,
|
||||
ExtensionPolicy(),
|
||||
np.array([0.1, 0.2]),
|
||||
),
|
||||
(
|
||||
[
|
||||
np.array([0.2, 0.4, 0.4]),
|
||||
np.array([0.1, 0.5, 0.4]),
|
||||
],
|
||||
[
|
||||
np.array([0.3, 0.4, 0.3]),
|
||||
np.array([0.5, 0.3, 0.2]),
|
||||
],
|
||||
2,
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
np.array([0.1, 0.2]),
|
||||
),
|
||||
(
|
||||
[
|
||||
np.array([0.02, 0.04, 0.16, 0.38, 0.1, 0.05, 0.15, 0.08, 0.02]),
|
||||
np.array([0.04, 0.02, 0.14, 0.40, 0.1, 0.03, 0.17, 0.07, 0.03]),
|
||||
],
|
||||
[
|
||||
np.array([0.02, 0.04, 0.16, 0.48, 0.0, 0.05, 0.15, 0.08, 0.02]),
|
||||
np.array([0.14, 0.02, 0.04, 0.30, 0.2, 0.03, 0.17, 0.07, 0.03]),
|
||||
],
|
||||
3,
|
||||
ExtensionPolicy(),
|
||||
np.array([0.1, 0.2]),
|
||||
),
|
||||
(
|
||||
[
|
||||
np.array([0.2, 0.4, 0.2, 0.2]),
|
||||
np.array([0.1, 0.3, 0.2, 0.4]),
|
||||
],
|
||||
[
|
||||
np.array([0.3, 0.3, 0.1, 0.3]),
|
||||
np.array([0.5, 0.2, 0.1, 0.2]),
|
||||
],
|
||||
3,
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
np.array([0.1, 0.2]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_accd(self, true_prev, estim_prev, nbcl, extpol, result):
|
||||
true_prev = [ExtendedPrev(tp, nbcl, extpol=extpol) for tp in true_prev]
|
||||
estim_prev = [ExtendedPrev(ep, nbcl, extpol=extpol) for ep in estim_prev]
|
||||
_err = error.accd(true_prev, estim_prev)
|
||||
assert (np.abs(_err - result) < 1e-15).all()
|
|
@ -0,0 +1,425 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from quacc.evaluation.report import (
|
||||
CompReport,
|
||||
DatasetReport,
|
||||
EvaluationReport,
|
||||
_get_shift,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_er():
|
||||
return EvaluationReport("empty")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def er_list():
|
||||
er1 = EvaluationReport("er1")
|
||||
er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1))
|
||||
er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4))
|
||||
er1.append_row(np.array([0.3, 0.7]), **dict(acc=0.7, acc_score=0.3))
|
||||
er2 = EvaluationReport("er2")
|
||||
er2.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1))
|
||||
er2.append_row(
|
||||
np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4, f1=0.9, f1_score=0.6)
|
||||
)
|
||||
er2.append_row(np.array([0.4, 0.6]), **dict(acc=0.7, acc_score=0.3))
|
||||
return [er1, er2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def er_list2():
|
||||
er1 = EvaluationReport("er12")
|
||||
er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1))
|
||||
er1.append_row(np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4))
|
||||
er1.append_row(np.array([0.3, 0.7]), **dict(acc=0.7, acc_score=0.3))
|
||||
er2 = EvaluationReport("er2")
|
||||
er2.append_row(np.array([0.2, 0.8]), **dict(acc=0.9, acc_score=0.1))
|
||||
er2.append_row(
|
||||
np.array([0.2, 0.8]), **dict(acc=0.6, acc_score=0.4, f1=0.9, f1_score=0.6)
|
||||
)
|
||||
er2.append_row(np.array([0.4, 0.6]), **dict(acc=0.8, acc_score=0.3))
|
||||
return [er1, er2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def er_list3():
|
||||
er1 = EvaluationReport("er31")
|
||||
er1.append_row(np.array([0.2, 0.5, 0.3]), **dict(acc=0.9, acc_score=0.1))
|
||||
er1.append_row(np.array([0.2, 0.4, 0.4]), **dict(acc=0.6, acc_score=0.4))
|
||||
er1.append_row(np.array([0.3, 0.6, 0.1]), **dict(acc=0.7, acc_score=0.3))
|
||||
er2 = EvaluationReport("er32")
|
||||
er2.append_row(np.array([0.2, 0.5, 0.3]), **dict(acc=0.9, acc_score=0.1))
|
||||
er2.append_row(
|
||||
np.array([0.2, 0.5, 0.3]), **dict(acc=0.6, acc_score=0.4, f1=0.9, f1_score=0.6)
|
||||
)
|
||||
er2.append_row(np.array([0.3, 0.3, 0.4]), **dict(acc=0.8, acc_score=0.3))
|
||||
return [er1, er2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cr_1(er_list):
|
||||
return CompReport(
|
||||
er_list,
|
||||
"cr_1",
|
||||
train_prev=np.array([0.2, 0.8]),
|
||||
valid_prev=np.array([0.25, 0.75]),
|
||||
g_time=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cr_2(er_list2):
|
||||
return CompReport(
|
||||
er_list2,
|
||||
"cr_2",
|
||||
train_prev=np.array([0.3, 0.7]),
|
||||
valid_prev=np.array([0.35, 0.65]),
|
||||
g_time=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cr_3(er_list3):
|
||||
return CompReport(
|
||||
er_list3,
|
||||
"cr_3",
|
||||
train_prev=np.array([0.4, 0.1, 0.5]),
|
||||
valid_prev=np.array([0.45, 0.25, 0.2]),
|
||||
g_time=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cr_4(er_list3):
|
||||
return CompReport(
|
||||
er_list3,
|
||||
"cr_4",
|
||||
train_prev=np.array([0.5, 0.1, 0.4]),
|
||||
valid_prev=np.array([0.45, 0.25, 0.2]),
|
||||
g_time=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dr_1(cr_1, cr_2):
|
||||
return DatasetReport("dr_1", [cr_1, cr_2])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dr_2(cr_3, cr_4):
|
||||
return DatasetReport("dr_2", [cr_3, cr_4])
|
||||
|
||||
|
||||
@pytest.mark.rep
|
||||
@pytest.mark.mrep
|
||||
class TestReport:
|
||||
@pytest.mark.parametrize(
|
||||
"cr_name,train_prev,shift",
|
||||
[
|
||||
(
|
||||
"cr_1",
|
||||
np.array([0.2, 0.8]),
|
||||
np.array([0.2, 0.1, 0.0, 0.0]),
|
||||
),
|
||||
(
|
||||
"cr_3",
|
||||
np.array([0.2, 0.5, 0.3]),
|
||||
np.array([0.2, 0.2, 0.0, 0.0, 0.1]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_shift(self, cr_name, train_prev, shift, request):
|
||||
cr = request.getfixturevalue(cr_name)
|
||||
assert (
|
||||
_get_shift(cr._data.index.get_level_values(0), train_prev) == shift
|
||||
).all()
|
||||
|
||||
|
||||
@pytest.mark.rep
|
||||
@pytest.mark.erep
|
||||
class TestEvaluationReport:
|
||||
def test_init(self, empty_er):
|
||||
assert empty_er.data is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rows,index,columns,data",
|
||||
[
|
||||
(
|
||||
[
|
||||
(np.array([0.2, 0.8]), dict(acc=0.9, acc_score=0.1)),
|
||||
(np.array([0.2, 0.8]), dict(acc=0.6, acc_score=0.4)),
|
||||
(np.array([0.3, 0.7]), dict(acc=0.7, acc_score=0.3)),
|
||||
],
|
||||
[((0.2, 0.8), 0), ((0.2, 0.8), 1), ((0.3, 0.7), 0)],
|
||||
["acc", "acc_score"],
|
||||
np.array([[0.9, 0.1], [0.6, 0.4], [0.7, 0.3]]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_append_row(self, empty_er, rows, index, columns, data):
|
||||
er: EvaluationReport = empty_er
|
||||
for prev, r in rows:
|
||||
er.append_row(prev, **r)
|
||||
assert er.data.index.to_list() == index
|
||||
assert er.data.columns.to_list() == columns
|
||||
assert (er.data.to_numpy() == data).all()
|
||||
|
||||
|
||||
@pytest.mark.rep
|
||||
@pytest.mark.crep
|
||||
class TestCompReport:
|
||||
@pytest.mark.parametrize(
|
||||
"train_prev,valid_prev,index,columns",
|
||||
[
|
||||
(
|
||||
np.array([0.2, 0.8]),
|
||||
np.array([0.25, 0.75]),
|
||||
[
|
||||
((0.4, 0.6), 0),
|
||||
((0.3, 0.7), 0),
|
||||
((0.2, 0.8), 0),
|
||||
((0.2, 0.8), 1),
|
||||
],
|
||||
[
|
||||
("acc", "er1"),
|
||||
("acc", "er2"),
|
||||
("acc_score", "er1"),
|
||||
("acc_score", "er2"),
|
||||
("f1", "er2"),
|
||||
("f1_score", "er2"),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_init(self, er_list, train_prev, valid_prev, index, columns):
|
||||
cr = CompReport(er_list, "cr", train_prev, valid_prev, g_time=0.0)
|
||||
assert cr._data.index.to_list() == index
|
||||
assert cr._data.columns.to_list() == columns
|
||||
assert (cr.train_prev == train_prev).all()
|
||||
assert (cr.valid_prev == valid_prev).all()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cr_name,prev",
|
||||
[
|
||||
("cr_1", [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_2", [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
(
|
||||
"cr_3",
|
||||
[(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)],
|
||||
),
|
||||
(
|
||||
"cr_4",
|
||||
[(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_prevs(self, cr_name, prev, request):
|
||||
cr = request.getfixturevalue(cr_name)
|
||||
assert cr.prevs.tolist() == prev
|
||||
|
||||
def test_join(self, er_list, er_list2):
|
||||
tp = np.array([0.2, 0.8])
|
||||
vp = np.array([0.25, 0.75])
|
||||
cr1 = CompReport(er_list, "cr1", train_prev=tp, valid_prev=vp)
|
||||
cr2 = CompReport(er_list2, "cr2", train_prev=tp, valid_prev=vp)
|
||||
crj = cr1.join(cr2)
|
||||
_loc = crj._data.loc[((0.4, 0.6), 0), ("acc", "er2")].to_numpy()
|
||||
assert (_loc == np.array([0.8])).all()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cr_name,metric,estimators,columns",
|
||||
[
|
||||
("cr_1", "acc", None, ["er1", "er2"]),
|
||||
("cr_1", "acc", ["er1"], ["er1"]),
|
||||
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"]),
|
||||
("cr_1", "f1", None, ["er2"]),
|
||||
("cr_1", "f1", ["er2"], ["er2"]),
|
||||
("cr_3", "acc", None, ["er31", "er32"]),
|
||||
("cr_3", "acc", ["er31"], ["er31"]),
|
||||
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"]),
|
||||
("cr_3", "f1", None, ["er32"]),
|
||||
("cr_3", "f1", ["er32"], ["er32"]),
|
||||
],
|
||||
)
|
||||
def test_data(self, cr_name, metric, estimators, columns, request):
|
||||
cr = request.getfixturevalue(cr_name)
|
||||
_data = cr.data(metric=metric, estimators=estimators)
|
||||
assert _data.columns.to_list() == columns
|
||||
assert all(_data.index == cr._data.index)
|
||||
|
||||
# fmt: off
|
||||
@pytest.mark.parametrize(
|
||||
"cr_name,metric,estimators,columns,index",
|
||||
[
|
||||
("cr_1", "acc", None, ["er1", "er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]),
|
||||
("cr_1", "acc", ["er1"], ["er1"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]),
|
||||
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]),
|
||||
("cr_1", "f1", None, ["er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]),
|
||||
("cr_1", "f1", ["er2"], ["er2"], [(0.0, 0), (0.0, 1), (0.1, 0), (0.2, 0)]),
|
||||
("cr_3", "acc", None, ["er31", "er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]),
|
||||
("cr_3", "acc", ["er31"], ["er31"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]),
|
||||
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]),
|
||||
("cr_3", "f1", None, ["er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]),
|
||||
("cr_3", "f1", ["er32"], ["er32"], [(0.2, 0), (0.3, 0), (0.4, 0), (0.4, 1), (0.5,0)]),
|
||||
],
|
||||
)
|
||||
def test_shift_data(self, cr_name, metric, estimators, columns, index, request):
|
||||
cr = request.getfixturevalue(cr_name)
|
||||
_data = cr.shift_data(metric=metric, estimators=estimators)
|
||||
assert _data.columns.to_list() == columns
|
||||
assert _data.index.to_list() == index
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cr_name,metric,estimators,columns,index",
|
||||
[
|
||||
("cr_1", "acc", None, ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_1", "acc", ["er1"], ["er1"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_1", "f1", None, ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_1", "f1", ["er2"], ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_3", "acc", None, ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
|
||||
("cr_3", "acc", ["er31"], ["er31"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
|
||||
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
|
||||
("cr_3", "f1", None, ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
|
||||
("cr_3", "f1", ["er32"], ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
|
||||
],
|
||||
)
|
||||
def test_avg_by_prevs(self, cr_name, metric, estimators, columns, index, request):
|
||||
cr = request.getfixturevalue(cr_name)
|
||||
_data = cr.avg_by_prevs(metric=metric, estimators=estimators)
|
||||
assert _data.columns.to_list() == columns
|
||||
assert _data.index.to_list() == index
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cr_name,metric,estimators,columns,index",
|
||||
[
|
||||
("cr_1", "acc", None, ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_1", "acc", ["er1"], ["er1"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_1", "f1", None, ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_1", "f1", ["er2"], ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8)]),
|
||||
("cr_3", "acc", None, ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
|
||||
("cr_3", "acc", ["er31"], ["er31"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
|
||||
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
|
||||
("cr_3", "f1", None, ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
|
||||
("cr_3", "f1", ["er32"], ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4)]),
|
||||
],
|
||||
)
|
||||
def test_stdev_by_prevs(self, cr_name, metric, estimators, columns, index, request):
|
||||
cr = request.getfixturevalue(cr_name)
|
||||
_data = cr.stdev_by_prevs(metric=metric, estimators=estimators)
|
||||
assert _data.columns.to_list() == columns
|
||||
assert _data.index.to_list() == index
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cr_name,metric,estimators,columns,index",
|
||||
[
|
||||
("cr_1", "acc", None, ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]),
|
||||
("cr_1", "acc", ["er1"], ["er1"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]),
|
||||
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]),
|
||||
("cr_1", "f1", None, ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]),
|
||||
("cr_1", "f1", ["er2"], ["er2"], [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), "mean"]),
|
||||
("cr_3", "acc", None, ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]),
|
||||
("cr_3", "acc", ["er31"], ["er31"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]),
|
||||
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]),
|
||||
("cr_3", "f1", None, ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]),
|
||||
("cr_3", "f1", ["er32"], ["er32"], [(0.3, 0.6, 0.1), (0.3, 0.3, 0.4), (0.2, 0.5, 0.3), (0.2, 0.4, 0.4), "mean"]),
|
||||
],
|
||||
)
|
||||
def test_train_table(self, cr_name, metric, estimators, columns, index, request):
|
||||
cr = request.getfixturevalue(cr_name)
|
||||
_data = cr.train_table(metric=metric, estimators=estimators)
|
||||
assert _data.columns.to_list() == columns
|
||||
assert _data.index.to_list() == index
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cr_name,metric,estimators,columns,index",
|
||||
[
|
||||
("cr_1", "acc", None, ["er1", "er2"], [0.0, 0.1, 0.2, "mean"]),
|
||||
("cr_1", "acc", ["er1"], ["er1"], [0.0, 0.1, 0.2, "mean"]),
|
||||
("cr_1", "acc", ["er1", "er2"], ["er1", "er2"], [0.0, 0.1, 0.2, "mean"]),
|
||||
("cr_1", "f1", None, ["er2"], [0.0, 0.1, 0.2, "mean"]),
|
||||
("cr_1", "f1", ["er2"], ["er2"], [0.0, 0.1, 0.2, "mean"]),
|
||||
("cr_3", "acc", None, ["er31", "er32"], [0.2, 0.3, 0.4, 0.5, "mean"]),
|
||||
("cr_3", "acc", ["er31"], ["er31"], [0.2, 0.3, 0.4, 0.5, "mean"]),
|
||||
("cr_3", "acc", ["er31", "er32"], ["er31", "er32"], [0.2, 0.3, 0.4, 0.5, "mean"]),
|
||||
("cr_3", "f1", None, ["er32"], [0.2, 0.3, 0.4, 0.5, "mean"]),
|
||||
("cr_3", "f1", ["er32"], ["er32"], [0.2, 0.3, 0.4, 0.5, "mean"]),
|
||||
],
|
||||
)
|
||||
def test_shift_table(self, cr_name, metric, estimators, columns, index, request):
|
||||
cr = request.getfixturevalue(cr_name)
|
||||
_data = cr.shift_table(metric=metric, estimators=estimators)
|
||||
assert _data.columns.to_list() == columns
|
||||
assert _data.index.to_list() == index
|
||||
# fmt: on
|
||||
|
||||
|
||||
@pytest.mark.rep
|
||||
@pytest.mark.drep
|
||||
class TestDatasetReport:
|
||||
# fmt: off
|
||||
@pytest.mark.parametrize(
|
||||
"dr_name,metric,estimators,columns,index",
|
||||
[
|
||||
(
|
||||
"dr_1", "acc", None, ["er1", "er2", "er12"],
|
||||
[
|
||||
((0.3, 0.7), (0.4, 0.6), 0),
|
||||
((0.3, 0.7), (0.3, 0.7), 0),
|
||||
((0.3, 0.7), (0.2, 0.8), 0),
|
||||
((0.3, 0.7), (0.2, 0.8), 1),
|
||||
((0.2, 0.8), (0.4, 0.6), 0),
|
||||
((0.2, 0.8), (0.3, 0.7), 0),
|
||||
((0.2, 0.8), (0.2, 0.8), 0),
|
||||
((0.2, 0.8), (0.2, 0.8), 1),
|
||||
],
|
||||
),
|
||||
(
|
||||
"dr_2", "acc", None, ["er31", "er32"],
|
||||
[
|
||||
((0.5, 0.1, 0.4), (0.3, 0.6, 0.1), 0),
|
||||
((0.5, 0.1, 0.4), (0.3, 0.3, 0.4), 0),
|
||||
((0.5, 0.1, 0.4), (0.2, 0.5, 0.3), 0),
|
||||
((0.5, 0.1, 0.4), (0.2, 0.5, 0.3), 1),
|
||||
((0.5, 0.1, 0.4), (0.2, 0.4, 0.4), 0),
|
||||
((0.4, 0.1, 0.5), (0.3, 0.6, 0.1), 0),
|
||||
((0.4, 0.1, 0.5), (0.3, 0.3, 0.4), 0),
|
||||
((0.4, 0.1, 0.5), (0.2, 0.5, 0.3), 0),
|
||||
((0.4, 0.1, 0.5), (0.2, 0.5, 0.3), 1),
|
||||
((0.4, 0.1, 0.5), (0.2, 0.4, 0.4), 0),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_data(self, dr_name, metric, estimators, columns, index, request):
|
||||
dr = request.getfixturevalue(dr_name)
|
||||
_data = dr.data(metric=metric, estimators=estimators)
|
||||
assert _data.columns.to_list() == columns
|
||||
assert _data.index.to_list() == index
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dr_name,metric,estimators,columns,index",
|
||||
[
|
||||
(
|
||||
"dr_1", "acc", None, ["er1", "er2", "er12"],
|
||||
[(0.0, 0), (0.0, 1), (0.0, 2), (0.1, 0),
|
||||
(0.1, 1), (0.1, 2), (0.1, 3), (0.2, 0)],
|
||||
),
|
||||
(
|
||||
"dr_2", "acc", None, ["er31", "er32"],
|
||||
[(0.2, 0), (0.2, 1), (0.3, 0), (0.3, 1), (0.4, 0),
|
||||
(0.4, 1), (0.4, 2), (0.4, 3), (0.5, 0), (0.5, 1)],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_shift_data(self, dr_name, metric, estimators, columns, index, request):
|
||||
dr = request.getfixturevalue(dr_name)
|
||||
_data = dr.shift_data(metric=metric, estimators=estimators)
|
||||
print(_data.index.tolist())
|
||||
assert _data.columns.to_list() == columns
|
||||
assert _data.index.to_list() == index
|
||||
# fmt: off
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,100 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import scipy.sparse as sp
|
||||
|
||||
from quacc.data import ExtendedData, ExtensionPolicy
|
||||
from quacc.method.base import MultiClassAccuracyEstimator
|
||||
|
||||
|
||||
@pytest.mark.mcae
|
||||
class TestMultiClassAccuracyEstimator:
|
||||
@pytest.mark.parametrize(
|
||||
"instances,pred_proba,extpol,result",
|
||||
[
|
||||
(
|
||||
np.arange(12).reshape((4, 3)),
|
||||
np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]),
|
||||
ExtensionPolicy(),
|
||||
np.array([0.21, 0.39, 0.1, 0.4]),
|
||||
),
|
||||
(
|
||||
np.arange(12).reshape((4, 3)),
|
||||
np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]),
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
np.array([0.21, 0.39, 0.5]),
|
||||
),
|
||||
(
|
||||
sp.csr_matrix(np.arange(12).reshape((4, 3))),
|
||||
np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]),
|
||||
ExtensionPolicy(),
|
||||
np.array([0.21, 0.39, 0.1, 0.4]),
|
||||
),
|
||||
(
|
||||
np.arange(12).reshape((4, 3)),
|
||||
np.array(
|
||||
[
|
||||
[0.3, 0.2, 0.5],
|
||||
[0.13, 0.67, 0.2],
|
||||
[0.21, 0.09, 0.8],
|
||||
[0.19, 0.1, 0.71],
|
||||
]
|
||||
),
|
||||
ExtensionPolicy(),
|
||||
np.array([0.21, 0.09, 0.1, 0.04, 0.06, 0.11, 0.11, 0.18, 0.1]),
|
||||
),
|
||||
(
|
||||
np.arange(12).reshape((4, 3)),
|
||||
np.array(
|
||||
[
|
||||
[0.3, 0.2, 0.5],
|
||||
[0.13, 0.67, 0.2],
|
||||
[0.21, 0.09, 0.8],
|
||||
[0.19, 0.1, 0.71],
|
||||
]
|
||||
),
|
||||
ExtensionPolicy(collapse_false=True),
|
||||
np.array([0.21, 0.09, 0.1, 0.7]),
|
||||
),
|
||||
(
|
||||
sp.csr_matrix(np.arange(12).reshape((4, 3))),
|
||||
np.array(
|
||||
[
|
||||
[0.3, 0.2, 0.5],
|
||||
[0.13, 0.67, 0.2],
|
||||
[0.21, 0.09, 0.8],
|
||||
[0.19, 0.1, 0.71],
|
||||
]
|
||||
),
|
||||
ExtensionPolicy(),
|
||||
np.array([0.21, 0.09, 0.1, 0.04, 0.06, 0.11, 0.11, 0.18, 0.1]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_estimate(self, monkeypatch, instances, pred_proba, extpol, result):
|
||||
ed = ExtendedData(instances, pred_proba, pred_proba, extpol)
|
||||
|
||||
class MockQuantifier:
|
||||
def __init__(self):
|
||||
self.classes_ = np.arange(result.shape[0])
|
||||
|
||||
def quantify(self, X):
|
||||
return result
|
||||
|
||||
def mockinit(self):
|
||||
self.extpol = extpol
|
||||
self.quantifier = MockQuantifier()
|
||||
|
||||
def mock_extend_instances(self, instances):
|
||||
return ed
|
||||
|
||||
monkeypatch.setattr(MultiClassAccuracyEstimator, "__init__", mockinit)
|
||||
monkeypatch.setattr(
|
||||
MultiClassAccuracyEstimator, "_extend_instances", mock_extend_instances
|
||||
)
|
||||
mcae = MultiClassAccuracyEstimator()
|
||||
|
||||
ep1 = mcae.estimate(instances)
|
||||
ep2 = mcae.estimate(ed)
|
||||
|
||||
assert (ep1.flat == ep2.flat).all()
|
||||
assert (ep1.flat == result).all()
|
Binary file not shown.
Binary file not shown.
|
@ -1,66 +0,0 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import scipy.sparse as sp
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
from quacc.method.base import BinaryQuantifierAccuracyEstimator
|
||||
|
||||
|
||||
class TestBQAE:
|
||||
@pytest.mark.parametrize(
|
||||
"instances,preds0,preds1,result",
|
||||
[
|
||||
(
|
||||
np.asarray(
|
||||
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
||||
),
|
||||
np.asarray([0.3, 0.7]),
|
||||
np.asarray([0.4, 0.6]),
|
||||
np.asarray([0.15, 0.2, 0.35, 0.3]),
|
||||
),
|
||||
(
|
||||
sp.csr_matrix(
|
||||
[[0, 0.3, 0.7], [1, 0.54, 0.46], [2, 0.28, 0.72], [3, 0.6, 0.4]]
|
||||
),
|
||||
np.asarray([0.3, 0.7]),
|
||||
np.asarray([0.4, 0.6]),
|
||||
np.asarray([0.15, 0.2, 0.35, 0.3]),
|
||||
),
|
||||
(
|
||||
np.asarray([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
||||
np.asarray([0.3, 0.7]),
|
||||
np.asarray([0.4, 0.6]),
|
||||
np.asarray([0.0, 0.4, 0.0, 0.6]),
|
||||
),
|
||||
(
|
||||
sp.csr_matrix([[0, 0.3, 0.7], [2, 0.28, 0.72]]),
|
||||
np.asarray([0.3, 0.7]),
|
||||
np.asarray([0.4, 0.6]),
|
||||
np.asarray([0.0, 0.4, 0.0, 0.6]),
|
||||
),
|
||||
(
|
||||
np.asarray([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
||||
np.asarray([0.3, 0.7]),
|
||||
np.asarray([0.4, 0.6]),
|
||||
np.asarray([0.3, 0.0, 0.7, 0.0]),
|
||||
),
|
||||
(
|
||||
sp.csr_matrix([[1, 0.54, 0.46], [3, 0.6, 0.4]]),
|
||||
np.asarray([0.3, 0.7]),
|
||||
np.asarray([0.4, 0.6]),
|
||||
np.asarray([0.3, 0.0, 0.7, 0.0]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_estimate_ndarray(self, mocker, instances, preds0, preds1, result):
|
||||
estimator = BinaryQuantifierAccuracyEstimator(LogisticRegression())
|
||||
estimator.n_classes = 4
|
||||
with mocker.patch.object(estimator.q_model_0, "quantify"), mocker.patch.object(
|
||||
estimator.q_model_1, "quantify"
|
||||
):
|
||||
estimator.q_model_0.quantify.return_value = preds0
|
||||
estimator.q_model_1.quantify.return_value = preds1
|
||||
assert np.array_equal(
|
||||
estimator.estimate(instances, ext=True),
|
||||
result,
|
||||
)
|
|
@ -1,2 +0,0 @@
|
|||
class TestMCAE:
|
||||
pass
|
Loading…
Reference in New Issue