Typing fixed, gitignore updated

This commit is contained in:
Lorenzo Volpi 2023-07-28 01:47:44 +02:00
parent 469dcb5898
commit dfd8d11e8f
3 changed files with 21 additions and 18 deletions

3
.gitignore vendored
View File

@ -2,4 +2,5 @@
quavenv/*
*.pdf
quacc/__pycache__/*
tests/__pycache__/*
tests/__pycache__/*
.coverage

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import List, Optional, Self
import numpy as np
import math
@ -44,7 +44,7 @@ class ExtendedCollection(LabelledCollection):
):
super().__init__(instances, labels, classes=classes)
def split_by_pred(self):
def split_by_pred(self) -> List[Self]:
_ncl = int(math.sqrt(self.n_classes))
_indexes = ExtendedCollection._split_index_by_pred(_ncl, self.instances)
if isinstance(self.instances, np.ndarray):
@ -128,7 +128,9 @@ class ExtendedCollection(LabelledCollection):
return n_x
@classmethod
def extend_collection(cls, base: LabelledCollection, pred_proba: np.ndarray) -> Any:
def extend_collection(
cls, base: LabelledCollection, pred_proba: np.ndarray
) -> Self:
n_classes = base.n_classes
# n_X = [ X | predicted probs. ]

View File

@ -8,17 +8,17 @@ from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
from quacc.data import ExtendedCollection as EC
from quacc.data import ExtendedCollection
class AccuracyEstimator:
def extend(self, base: LabelledCollection, pred_proba=None) -> EC:
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
if not pred_proba:
pred_proba = self.c_model.predict_proba(base.X)
return EC.extend_collection(base, pred_proba)
return ExtendedCollection.extend_collection(base, pred_proba)
@abstractmethod
def fit(self, train: LabelledCollection | EC):
def fit(self, train: LabelledCollection | ExtendedCollection):
...
@abstractmethod
@ -32,7 +32,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
self.q_model = SLD(LogisticRegression())
self.e_train = None
def fit(self, train: LabelledCollection | EC):
def fit(self, train: LabelledCollection | ExtendedCollection):
# check if model is fit
# self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection):
@ -40,7 +40,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
self.c_model, *train.Xy, method="predict_proba"
)
self.e_train = EC.extend_collection(train, pred_prob_train)
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
else:
self.e_train = train
@ -49,7 +49,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
def estimate(self, instances, ext=False):
if not ext:
pred_prob = self.c_model.predict_proba(instances)
e_inst = EC.extend_instances(instances, pred_prob)
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
else:
e_inst = instances
@ -71,9 +71,9 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
self.c_model = c_model
self.q_model_0 = SLD(LogisticRegression())
self.q_model_1 = SLD(LogisticRegression())
self.e_train: EC = None
self.e_train = None
def fit(self, train: LabelledCollection | EC):
def fit(self, train: LabelledCollection | ExtendedCollection):
# check if model is fit
# self.model.fit(*train.Xy)
if isinstance(train, LabelledCollection):
@ -81,9 +81,9 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
self.c_model, *train.Xy, method="predict_proba"
)
self.e_train = EC.extend_collection(train, pred_prob_train)
else:
self.e_train = train
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
elif isinstance(train, ExtendedCollection):
self.e_train = train
self.n_classes = self.e_train.n_classes
[e_train_0, e_train_1] = self.e_train.split_by_pred()
@ -95,12 +95,12 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
# TODO: test
if not ext:
pred_prob = self.c_model.predict_proba(instances)
e_inst = EC.extend_instances(instances, pred_prob)
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
else:
e_inst = instances
_ncl = int(math.sqrt(self.n_classes))
s_inst, norms = EC.split_inst_by_pred(_ncl, e_inst)
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
[estim_prev_0, estim_prev_1] = [
self._quantify_helper(inst, norm, q_model)
for (inst, norm, q_model) in zip(