From dfd8d11e8f5ce23da12154efca05f90e88fa330e Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Fri, 28 Jul 2023 01:47:44 +0200 Subject: [PATCH] Typing fixed, gitignore updated --- .gitignore | 3 ++- quacc/data.py | 8 +++++--- quacc/estimator.py | 28 ++++++++++++++-------------- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 5766485..3dbe099 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ quavenv/* *.pdf quacc/__pycache__/* -tests/__pycache__/* \ No newline at end of file +tests/__pycache__/* +.coverage \ No newline at end of file diff --git a/quacc/data.py b/quacc/data.py index fd1e3c3..b05105b 100644 --- a/quacc/data.py +++ b/quacc/data.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import List, Optional, Self import numpy as np import math @@ -44,7 +44,7 @@ class ExtendedCollection(LabelledCollection): ): super().__init__(instances, labels, classes=classes) - def split_by_pred(self): + def split_by_pred(self) -> List[Self]: _ncl = int(math.sqrt(self.n_classes)) _indexes = ExtendedCollection._split_index_by_pred(_ncl, self.instances) if isinstance(self.instances, np.ndarray): @@ -128,7 +128,9 @@ class ExtendedCollection(LabelledCollection): return n_x @classmethod - def extend_collection(cls, base: LabelledCollection, pred_proba: np.ndarray) -> Any: + def extend_collection( + cls, base: LabelledCollection, pred_proba: np.ndarray + ) -> Self: n_classes = base.n_classes # n_X = [ X | predicted probs. ] diff --git a/quacc/estimator.py b/quacc/estimator.py index 2fccfe1..4516b6d 100644 --- a/quacc/estimator.py +++ b/quacc/estimator.py @@ -8,17 +8,17 @@ from sklearn.base import BaseEstimator from sklearn.linear_model import LogisticRegression from sklearn.model_selection import cross_val_predict -from quacc.data import ExtendedCollection as EC +from quacc.data import ExtendedCollection class AccuracyEstimator: - def extend(self, base: LabelledCollection, pred_proba=None) -> EC: + def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection: if not pred_proba: pred_proba = self.c_model.predict_proba(base.X) - return EC.extend_collection(base, pred_proba) + return ExtendedCollection.extend_collection(base, pred_proba) @abstractmethod - def fit(self, train: LabelledCollection | EC): + def fit(self, train: LabelledCollection | ExtendedCollection): ... @abstractmethod @@ -32,7 +32,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator): self.q_model = SLD(LogisticRegression()) self.e_train = None - def fit(self, train: LabelledCollection | EC): + def fit(self, train: LabelledCollection | ExtendedCollection): # check if model is fit # self.model.fit(*train.Xy) if isinstance(train, LabelledCollection): @@ -40,7 +40,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator): self.c_model, *train.Xy, method="predict_proba" ) - self.e_train = EC.extend_collection(train, pred_prob_train) + self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train) else: self.e_train = train @@ -49,7 +49,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator): def estimate(self, instances, ext=False): if not ext: pred_prob = self.c_model.predict_proba(instances) - e_inst = EC.extend_instances(instances, pred_prob) + e_inst = ExtendedCollection.extend_instances(instances, pred_prob) else: e_inst = instances @@ -71,9 +71,9 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): self.c_model = c_model self.q_model_0 = SLD(LogisticRegression()) self.q_model_1 = SLD(LogisticRegression()) - self.e_train: EC = None + self.e_train = None - def fit(self, train: LabelledCollection | EC): + def fit(self, train: LabelledCollection | ExtendedCollection): # check if model is fit # self.model.fit(*train.Xy) if isinstance(train, LabelledCollection): @@ -81,9 +81,9 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): self.c_model, *train.Xy, method="predict_proba" ) - self.e_train = EC.extend_collection(train, pred_prob_train) - else: - self.e_train = train + self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train) + elif isinstance(train, ExtendedCollection): + self.e_train = train self.n_classes = self.e_train.n_classes [e_train_0, e_train_1] = self.e_train.split_by_pred() @@ -95,12 +95,12 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator): # TODO: test if not ext: pred_prob = self.c_model.predict_proba(instances) - e_inst = EC.extend_instances(instances, pred_prob) + e_inst = ExtendedCollection.extend_instances(instances, pred_prob) else: e_inst = instances _ncl = int(math.sqrt(self.n_classes)) - s_inst, norms = EC.split_inst_by_pred(_ncl, e_inst) + s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst) [estim_prev_0, estim_prev_1] = [ self._quantify_helper(inst, norm, q_model) for (inst, norm, q_model) in zip(