Typing fixed, gitignore updated
This commit is contained in:
parent
469dcb5898
commit
dfd8d11e8f
|
@ -3,3 +3,4 @@ quavenv/*
|
||||||
*.pdf
|
*.pdf
|
||||||
quacc/__pycache__/*
|
quacc/__pycache__/*
|
||||||
tests/__pycache__/*
|
tests/__pycache__/*
|
||||||
|
.coverage
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, List, Optional
|
from typing import List, Optional, Self
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
|
@ -44,7 +44,7 @@ class ExtendedCollection(LabelledCollection):
|
||||||
):
|
):
|
||||||
super().__init__(instances, labels, classes=classes)
|
super().__init__(instances, labels, classes=classes)
|
||||||
|
|
||||||
def split_by_pred(self):
|
def split_by_pred(self) -> List[Self]:
|
||||||
_ncl = int(math.sqrt(self.n_classes))
|
_ncl = int(math.sqrt(self.n_classes))
|
||||||
_indexes = ExtendedCollection._split_index_by_pred(_ncl, self.instances)
|
_indexes = ExtendedCollection._split_index_by_pred(_ncl, self.instances)
|
||||||
if isinstance(self.instances, np.ndarray):
|
if isinstance(self.instances, np.ndarray):
|
||||||
|
@ -128,7 +128,9 @@ class ExtendedCollection(LabelledCollection):
|
||||||
return n_x
|
return n_x
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extend_collection(cls, base: LabelledCollection, pred_proba: np.ndarray) -> Any:
|
def extend_collection(
|
||||||
|
cls, base: LabelledCollection, pred_proba: np.ndarray
|
||||||
|
) -> Self:
|
||||||
n_classes = base.n_classes
|
n_classes = base.n_classes
|
||||||
|
|
||||||
# n_X = [ X | predicted probs. ]
|
# n_X = [ X | predicted probs. ]
|
||||||
|
|
|
@ -8,17 +8,17 @@ from sklearn.base import BaseEstimator
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.model_selection import cross_val_predict
|
from sklearn.model_selection import cross_val_predict
|
||||||
|
|
||||||
from quacc.data import ExtendedCollection as EC
|
from quacc.data import ExtendedCollection
|
||||||
|
|
||||||
|
|
||||||
class AccuracyEstimator:
|
class AccuracyEstimator:
|
||||||
def extend(self, base: LabelledCollection, pred_proba=None) -> EC:
|
def extend(self, base: LabelledCollection, pred_proba=None) -> ExtendedCollection:
|
||||||
if not pred_proba:
|
if not pred_proba:
|
||||||
pred_proba = self.c_model.predict_proba(base.X)
|
pred_proba = self.c_model.predict_proba(base.X)
|
||||||
return EC.extend_collection(base, pred_proba)
|
return ExtendedCollection.extend_collection(base, pred_proba)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, train: LabelledCollection | EC):
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -32,7 +32,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
|
||||||
self.q_model = SLD(LogisticRegression())
|
self.q_model = SLD(LogisticRegression())
|
||||||
self.e_train = None
|
self.e_train = None
|
||||||
|
|
||||||
def fit(self, train: LabelledCollection | EC):
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||||
# check if model is fit
|
# check if model is fit
|
||||||
# self.model.fit(*train.Xy)
|
# self.model.fit(*train.Xy)
|
||||||
if isinstance(train, LabelledCollection):
|
if isinstance(train, LabelledCollection):
|
||||||
|
@ -40,7 +40,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
|
||||||
self.c_model, *train.Xy, method="predict_proba"
|
self.c_model, *train.Xy, method="predict_proba"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.e_train = EC.extend_collection(train, pred_prob_train)
|
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
||||||
else:
|
else:
|
||||||
self.e_train = train
|
self.e_train = train
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class MulticlassAccuracyEstimator(AccuracyEstimator):
|
||||||
def estimate(self, instances, ext=False):
|
def estimate(self, instances, ext=False):
|
||||||
if not ext:
|
if not ext:
|
||||||
pred_prob = self.c_model.predict_proba(instances)
|
pred_prob = self.c_model.predict_proba(instances)
|
||||||
e_inst = EC.extend_instances(instances, pred_prob)
|
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
||||||
else:
|
else:
|
||||||
e_inst = instances
|
e_inst = instances
|
||||||
|
|
||||||
|
@ -71,9 +71,9 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
||||||
self.c_model = c_model
|
self.c_model = c_model
|
||||||
self.q_model_0 = SLD(LogisticRegression())
|
self.q_model_0 = SLD(LogisticRegression())
|
||||||
self.q_model_1 = SLD(LogisticRegression())
|
self.q_model_1 = SLD(LogisticRegression())
|
||||||
self.e_train: EC = None
|
self.e_train = None
|
||||||
|
|
||||||
def fit(self, train: LabelledCollection | EC):
|
def fit(self, train: LabelledCollection | ExtendedCollection):
|
||||||
# check if model is fit
|
# check if model is fit
|
||||||
# self.model.fit(*train.Xy)
|
# self.model.fit(*train.Xy)
|
||||||
if isinstance(train, LabelledCollection):
|
if isinstance(train, LabelledCollection):
|
||||||
|
@ -81,8 +81,8 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
||||||
self.c_model, *train.Xy, method="predict_proba"
|
self.c_model, *train.Xy, method="predict_proba"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.e_train = EC.extend_collection(train, pred_prob_train)
|
self.e_train = ExtendedCollection.extend_collection(train, pred_prob_train)
|
||||||
else:
|
elif isinstance(train, ExtendedCollection):
|
||||||
self.e_train = train
|
self.e_train = train
|
||||||
|
|
||||||
self.n_classes = self.e_train.n_classes
|
self.n_classes = self.e_train.n_classes
|
||||||
|
@ -95,12 +95,12 @@ class BinaryQuantifierAccuracyEstimator(AccuracyEstimator):
|
||||||
# TODO: test
|
# TODO: test
|
||||||
if not ext:
|
if not ext:
|
||||||
pred_prob = self.c_model.predict_proba(instances)
|
pred_prob = self.c_model.predict_proba(instances)
|
||||||
e_inst = EC.extend_instances(instances, pred_prob)
|
e_inst = ExtendedCollection.extend_instances(instances, pred_prob)
|
||||||
else:
|
else:
|
||||||
e_inst = instances
|
e_inst = instances
|
||||||
|
|
||||||
_ncl = int(math.sqrt(self.n_classes))
|
_ncl = int(math.sqrt(self.n_classes))
|
||||||
s_inst, norms = EC.split_inst_by_pred(_ncl, e_inst)
|
s_inst, norms = ExtendedCollection.split_inst_by_pred(_ncl, e_inst)
|
||||||
[estim_prev_0, estim_prev_1] = [
|
[estim_prev_0, estim_prev_1] = [
|
||||||
self._quantify_helper(inst, norm, q_model)
|
self._quantify_helper(inst, norm, q_model)
|
||||||
for (inst, norm, q_model) in zip(
|
for (inst, norm, q_model) in zip(
|
||||||
|
|
Loading…
Reference in New Issue