QuAcc/tests/test_error.py

96 lines
3.1 KiB
Python

import numpy as np
import pytest
from quacc import error
from quacc.data import ExtendedPrev, ExtensionPolicy
@pytest.mark.err
class TestError:
@pytest.mark.parametrize(
"prev,result",
[
(np.array([[1, 4], [4, 4]]), 0.5),
(np.array([[6, 2, 4], [2, 4, 2], [4, 2, 6]]), 0.5),
],
)
def test_f1(self, prev, result):
ep = ExtendedPrev(prev.flatten(), prev.shape[0], extpol=ExtensionPolicy())
assert error.f1(prev) == result
assert error.f1(ep) == result
@pytest.mark.parametrize(
"prev,result",
[
(np.array([[4, 4], [4, 4]]), 0.5),
(np.array([[2, 4, 2], [2, 2, 4], [4, 2, 2]]), 0.25),
],
)
def test_acc(self, prev, result):
ep = ExtendedPrev(prev.flatten(), prev.shape[0], extpol=ExtensionPolicy())
assert error.acc(prev) == result
assert error.acc(ep) == result
@pytest.mark.parametrize(
"true_prev,estim_prev,nbcl,extpol,result",
[
(
[
np.array([0.2, 0.4, 0.1, 0.3]),
np.array([0.1, 0.5, 0.1, 0.3]),
],
[
np.array([0.3, 0.4, 0.2, 0.1]),
np.array([0.5, 0.3, 0.1, 0.1]),
],
2,
ExtensionPolicy(),
np.array([0.1, 0.2]),
),
(
[
np.array([0.2, 0.4, 0.4]),
np.array([0.1, 0.5, 0.4]),
],
[
np.array([0.3, 0.4, 0.3]),
np.array([0.5, 0.3, 0.2]),
],
2,
ExtensionPolicy(collapse_false=True),
np.array([0.1, 0.2]),
),
(
[
np.array([0.02, 0.04, 0.16, 0.38, 0.1, 0.05, 0.15, 0.08, 0.02]),
np.array([0.04, 0.02, 0.14, 0.40, 0.1, 0.03, 0.17, 0.07, 0.03]),
],
[
np.array([0.02, 0.04, 0.16, 0.48, 0.0, 0.05, 0.15, 0.08, 0.02]),
np.array([0.14, 0.02, 0.04, 0.30, 0.2, 0.03, 0.17, 0.07, 0.03]),
],
3,
ExtensionPolicy(),
np.array([0.1, 0.2]),
),
(
[
np.array([0.2, 0.4, 0.2, 0.2]),
np.array([0.1, 0.3, 0.2, 0.4]),
],
[
np.array([0.3, 0.3, 0.1, 0.3]),
np.array([0.5, 0.2, 0.1, 0.2]),
],
3,
ExtensionPolicy(collapse_false=True),
np.array([0.1, 0.2]),
),
],
)
def test_accd(self, true_prev, estim_prev, nbcl, extpol, result):
true_prev = [ExtendedPrev(tp, nbcl, extpol=extpol) for tp in true_prev]
estim_prev = [ExtendedPrev(ep, nbcl, extpol=extpol) for ep in estim_prev]
_err = error.accd(true_prev, estim_prev)
assert (np.abs(_err - result) < 1e-15).all()