From 81b92157a51477cbfea35c8fa9e5158678f19b14 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 16 Nov 2023 01:35:49 +0100 Subject: [PATCH] gs params updated, methods added --- quacc/evaluation/method.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/quacc/evaluation/method.py b/quacc/evaluation/method.py index e812b5c..ba33c14 100644 --- a/quacc/evaluation/method.py +++ b/quacc/evaluation/method.py @@ -17,7 +17,7 @@ _sld_param_grid = { "q__classifier__C": np.logspace(-3, 3, 7), "q__classifier__class_weight": [None, "balanced"], "q__recalib": [None, "bcts"], - "confidence": [["max_conf", "entropy"]], + "confidence": [["max_conf"], ["entropy"], ["max_conf", "entropy"]], } _pacc_param_grid = { "q__classifier__C": np.logspace(-3, 3, 7), @@ -151,6 +151,20 @@ def mulmc_sld(c_model, validation, protocol) -> EvaluationReport: ) +@method +def mul3wmc_sld(c_model, validation, protocol) -> EvaluationReport: + est = MCAE( + c_model, + SLD(LogisticRegression()), + confidence="max_conf", + collapse_false=True, + ).fit(validation) + return evaluation_report( + estimator=est, + protocol=protocol, + ) + + @method def binne_sld(c_model, validation, protocol) -> EvaluationReport: est = BQAE( @@ -177,6 +191,20 @@ def mulne_sld(c_model, validation, protocol) -> EvaluationReport: ) +@method +def mul3wne_sld(c_model, validation, protocol) -> EvaluationReport: + est = MCAE( + c_model, + SLD(LogisticRegression()), + confidence="entropy", + collapse_false=True, + ).fit(validation) + return evaluation_report( + estimator=est, + protocol=protocol, + ) + + @method def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport: v_train, v_val = validation.split_stratified(0.6, random_state=0)