gs params updated, methods added

This commit is contained in:
Lorenzo Volpi 2023-11-16 01:35:49 +01:00
parent 4423221ba9
commit 81b92157a5
1 changed files with 29 additions and 1 deletions

View File

@ -17,7 +17,7 @@ _sld_param_grid = {
"q__classifier__C": np.logspace(-3, 3, 7),
"q__classifier__class_weight": [None, "balanced"],
"q__recalib": [None, "bcts"],
"confidence": [["max_conf", "entropy"]],
"confidence": [["max_conf"], ["entropy"], ["max_conf", "entropy"]],
}
_pacc_param_grid = {
"q__classifier__C": np.logspace(-3, 3, 7),
@ -151,6 +151,20 @@ def mulmc_sld(c_model, validation, protocol) -> EvaluationReport:
)
@method
def mul3wmc_sld(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(
c_model,
SLD(LogisticRegression()),
confidence="max_conf",
collapse_false=True,
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def binne_sld(c_model, validation, protocol) -> EvaluationReport:
est = BQAE(
@ -177,6 +191,20 @@ def mulne_sld(c_model, validation, protocol) -> EvaluationReport:
)
@method
def mul3wne_sld(c_model, validation, protocol) -> EvaluationReport:
est = MCAE(
c_model,
SLD(LogisticRegression()),
confidence="entropy",
collapse_false=True,
).fit(validation)
return evaluation_report(
estimator=est,
protocol=protocol,
)
@method
def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
v_train, v_val = validation.split_stratified(0.6, random_state=0)