This commit is contained in:
Alejandro Moreo Fernandez 2025-10-01 10:26:57 +02:00
parent 636e33318f
commit edbc8bc201
1 changed files with 5 additions and 5 deletions

View File

@ -161,7 +161,9 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
assert self.fit_classifier, f'{self.__class__}: unexpected value for {self.fit_classifier=}'
num_folds = self.val_split
n_jobs = self.n_jobs if hasattr(self, 'n_jobs') else qp._get_njobs(None)
predictions = cross_val_predict(self.classifier, X, y, cv=num_folds, n_jobs=n_jobs, method=self._classifier_method())
predictions = cross_val_predict(
self.classifier, X, y, cv=num_folds, n_jobs=n_jobs, method=self._classifier_method()
)
labels = y
self.classifier.fit(X, y)
elif isinstance(self.val_split, float):
@ -756,8 +758,8 @@ class EMQ(AggregativeSoftQuantifier):
if self.val_split is not None:
if self.exact_train_prev and self.calib is None:
raise RuntimeWarning(f'The parameter {self.val_split=} was specified for EMQ, while the parameters '
f'{self.exact_train_prev=} and {self.calib=}. This has no effect and causes an unnecessary '
f'overload.')
f'{self.exact_train_prev=} and {self.calib=}. This has no effect and causes an '
f'unnecessary overload.')
else:
if self.calib is not None:
print(f'[warning] The parameter {self.calib=} requires the val_split be different from None. '
@ -784,8 +786,6 @@ class EMQ(AggregativeSoftQuantifier):
def _fit_calibration(self, calibrator, P, y):
n_classes = len(self.classes_)
print(y, 'Y')
print(y.dtype, 'DTYPE')
if not np.issubdtype(y.dtype, np.number):
y = np.searchsorted(self.classes_, y)