From ff00de18cbf0396c8d1ad2f12bb96bb3fe12fb2b Mon Sep 17 00:00:00 2001
From: Alejandro Moreo <alejandro.moreo@isti.cnr.it>
Date: Fri, 19 Jan 2024 18:24:38 +0100
Subject: [PATCH] updating documentation a bit

---
 quapy/method/aggregative.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py
index c3e24db..b7b7409 100644
--- a/quapy/method/aggregative.py
+++ b/quapy/method/aggregative.py
@@ -1141,6 +1141,8 @@ class ThresholdOptimization(BinaryAggregativeQuantifier):
         return candidates
 
     def aggregate_with_threshold(self, classif_predictions, tprs, fprs, thresholds):
+        # This function performs the adjusted count for given tpr, fpr, and threshold.
+        # Note that, due to broadcasting, tprs, fprs, and thresholds could be arrays of length > 1
         prevs_estims = np.mean(classif_predictions[:, None] >= thresholds, axis=0)
         prevs_estims = (prevs_estims - fprs) / (tprs - fprs)
         prevs_estims = F.as_binary_prevalence(prevs_estims, clip_if_necessary=True)
@@ -1164,8 +1166,8 @@ class ThresholdOptimization(BinaryAggregativeQuantifier):
         return FP / (FP + TN)
 
     def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
-        # the standard behavior is to keep the best threshold only
         decision_scores, y = classif_predictions.Xy
+        # the standard behavior is to keep the best threshold only
         self.tpr, self.fpr, self.threshold = self._eval_candidate_thresholds(decision_scores, y)[0]
         return self
 
@@ -1270,8 +1272,8 @@ class MS(ThresholdOptimization):
         return 1
 
     def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
-        # keeps all candidates
         decision_scores, y = classif_predictions.Xy
+        # keeps all candidates
         tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y)
         self.tprs = tprs_fprs_thresholds[:, 0]
         self.fprs = tprs_fprs_thresholds[:, 1]