dealing with unit tests

This commit is contained in:
Alejandro Moreo Fernandez 2025-04-25 13:52:05 +02:00
parent 960ca5076e
commit aac133817b
3 changed files with 16 additions and 8 deletions

View File

@ -67,8 +67,14 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
assert val_split > 1, \
(f'when {val_split=} is indicated as an integer, it represents the number of folds in a kFCV '
f'and must thus be >1')
assert fit_classifier, (f'when {val_split=} is indicated as an integer (the number of folds for kFCV) '
f'the parameter {fit_classifier=} must be True')
if val_split==5 and not fit_classifier:
print(f'Warning: {val_split=} will be ignored when the classifier is already trained '
f'({fit_classifier=}). Parameter {self.val_split=} will be set to None. Set {val_split=} '
f'to None to avoid this warning.')
self.val_split=None
if val_split!=5:
assert fit_classifier, (f'Parameter {val_split=} has been modified, but {fit_classifier=} '
f'indicates the classifier should not be retrained.')
elif isinstance(val_split, float):
assert 0 < val_split < 1, \
(f'when {val_split=} is indicated as a float, it represents the fraction of training instances '
@ -174,7 +180,9 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
elif self.val_split is None:
if self.fit_classifier:
self.classifier.fit(X, y)
predictions, labels = None, None
predictions, labels = None, None
else:
predictions, labels = self.classify(X), y
else:
raise ValueError(f'unexpected type for {self.val_split=}')

View File

@ -17,7 +17,7 @@ class TestDatasets(unittest.TestCase):
def _check_dataset(self, dataset):
q = self.new_quantifier()
print(f'testing method {q} in {dataset.name}...', end='')
q.fit(dataset.training)
q.fit(*dataset.training.Xy)
estim_prevalences = q.predict(dataset.test.instances)
self.assertTrue(F.check_prevalence_vector(estim_prevalences))
print(f'[done]')
@ -89,7 +89,7 @@ class TestDatasets(unittest.TestCase):
n_classes = train.n_classes
train = train.sampling(100, *F.uniform_prevalence(n_classes))
q = self.new_quantifier()
q.fit(train)
q.fit(*train.Xy)
self._check_samples(gen_val, q, max_samples_test=5)
self._check_samples(gen_test, q, max_samples_test=5)
@ -102,7 +102,7 @@ class TestDatasets(unittest.TestCase):
tfidf = TfidfVectorizer()
train.instances = tfidf.fit_transform(train.instances)
q = self.new_quantifier()
q.fit(train)
q.fit(*train.Xy)
self._check_samples(gen_val, q, max_samples_test=5, vectorizer=tfidf)
self._check_samples(gen_test, q, max_samples_test=5, vectorizer=tfidf)

View File

@ -48,9 +48,9 @@ class TestMethods(unittest.TestCase):
print(f'skipping the test of binary model {model.__name__} on multiclass dataset {dataset.name}')
continue
q = model(learner)
q = model(learner, fit_classifier=False)
print('testing', q)
q.fit(dataset.training, fit_classifier=False)
q.fit(*dataset.training.Xy)
estim_prevalences = q.predict(dataset.test.X)
self.assertTrue(check_prevalence_vector(estim_prevalences))