QuAcc/tests/test_method/test_base.py

101 lines
3.3 KiB
Python

import numpy as np
import pytest
import scipy.sparse as sp
from quacc.data import ExtendedData, ExtensionPolicy
from quacc.method.base import MultiClassAccuracyEstimator
@pytest.mark.mcae
class TestMultiClassAccuracyEstimator:
@pytest.mark.parametrize(
"instances,pred_proba,extpol,result",
[
(
np.arange(12).reshape((4, 3)),
np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]),
ExtensionPolicy(),
np.array([0.21, 0.39, 0.1, 0.4]),
),
(
np.arange(12).reshape((4, 3)),
np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]),
ExtensionPolicy(collapse_false=True),
np.array([0.21, 0.39, 0.5]),
),
(
sp.csr_matrix(np.arange(12).reshape((4, 3))),
np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]),
ExtensionPolicy(),
np.array([0.21, 0.39, 0.1, 0.4]),
),
(
np.arange(12).reshape((4, 3)),
np.array(
[
[0.3, 0.2, 0.5],
[0.13, 0.67, 0.2],
[0.21, 0.09, 0.8],
[0.19, 0.1, 0.71],
]
),
ExtensionPolicy(),
np.array([0.21, 0.09, 0.1, 0.04, 0.06, 0.11, 0.11, 0.18, 0.1]),
),
(
np.arange(12).reshape((4, 3)),
np.array(
[
[0.3, 0.2, 0.5],
[0.13, 0.67, 0.2],
[0.21, 0.09, 0.8],
[0.19, 0.1, 0.71],
]
),
ExtensionPolicy(collapse_false=True),
np.array([0.21, 0.09, 0.1, 0.7]),
),
(
sp.csr_matrix(np.arange(12).reshape((4, 3))),
np.array(
[
[0.3, 0.2, 0.5],
[0.13, 0.67, 0.2],
[0.21, 0.09, 0.8],
[0.19, 0.1, 0.71],
]
),
ExtensionPolicy(),
np.array([0.21, 0.09, 0.1, 0.04, 0.06, 0.11, 0.11, 0.18, 0.1]),
),
],
)
def test_estimate(self, monkeypatch, instances, pred_proba, extpol, result):
ed = ExtendedData(instances, pred_proba, pred_proba, extpol)
class MockQuantifier:
def __init__(self):
self.classes_ = np.arange(result.shape[0])
def quantify(self, X):
return result
def mockinit(self):
self.extpol = extpol
self.quantifier = MockQuantifier()
def mock_extend_instances(self, instances):
return ed
monkeypatch.setattr(MultiClassAccuracyEstimator, "__init__", mockinit)
monkeypatch.setattr(
MultiClassAccuracyEstimator, "_extend_instances", mock_extend_instances
)
mcae = MultiClassAccuracyEstimator()
ep1 = mcae.estimate(instances)
ep2 = mcae.estimate(ed)
assert (ep1.flat == ep2.flat).all()
assert (ep1.flat == result).all()