From 29993386ae98d4abec1004254f29689f71556425 Mon Sep 17 00:00:00 2001
From: Alex Moreo <alejandro.moreo@isti.cnr.it>
Date: Wed, 23 Feb 2022 15:29:39 +0100
Subject: [PATCH] plotting y distributions over time

---
 eDiscovery/functions.py | 12 +--------
 eDiscovery/main.py      | 31 +++++++++++++----------
 eDiscovery/plot.py      | 54 +++++++++++++++++++++++++++++++++++++++++
 3 files changed, 73 insertions(+), 24 deletions(-)

diff --git a/eDiscovery/functions.py b/eDiscovery/functions.py
index 45e68f9..d36feb8 100644
--- a/eDiscovery/functions.py
+++ b/eDiscovery/functions.py
@@ -162,14 +162,6 @@ def estimate_prev_CC(train, pool: LabelledCollection, classifiername:str):
 
 
 def estimate_prev_Q(train, pool, quantifiername, classifiername):
-    # q = qp.model_selection.GridSearchQ(
-    #     ACC(LogisticRegression()),
-    #     param_grid={'C':np.logspace(-3,3,7), 'class_weight':[None, 'balanced']},
-    #     sample_size=len(train),
-    #     protocol='app',
-    #     n_prevpoints=21,
-    #     n_repetitions=10)
-
     q = NewQuantifier(quantifiername, classifiername)
     # q._find_regions((train+pool).instances)
     q.fit(train)
@@ -181,16 +173,14 @@ def estimate_prev_Q(train, pool, quantifiername, classifiername):
 def eval_classifier(learner, test:LabelledCollection):
     predictions = learner.predict(test.instances)
     true_labels = test.labels
-    # f1 = f1_score(true_labels, predictions, average='macro')
     f1 = f1_score(true_labels, predictions, average='binary')
-    # f1 = (true_labels==predictions).mean()
     return f1
 
 
 def ideal_cost(classifier, pool):
     # returns the cost (in terms of number of documents) to review until the last relevant document
     # is processed, assuming the rank produced by this classifier. The cost is said to be "idealized" since
-    # one assumes to be able to stop reviewing when the last relevant is encountered
+    # one assumes to know the optimal stopping point (reached after the last relevant is encountered)
 
     prob = classifier.predict_proba(pool.instances)
     order = np.argsort(prob[:,0])  # col 0 has negative posterior prob, so the natural order is "by relevance"
diff --git a/eDiscovery/main.py b/eDiscovery/main.py
index 0a740c1..f8b213e 100644
--- a/eDiscovery/main.py
+++ b/eDiscovery/main.py
@@ -6,8 +6,7 @@ import functions as fn
 import quapy as qp
 import argparse
 from quapy.data import LabelledCollection
-from plot import eDiscoveryPlot
-
+from plot import eDiscoveryPlot, InOutDistPlot
 
 
 def main(args):
@@ -23,9 +22,10 @@ def main(args):
     collection = qp.util.pickled_resource(f'./dataset/{datasetname}.pkl', fn.create_dataset, datasetname)
     nD = len(collection)
 
-    fig = eDiscoveryPlot(args.output)
+    # fig = eDiscoveryPlot(args.output)
+    fig_dist = InOutDistPlot()
 
-    skip_first_steps = 20
+    skip_first_steps = 1
 
     with qp.util.temp_seed(args.seed):
         # initial labelled data selection
@@ -34,23 +34,20 @@ def main(args):
         else:
             idx = collection.sampling_index(init_nD, *[1 - args.initprev, args.initprev])
         train, pool = fn.split_from_index(collection, idx)
-        #first_train = LabelledCollection(train.instances, train.labels)
 
         # recall_target = 0.99
         i = 0
 
-        # q = fn.NewQuantifier(q_name, clf_name)
-        # print('searching regions')
-        # q._find_regions((train+pool).instances)
-        # print('[done]')
-
         with open(args.output, 'wt') as foo:
             def tee(msg):
                 foo.write(msg + '\n')
                 foo.flush()
                 print(msg)
 
-            tee('it\t%\ttr-size\tte-size\ttr-prev\tte-prev\tte-estim\tte-estimCC\tR\tRhat\tRhatCC\tShift\tAE\tAE_CC\tMF1_Q\tMF1_Clf\tICost\tremaining')
+            tee('it\t%\ttr-size\tte-size\ttr-prev\tte-prev\tte-estim\tte-estimCC\tR\tRhat\tRhatCC\tShift\tAE\tAE_CC'
+                '\tMF1_Q\tMF1_Clf\tICost\tremaining\tba-prev\tba-estim')
+
+            batch_prev_estim, batch_prev_true, q = 0, 0, None
 
             while True:
 
