import numpy as np import pytest from quacc import error from quacc.data import ExtendedPrev, ExtensionPolicy @pytest.mark.err class TestError: @pytest.mark.parametrize( "prev,result", [ (np.array([[1, 4], [4, 4]]), 0.5), (np.array([[6, 2, 4], [2, 4, 2], [4, 2, 6]]), 0.5), ], ) def test_f1(self, prev, result): ep = ExtendedPrev(prev.flatten(), prev.shape[0], extpol=ExtensionPolicy()) assert error.f1(prev) == result assert error.f1(ep) == result @pytest.mark.parametrize( "prev,result", [ (np.array([[4, 4], [4, 4]]), 0.5), (np.array([[2, 4, 2], [2, 2, 4], [4, 2, 2]]), 0.25), ], ) def test_acc(self, prev, result): ep = ExtendedPrev(prev.flatten(), prev.shape[0], extpol=ExtensionPolicy()) assert error.acc(prev) == result assert error.acc(ep) == result @pytest.mark.parametrize( "true_prev,estim_prev,nbcl,extpol,result", [ ( [ np.array([0.2, 0.4, 0.1, 0.3]), np.array([0.1, 0.5, 0.1, 0.3]), ], [ np.array([0.3, 0.4, 0.2, 0.1]), np.array([0.5, 0.3, 0.1, 0.1]), ], 2, ExtensionPolicy(), np.array([0.1, 0.2]), ), ( [ np.array([0.2, 0.4, 0.4]), np.array([0.1, 0.5, 0.4]), ], [ np.array([0.3, 0.4, 0.3]), np.array([0.5, 0.3, 0.2]), ], 2, ExtensionPolicy(collapse_false=True), np.array([0.1, 0.2]), ), ( [ np.array([0.02, 0.04, 0.16, 0.38, 0.1, 0.05, 0.15, 0.08, 0.02]), np.array([0.04, 0.02, 0.14, 0.40, 0.1, 0.03, 0.17, 0.07, 0.03]), ], [ np.array([0.02, 0.04, 0.16, 0.48, 0.0, 0.05, 0.15, 0.08, 0.02]), np.array([0.14, 0.02, 0.04, 0.30, 0.2, 0.03, 0.17, 0.07, 0.03]), ], 3, ExtensionPolicy(), np.array([0.1, 0.2]), ), ( [ np.array([0.2, 0.4, 0.2, 0.2]), np.array([0.1, 0.3, 0.2, 0.4]), ], [ np.array([0.3, 0.3, 0.1, 0.3]), np.array([0.5, 0.2, 0.1, 0.2]), ], 3, ExtensionPolicy(collapse_false=True), np.array([0.1, 0.2]), ), ], ) def test_accd(self, true_prev, estim_prev, nbcl, extpol, result): true_prev = [ExtendedPrev(tp, nbcl, extpol=extpol) for tp in true_prev] estim_prev = [ExtendedPrev(ep, nbcl, extpol=extpol) for ep in estim_prev] _err = error.accd(true_prev, estim_prev) assert (np.abs(_err - result) < 1e-15).all()