clean
This commit is contained in:
parent
636e33318f
commit
edbc8bc201
|
|
@ -161,7 +161,9 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
assert self.fit_classifier, f'{self.__class__}: unexpected value for {self.fit_classifier=}'
|
assert self.fit_classifier, f'{self.__class__}: unexpected value for {self.fit_classifier=}'
|
||||||
num_folds = self.val_split
|
num_folds = self.val_split
|
||||||
n_jobs = self.n_jobs if hasattr(self, 'n_jobs') else qp._get_njobs(None)
|
n_jobs = self.n_jobs if hasattr(self, 'n_jobs') else qp._get_njobs(None)
|
||||||
predictions = cross_val_predict(self.classifier, X, y, cv=num_folds, n_jobs=n_jobs, method=self._classifier_method())
|
predictions = cross_val_predict(
|
||||||
|
self.classifier, X, y, cv=num_folds, n_jobs=n_jobs, method=self._classifier_method()
|
||||||
|
)
|
||||||
labels = y
|
labels = y
|
||||||
self.classifier.fit(X, y)
|
self.classifier.fit(X, y)
|
||||||
elif isinstance(self.val_split, float):
|
elif isinstance(self.val_split, float):
|
||||||
|
|
@ -756,8 +758,8 @@ class EMQ(AggregativeSoftQuantifier):
|
||||||
if self.val_split is not None:
|
if self.val_split is not None:
|
||||||
if self.exact_train_prev and self.calib is None:
|
if self.exact_train_prev and self.calib is None:
|
||||||
raise RuntimeWarning(f'The parameter {self.val_split=} was specified for EMQ, while the parameters '
|
raise RuntimeWarning(f'The parameter {self.val_split=} was specified for EMQ, while the parameters '
|
||||||
f'{self.exact_train_prev=} and {self.calib=}. This has no effect and causes an unnecessary '
|
f'{self.exact_train_prev=} and {self.calib=}. This has no effect and causes an '
|
||||||
f'overload.')
|
f'unnecessary overload.')
|
||||||
else:
|
else:
|
||||||
if self.calib is not None:
|
if self.calib is not None:
|
||||||
print(f'[warning] The parameter {self.calib=} requires the val_split be different from None. '
|
print(f'[warning] The parameter {self.calib=} requires the val_split be different from None. '
|
||||||
|
|
@ -784,8 +786,6 @@ class EMQ(AggregativeSoftQuantifier):
|
||||||
def _fit_calibration(self, calibrator, P, y):
|
def _fit_calibration(self, calibrator, P, y):
|
||||||
n_classes = len(self.classes_)
|
n_classes = len(self.classes_)
|
||||||
|
|
||||||
print(y, 'Y')
|
|
||||||
print(y.dtype, 'DTYPE')
|
|
||||||
if not np.issubdtype(y.dtype, np.number):
|
if not np.issubdtype(y.dtype, np.number):
|
||||||
y = np.searchsorted(self.classes_, y)
|
y = np.searchsorted(self.classes_, y)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue