forked from moreo/QuaPy
Merge pull request #12 from pglez82/protocols
changing app to use prevalence_linspace function with smooth limits
This commit is contained in:
commit
1742b75504
|
@ -223,6 +223,7 @@ def cross_generate_predictions(
|
|||
|
||||
# fit the learner on all data
|
||||
learner.fit(*data.Xy)
|
||||
y = data.y
|
||||
classes = data.classes_
|
||||
else:
|
||||
learner, val_data = _training_helper(
|
||||
|
|
|
@ -132,15 +132,17 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
:param n_prevalences: the number of equidistant prevalence points to extract from the [0,1] interval for the
|
||||
grid (default is 21)
|
||||
:param repeats: number of copies for each valid prevalence vector (default is 10)
|
||||
:param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1
|
||||
:param random_state: allows replicating samples across runs (default None)
|
||||
"""
|
||||
|
||||
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_state=None, return_type='sample_prev'):
|
||||
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, smooth_limits_epsilon=0, random_state=None, return_type='sample_prev'):
|
||||
super(APP, self).__init__(random_state)
|
||||
self.data = data
|
||||
self.sample_size = sample_size
|
||||
self.n_prevalences = n_prevalences
|
||||
self.repeats = repeats
|
||||
self.smooth_limits_epsilon = smooth_limits_epsilon
|
||||
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||
|
||||
def prevalence_grid(self):
|
||||
|
@ -159,7 +161,7 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
in the grid multiplied by `repeat`
|
||||
"""
|
||||
dimensions = self.data.n_classes
|
||||
s = np.linspace(0., 1., self.n_prevalences, endpoint=True)
|
||||
s = F.prevalence_linspace(self.n_prevalences, repeats=1, smooth_limits_epsilon=self.smooth_limits_epsilon)
|
||||
s = [s] * (dimensions - 1)
|
||||
prevs = [p for p in itertools.product(*s, repeat=1) if (sum(p) <= 1.0)]
|
||||
prevs = np.asarray(prevs).reshape(len(prevs), -1)
|
||||
|
|
Loading…
Reference in New Issue