Typing fixed, gitignore updated
This commit is contained in:
parent
469dcb5898
commit
dfd8d11e8f
|
@ -2,4 +2,5 @@
|
|||
quavenv/*
|
||||
*.pdf
|
||||
quacc/__pycache__/*
|
||||
tests/__pycache__/*
|
||||
tests/__pycache__/*
|
||||
.coverage
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, List, Optional
|
||||
from typing import List, Optional, Self
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
@ -44,7 +44,7 @@ class ExtendedCollection(LabelledCollection):
|
|||
):
|
||||
super().__init__(instances, labels, classes=classes)
|
||||
|
||||
def split_by_pred(self):
|
||||
def split_by_pred(self) -> List[Self]:
|
||||
_ncl = int(math.sqrt(self.n_classes))
|
||||
_indexes = ExtendedCollection._split_index_by_pred(_ncl, self.instances)
|
||||
if isinstance(self.instances, np.ndarray):
|
||||
|
@ -128,7 +128,9 @@ class ExtendedCollection(LabelledCollection):
|
|||
return n_x
|
||||
|
||||
@classmethod
|
||||
def extend_collection(cls, base: LabelledCollection, pred_proba: np.ndarray) -> Any:
|
||||
def extend_collection(
|
||||
cls, base: LabelledCollection, pred_proba: np.ndarray
|
||||
) -> Self:
|
||||
n_classes = base.n_classes
|
||||
|
||||
# n_X = [ X | predicted probs. ]
|
||||
|
|
|
@ -8,17 +8,17 @@ from sklearn.base import BaseEstimator
|
|||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import cross_val_predict
|
||||
|
||||
from quacc.data import ExtendedCollection as EC
|
||||
from quacc.data import ExtendedCollection
|
||||
|
||||
|
||||
class AccuracyEstimator:
|
||||
def extend(self, base: LabelledCollection, pred_proba=None) -> EC:
|
||||
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||
if not pred_proba:
|
||||
pred_proba = self.c_model.predict_proba(base.X)
|
||||
return EC.extend_collection(base, pred_proba)
|
||||
return ExtendedCollection.extend_collection(base, pred_proba)
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, train: LabelledCollection | EC):
|
||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
@ -32,7 +32,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
|
|||
self.q_model = SLD(LogisticRegression())
|
||||
self.e_train = None
|
||||
|
||||
def fit(self, train: LabelledCollection | EC):
|
||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||
# check if model is fit
|
||||
# self.model.fit(*train.Xy)
|
||||
if isinstance(train, LabelledCollection):
|
||||
|
@ -40,7 +40,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
|
|||
self.c_model, *train.Xy, method="predict_proba"
|
||||
)
|
||||
|
||||
self.e_train = EC.extend_collection(train, pred_prob_train)
|
||||
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
||||
else:
|
||||
self.e_train = train
|
||||
|
||||
|
@ -49,7 +49,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
|
|||
def estimate(self, instances, ext=False):
|
||||
if not ext:
|
||||
pred_prob = self.c_model.predict_proba(instances)
|
||||
e_inst = EC.extend_instances(instances, pred_prob)
|
||||
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
||||
else:
|
||||
e_inst = instances
|
||||
|
||||
|
@ -71,9 +71,9 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
|||
self.c_model = c_model
|
||||
self.q_model_0 = SLD(LogisticRegression())
|
||||
self.q_model_1 = SLD(LogisticRegression())
|
||||
self.e_train: EC = None
|
||||
self.e_train = None
|
||||
|
||||
def fit(self, train: LabelledCollection | EC):
|
||||
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||
# check if model is fit
|
||||
# self.model.fit(*train.Xy)
|
||||
if isinstance(train, LabelledCollection):
|
||||
|
@ -81,9 +81,9 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
|||
self.c_model, *train.Xy, method="predict_proba"
|
||||
)
|
||||
|
||||
self.e_train = EC.extend_collection(train, pred_prob_train)
|
||||
else:
|
||||
self.e_train = train
|
||||
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
||||
elif isinstance(train, ExtendedCollection):
|
||||
self.e_train = train
|
||||
|
||||
self.n_classes = self.e_train.n_classes
|
||||
[e_train_0, e_train_1] = self.e_train.split_by_pred()
|
||||
|
@ -95,12 +95,12 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
|||
# TODO: test
|
||||
if not ext:
|
||||
pred_prob = self.c_model.predict_proba(instances)
|
||||
e_inst = EC.extend_instances(instances, pred_prob)
|
||||
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
||||
else:
|
||||
e_inst = instances
|
||||
|
||||
_ncl = int(math.sqrt(self.n_classes))
|
||||
s_inst, norms = EC.split_inst_by_pred(_ncl, e_inst)
|
||||
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
|
||||
[estim_prev_0, estim_prev_1] = [
|
||||
self._quantify_helper(inst, norm, q_model)
|
||||
for (inst, norm, q_model) in zip(
|
||||
|
|
Loading…
Reference in New Issue