From bfe4b8b51a42ce29bba812697f22aed451a9ec58 Mon Sep 17 00:00:00 2001
From: Alejandro Moreo <alejandro.moreo@isti.cnr.it>
Date: Fri, 3 Jun 2022 13:51:22 +0200
Subject: [PATCH] updating properties of labelled collection

---
 quapy/data/base.py | 38 ++++++++++++++++++++++++++++++++++++++
 quapy/protocol.py  | 17 ++++++++++++++---
 2 files changed, 52 insertions(+), 3 deletions(-)

diff --git a/quapy/data/base.py b/quapy/data/base.py
index c555692..b22a71f 100644
--- a/quapy/data/base.py
+++ b/quapy/data/base.py
@@ -63,6 +63,7 @@ class LabelledCollection:
         """
         return self.instances.shape[0]
 
+    @property
     def prevalence(self):
         """
         Returns the prevalence, or relative frequency, of the classes of interest.
@@ -248,6 +249,43 @@ class LabelledCollection:
         """
         return self.instances, self.labels
 
+    @property
+    def Xp(self):
+        """
+        Gets the instances and the true prevalence. This is useful when implementing evaluation protocols
+
+        :return: a tuple `(instances, prevalence)` from this collection
+        """
+        return self.instances, self.prevalence()
+
+    @property
+    def X(self):
+        """
+        An alias to self.instances
+
+        :return: self.instances
+        """
+        return self.instances
+
+    @property
+    def y(self):
+        """
+        An alias to self.labels
+
+        :return: self.labels
+        """
+        return self.labels
+
+    @property
+    def p(self):
+        """
+        An alias to self.prevalence()
+
+        :return: self.prevalence()
+        """
+        return self.prevalence()
+
+
     def stats(self, show=True):
         """
         Returns (and eventually prints) a dictionary with some stats of this collection. E.g.,:
diff --git a/quapy/protocol.py b/quapy/protocol.py
index f539830..c55c3ef 100644
--- a/quapy/protocol.py
+++ b/quapy/protocol.py
@@ -84,14 +84,16 @@ class AbstractStochasticSeededProtocol(AbstractProtocol):
             if self.random_seed is not None:
                 stack.enter_context(qp.util.temp_seed(self.random_seed))
             for params in self.samples_parameters():
-                yield self.collator_fn(self.sample(params))
+                yield self.collator(self.sample(params))
 
-    def set_collator(self, collator_fn):
-        self.collator_fn = collator_fn
+    def collator(self, sample, *args):
+        return sample
 
 
 class OnLabelledCollectionProtocol:
 
+    RETURN_TYPES = ['sample_prev', 'labelled_collection']
+
     def get_labelled_collection(self):
         return self.data
 
@@ -106,6 +108,15 @@ class OnLabelledCollectionProtocol:
             new = deepcopy(self)
             return new.on_preclassified_instances(pre_classifications, in_place=True)
 
+    @classmethod
+    def get_collator(cls, return_type='sample_prev'):
+        assert return_type in cls.RETURN_TYPES, \
+            f'unknown return type passed as argument; valid ones are {cls.RETURN_TYPES}'
+        if return_type=='sample_prev':
+            return lambda lc:lc.Xp
+        elif return_type=='labelled_collection':
+            return lambda lc:lc
+
 
 class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
     """