dealing with unit tests
This commit is contained in:
parent
960ca5076e
commit
aac133817b
|
|
@ -67,8 +67,14 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
assert val_split > 1, \
|
||||
(f'when {val_split=} is indicated as an integer, it represents the number of folds in a kFCV '
|
||||
f'and must thus be >1')
|
||||
assert fit_classifier, (f'when {val_split=} is indicated as an integer (the number of folds for kFCV) '
|
||||
f'the parameter {fit_classifier=} must be True')
|
||||
if val_split==5 and not fit_classifier:
|
||||
print(f'Warning: {val_split=} will be ignored when the classifier is already trained '
|
||||
f'({fit_classifier=}). Parameter {self.val_split=} will be set to None. Set {val_split=} '
|
||||
f'to None to avoid this warning.')
|
||||
self.val_split=None
|
||||
if val_split!=5:
|
||||
assert fit_classifier, (f'Parameter {val_split=} has been modified, but {fit_classifier=} '
|
||||
f'indicates the classifier should not be retrained.')
|
||||
elif isinstance(val_split, float):
|
||||
assert 0 < val_split < 1, \
|
||||
(f'when {val_split=} is indicated as a float, it represents the fraction of training instances '
|
||||
|
|
@ -174,7 +180,9 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
elif self.val_split is None:
|
||||
if self.fit_classifier:
|
||||
self.classifier.fit(X, y)
|
||||
predictions, labels = None, None
|
||||
predictions, labels = None, None
|
||||
else:
|
||||
predictions, labels = self.classify(X), y
|
||||
else:
|
||||
raise ValueError(f'unexpected type for {self.val_split=}')
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class TestDatasets(unittest.TestCase):
|
|||
def _check_dataset(self, dataset):
|
||||
q = self.new_quantifier()
|
||||
print(f'testing method {q} in {dataset.name}...', end='')
|
||||
q.fit(dataset.training)
|
||||
q.fit(*dataset.training.Xy)
|
||||
estim_prevalences = q.predict(dataset.test.instances)
|
||||
self.assertTrue(F.check_prevalence_vector(estim_prevalences))
|
||||
print(f'[done]')
|
||||
|
|
@ -89,7 +89,7 @@ class TestDatasets(unittest.TestCase):
|
|||
n_classes = train.n_classes
|
||||
train = train.sampling(100, *F.uniform_prevalence(n_classes))
|
||||
q = self.new_quantifier()
|
||||
q.fit(train)
|
||||
q.fit(*train.Xy)
|
||||
self._check_samples(gen_val, q, max_samples_test=5)
|
||||
self._check_samples(gen_test, q, max_samples_test=5)
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ class TestDatasets(unittest.TestCase):
|
|||
tfidf = TfidfVectorizer()
|
||||
train.instances = tfidf.fit_transform(train.instances)
|
||||
q = self.new_quantifier()
|
||||
q.fit(train)
|
||||
q.fit(*train.Xy)
|
||||
self._check_samples(gen_val, q, max_samples_test=5, vectorizer=tfidf)
|
||||
self._check_samples(gen_test, q, max_samples_test=5, vectorizer=tfidf)
|
||||
|
||||
|
|
|
|||
|
|
@ -48,9 +48,9 @@ class TestMethods(unittest.TestCase):
|
|||
print(f'skipping the test of binary model {model.__name__} on multiclass dataset {dataset.name}')
|
||||
continue
|
||||
|
||||
q = model(learner)
|
||||
q = model(learner, fit_classifier=False)
|
||||
print('testing', q)
|
||||
q.fit(dataset.training, fit_classifier=False)
|
||||
q.fit(*dataset.training.Xy)
|
||||
estim_prevalences = q.predict(dataset.test.X)
|
||||
self.assertTrue(check_prevalence_vector(estim_prevalences))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue