diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index 73816c5..cda6294 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -67,8 +67,14 @@ class AggregativeQuantifier(BaseQuantifier, ABC): assert val_split > 1, \ (f'when {val_split=} is indicated as an integer, it represents the number of folds in a kFCV ' f'and must thus be >1') - assert fit_classifier, (f'when {val_split=} is indicated as an integer (the number of folds for kFCV) ' - f'the parameter {fit_classifier=} must be True') + if val_split==5 and not fit_classifier: + print(f'Warning: {val_split=} will be ignored when the classifier is already trained ' + f'({fit_classifier=}). Parameter {self.val_split=} will be set to None. Set {val_split=} ' + f'to None to avoid this warning.') + self.val_split=None + if val_split!=5: + assert fit_classifier, (f'Parameter {val_split=} has been modified, but {fit_classifier=} ' + f'indicates the classifier should not be retrained.') elif isinstance(val_split, float): assert 0 < val_split < 1, \ (f'when {val_split=} is indicated as a float, it represents the fraction of training instances ' @@ -174,7 +180,9 @@ class AggregativeQuantifier(BaseQuantifier, ABC): elif self.val_split is None: if self.fit_classifier: self.classifier.fit(X, y) - predictions, labels = None, None + predictions, labels = None, None + else: + predictions, labels = self.classify(X), y else: raise ValueError(f'unexpected type for {self.val_split=}') diff --git a/quapy/tests/test_datasets.py b/quapy/tests/test_datasets.py index 6903aba..63c6ef8 100644 --- a/quapy/tests/test_datasets.py +++ b/quapy/tests/test_datasets.py @@ -17,7 +17,7 @@ class TestDatasets(unittest.TestCase): def _check_dataset(self, dataset): q = self.new_quantifier() print(f'testing method {q} in {dataset.name}...', end='') - q.fit(dataset.training) + q.fit(*dataset.training.Xy) estim_prevalences = q.predict(dataset.test.instances) self.assertTrue(F.check_prevalence_vector(estim_prevalences)) print(f'[done]') @@ -89,7 +89,7 @@ class TestDatasets(unittest.TestCase): n_classes = train.n_classes train = train.sampling(100, *F.uniform_prevalence(n_classes)) q = self.new_quantifier() - q.fit(train) + q.fit(*train.Xy) self._check_samples(gen_val, q, max_samples_test=5) self._check_samples(gen_test, q, max_samples_test=5) @@ -102,7 +102,7 @@ class TestDatasets(unittest.TestCase): tfidf = TfidfVectorizer() train.instances = tfidf.fit_transform(train.instances) q = self.new_quantifier() - q.fit(train) + q.fit(*train.Xy) self._check_samples(gen_val, q, max_samples_test=5, vectorizer=tfidf) self._check_samples(gen_test, q, max_samples_test=5, vectorizer=tfidf) diff --git a/quapy/tests/test_methods.py b/quapy/tests/test_methods.py index c938c80..71753c8 100644 --- a/quapy/tests/test_methods.py +++ b/quapy/tests/test_methods.py @@ -48,9 +48,9 @@ class TestMethods(unittest.TestCase): print(f'skipping the test of binary model {model.__name__} on multiclass dataset {dataset.name}') continue - q = model(learner) + q = model(learner, fit_classifier=False) print('testing', q) - q.fit(dataset.training, fit_classifier=False) + q.fit(*dataset.training.Xy) estim_prevalences = q.predict(dataset.test.X) self.assertTrue(check_prevalence_vector(estim_prevalences))