From e40c40960987da110e402d918020d690266e1d6f Mon Sep 17 00:00:00 2001
From: Alejandro Moreo <alejandro.moreo@isti.cnr.it>
Date: Tue, 4 Oct 2022 11:03:08 +0200
Subject: [PATCH] bugfix in NeuralClassifierTrainer; it was only configured to
 work well in binary problems

---
 quapy/classification/neural.py | 4 ++--
 quapy/data/preprocessing.py    | 2 +-
 quapy/method/aggregative.py    | 5 ++++-
 quapy/method/neural.py         | 1 +
 4 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/quapy/classification/neural.py b/quapy/classification/neural.py
index 0d576c5..18fd646 100644
--- a/quapy/classification/neural.py
+++ b/quapy/classification/neural.py
@@ -42,7 +42,7 @@ class NeuralClassifierTrainer:
                  batch_size=64,
                  batch_size_test=512,
                  padding_length=300,
-                 device='cpu',
+                 device='cuda',
                  checkpointpath='../checkpoint/classifier_net.dat'):
 
         super().__init__()
@@ -62,7 +62,6 @@ class NeuralClassifierTrainer:
         }
         self.learner_hyperparams = self.net.get_params()
         self.checkpointpath = checkpointpath
-        self.classes_ = np.asarray([0, 1])
 
         print(f'[NeuralNetwork running on {device}]')
         os.makedirs(Path(checkpointpath).parent, exist_ok=True)
@@ -174,6 +173,7 @@ class NeuralClassifierTrainer:
         :return:
         """
         train, val = LabelledCollection(instances, labels).split_stratified(1-val_split)
+        self.classes_ = train.classes_
         opt = self.trainer_hyperparams
         checkpoint = self.checkpointpath
         self.reset_net_params(self.vocab_size, train.n_classes)
diff --git a/quapy/data/preprocessing.py b/quapy/data/preprocessing.py
index f04f010..99a267b 100644
--- a/quapy/data/preprocessing.py
+++ b/quapy/data/preprocessing.py
@@ -184,7 +184,7 @@ class IndexTransformer:
 
     def _index(self, documents):
         vocab = self.vocabulary_.copy()
-        return [[vocab.prevalence(word, self.unk) for word in self.analyzer(doc)] for doc in tqdm(documents, 'indexing')]
+        return [[vocab.get(word, self.unk) for word in self.analyzer(doc)] for doc in tqdm(documents, 'indexing')]
 
     def fit_transform(self, X, n_jobs=-1):
         """
diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py
index bb71525..c0029ac 100644
--- a/quapy/method/aggregative.py
+++ b/quapy/method/aggregative.py
@@ -282,6 +282,7 @@ class ACC(AggregativeQuantifier):
         """
         if val_split is None:
             val_split = self.val_split
+            classes = data.classes_
         if isinstance(val_split, int):
             assert fit_learner == True, \
                 'the parameters for the adjustment cannot be estimated with kFCV with fit_learner=False'
@@ -300,6 +301,7 @@ class ACC(AggregativeQuantifier):
             y = np.concatenate(y)
             y_ = np.concatenate(y_)
             class_count = data.counts()
+            classes = data.classes_
 
             # fit the learner on all data
             self.learner, _ = _training_helper(self.learner, data, fit_learner, val_split=None)
@@ -308,10 +310,11 @@ class ACC(AggregativeQuantifier):
             self.learner, val_data = _training_helper(self.learner, data, fit_learner, val_split=val_split)
             y_ = self.learner.predict(val_data.instances)
             y = val_data.labels
+            classes = val_data.classes_
 
         self.cc = CC(self.learner)
 
-        self.Pte_cond_estim_ = self.getPteCondEstim(data.classes_, y, y_)
+        self.Pte_cond_estim_ = self.getPteCondEstim(classes, y, y_)
 
         return self
 
diff --git a/quapy/method/neural.py b/quapy/method/neural.py
index bf1f375..b42ada7 100644
--- a/quapy/method/neural.py
+++ b/quapy/method/neural.py
@@ -82,6 +82,7 @@ class QuaNetTrainer(BaseQuantifier):
         assert hasattr(learner, 'predict_proba'), \
             f'the learner {learner.__class__.__name__} does not seem to be able to produce posterior probabilities ' \
                 f'since it does not implement the method "predict_proba"'
+        assert sample_size is not None, 'sample_size cannot be None'
         self.learner = learner
         self.sample_size = sample_size
         self.n_epochs = n_epochs