import numpy as np import pytest import scipy.sparse as sp from quacc.data import ExtendedData, ExtensionPolicy from quacc.method.base import MultiClassAccuracyEstimator @pytest.mark.mcae class TestMultiClassAccuracyEstimator: @pytest.mark.parametrize( "instances,pred_proba,extpol,result", [ ( np.arange(12).reshape((4, 3)), np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]), ExtensionPolicy(), np.array([0.21, 0.39, 0.1, 0.4]), ), ( np.arange(12).reshape((4, 3)), np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]), ExtensionPolicy(collapse_false=True), np.array([0.21, 0.39, 0.5]), ), ( sp.csr_matrix(np.arange(12).reshape((4, 3))), np.array([[0.3, 0.7], [0.6, 0.4], [0.2, 0.8], [0.9, 0.1]]), ExtensionPolicy(), np.array([0.21, 0.39, 0.1, 0.4]), ), ( np.arange(12).reshape((4, 3)), np.array( [ [0.3, 0.2, 0.5], [0.13, 0.67, 0.2], [0.21, 0.09, 0.8], [0.19, 0.1, 0.71], ] ), ExtensionPolicy(), np.array([0.21, 0.09, 0.1, 0.04, 0.06, 0.11, 0.11, 0.18, 0.1]), ), ( np.arange(12).reshape((4, 3)), np.array( [ [0.3, 0.2, 0.5], [0.13, 0.67, 0.2], [0.21, 0.09, 0.8], [0.19, 0.1, 0.71], ] ), ExtensionPolicy(collapse_false=True), np.array([0.21, 0.09, 0.1, 0.7]), ), ( sp.csr_matrix(np.arange(12).reshape((4, 3))), np.array( [ [0.3, 0.2, 0.5], [0.13, 0.67, 0.2], [0.21, 0.09, 0.8], [0.19, 0.1, 0.71], ] ), ExtensionPolicy(), np.array([0.21, 0.09, 0.1, 0.04, 0.06, 0.11, 0.11, 0.18, 0.1]), ), ], ) def test_estimate(self, monkeypatch, instances, pred_proba, extpol, result): ed = ExtendedData(instances, pred_proba, pred_proba, extpol) class MockQuantifier: def __init__(self): self.classes_ = np.arange(result.shape[0]) def quantify(self, X): return result def mockinit(self): self.extpol = extpol self.quantifier = MockQuantifier() def mock_extend_instances(self, instances): return ed monkeypatch.setattr(MultiClassAccuracyEstimator, "__init__", mockinit) monkeypatch.setattr( MultiClassAccuracyEstimator, "_extend_instances", mock_extend_instances ) mcae = MultiClassAccuracyEstimator() ep1 = mcae.estimate(instances) ep2 = mcae.estimate(ed) assert (ep1.flat == ep2.flat).all() assert (ep1.flat == result).all()