@@ -85,10 +82,12 @@ def main(args):
 
                     tee(f'{i}\t{progress:.2f}\t{nDtr}\t{nDte}\t{tr_p[1]:.3f}\t{te_p[1]:.3f}\t{pool_p_hat_q[1]:.3f}\t{pool_p_hat_cc[1]:.3f}'
                         f'\t{r:.3f}\t{r_hat_q:.3f}\t{r_hat_cc:.3f}\t{tr_te_shift:.5f}\t{ae_q:.4f}\t{ae_cc:.4f}\t{f1_q:.3f}\t{f1_clf:.3f}'
-                        f'\t{ideal_cost}\t{pool.labels.sum()}')
+                        f'\t{ideal_cost}\t{pool.labels.sum()}\t{batch_prev_true}\t{batch_prev_estim:.3f}')
 
                     posteriors = classifier.predict_proba(pool.instances)
-                    fig.plot(posteriors, pool.labels)
+                    in_posteriors = classifier.predict_proba(train.instances)
+                    # fig.plot(posteriors, pool.labels)
+                    fig_dist.plot(in_posteriors, train.labels, posteriors, pool.labels)
 
                     if nDte < k:
                         print('[stop] too few documents remaining')
@@ -98,6 +97,12 @@ def main(args):
                         break
 
                 top_relevant_idx = sampling_fn(pool, classifier, k, progress)
+
+                if q is not None:
+                    batch = pool.sampling_from_index(top_relevant_idx)
+                    batch_prev_estim = q.quantify(batch.instances)[1]
+                    batch_prev_true  = batch.prevalence()[1]
+
                 train, pool = fn.move_documents(train, pool, top_relevant_idx)
 
                 i += 1
diff --git a/eDiscovery/plot.py b/eDiscovery/plot.py
index 67f33ca..649ef21 100644
--- a/eDiscovery/plot.py
+++ b/eDiscovery/plot.py
@@ -142,6 +142,60 @@ class eDiscoveryPlot:
         self.calls += 1
 
 
+class InOutDistPlot:
+
+    def __init__(self, refreshEach=1):
+        self.refreshEach = refreshEach
+
+        # plot the data
+        self.fig, self.axs = plt.subplots(2)
+        self.calls = 0
+
+    def _plot_dist(self, posteriors, y, aXn, title):
+        positive_posteriors = posteriors[y == 1, 1]
+        negative_posteriors = posteriors[y == 0, 1]
+        self.axs[aXn].hist(negative_posteriors, bins=50, label='$Pr(x|\ominus)$', density=False, alpha=.75)
+        self.axs[aXn].hist(positive_posteriors, bins=50, label='$Pr(x|\oplus)$', density=False, alpha=.75)
+        self.axs[aXn].legend()
+        self.axs[aXn].grid()
+        self.axs[aXn].set_xlim(0, 1)
+        self.axs[aXn].set_ylabel(title)
+
+    def plot(self, in_posteriors, in_y, out_posteriors, out_y):
+
+        if (self.calls+1) % self.refreshEach != 0:
+            self.calls += 1
+            return
+
+        fig, axs = self.fig, self.axs
+
+        aXn = 0
+
+        # in-posteriors distribution
+        self._plot_dist(in_posteriors, in_y, aXn, title='training distribution')
+        aXn += 1
+
+        # out-posteriors distribution
+        self._plot_dist(out_posteriors, out_y, aXn, title='pool distribution')
+        aXn += 1
+
+        for i in range(aXn):
+            if self.calls==0:
+                # Shrink current axis by 20%
+                box = axs[i].get_position()
+                axs[i].set_position([box.x0, box.y0, box.width * 0.8, box.height])
+                fig.tight_layout()
+
+            # Put a legend to the right of the current axis
+            axs[i].legend(loc='center left', bbox_to_anchor=(1, 0.5))
+
+        plt.pause(.5)
+        for i in range(aXn):
+            axs[i].cla()
+
+        self.calls += 1
+
+
 if __name__ == '__main__':
 
     assert len(sys.argv) == 3, f'wrong args, syntax is: python {sys.argv[0]} <result_input_path> <dynamic (0|1)>'