clean
This commit is contained in:
parent
636e33318f
commit
edbc8bc201
|
|
@ -161,7 +161,9 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
|||
assert self.fit_classifier, f'{self.__class__}: unexpected value for {self.fit_classifier=}'
|
||||
num_folds = self.val_split
|
||||
n_jobs = self.n_jobs if hasattr(self, 'n_jobs') else qp._get_njobs(None)
|
||||
predictions = cross_val_predict(self.classifier, X, y, cv=num_folds, n_jobs=n_jobs, method=self._classifier_method())
|
||||
predictions = cross_val_predict(
|
||||
self.classifier, X, y, cv=num_folds, n_jobs=n_jobs, method=self._classifier_method()
|
||||
)
|
||||
labels = y
|
||||
self.classifier.fit(X, y)
|
||||
elif isinstance(self.val_split, float):
|
||||
|
|
@ -756,8 +758,8 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
if self.val_split is not None:
|
||||
if self.exact_train_prev and self.calib is None:
|
||||
raise RuntimeWarning(f'The parameter {self.val_split=} was specified for EMQ, while the parameters '
|
||||
f'{self.exact_train_prev=} and {self.calib=}. This has no effect and causes an unnecessary '
|
||||
f'overload.')
|
||||
f'{self.exact_train_prev=} and {self.calib=}. This has no effect and causes an '
|
||||
f'unnecessary overload.')
|
||||
else:
|
||||
if self.calib is not None:
|
||||
print(f'[warning] The parameter {self.calib=} requires the val_split be different from None. '
|
||||
|
|
@ -784,8 +786,6 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
def _fit_calibration(self, calibrator, P, y):
|
||||
n_classes = len(self.classes_)
|
||||
|
||||
print(y, 'Y')
|
||||
print(y.dtype, 'DTYPE')
|
||||
if not np.issubdtype(y.dtype, np.number):
|
||||
y = np.searchsorted(self.classes_, y)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue