diff --git a/src/data/datamodule.py b/src/data/datamodule.py index 067d47f..767f349 100644 --- a/src/data/datamodule.py +++ b/src/data/datamodule.py @@ -181,7 +181,7 @@ class BertDataModule(RecurrentDataModule): Pytorch Lightning Datamodule to be deployed with BertGen. https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html """ - def __init__(self, multilingualIndex, batchsize=64, max_len=512, zero_shot=False, zscl_langs=None): + def __init__(self, multilingualIndex, batchsize=64, max_len=512, zero_shot=False, zscl_langs=None, debug=False): """ Init BertDataModule. :param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents @@ -196,28 +196,33 @@ class BertDataModule(RecurrentDataModule): zscl_langs = [] self.zero_shot = zero_shot self.train_langs = zscl_langs + self.debug = debug + if self.debug: + print('\n[Running on DEBUG mode - samples per language are reduced to 50 max!]\n') def setup(self, stage=None): if stage == 'fit' or stage is None: if self.zero_shot: - l_train_raw, l_train_target = self.multilingualIndex.l_train_raw_zero_shot(langs=self.train_langs) # todo: check this! + l_train_raw, l_train_target = self.multilingualIndex.l_train_raw_zero_shot(langs=self.train_langs) else: l_train_raw, l_train_target = self.multilingualIndex.l_train_raw() - # Debug settings: reducing number of samples - # l_train_raw = {l: train[:5] for l, train in l_train_raw.items()} - # l_train_target = {l: target[:5] for l, target in l_train_target.items()} + if self.debug: + # Debug settings: reducing number of samples + l_train_raw = {l: train[:50] for l, train in l_train_raw.items()} + l_train_target = {l: target[:50] for l, target in l_train_target.items()} l_train_index = tokenize(l_train_raw, max_len=self.max_len) self.training_dataset = RecurrentDataset(l_train_index, l_train_target, lPad_index=self.multilingualIndex.l_pad()) if self.zero_shot: - l_val_raw, l_val_target = self.multilingualIndex.l_val_raw_zero_shot(langs=self.train_langs) # todo: check this! + l_val_raw, l_val_target = self.multilingualIndex.l_val_raw_zero_shot(langs=self.train_langs) else: l_val_raw, l_val_target = self.multilingualIndex.l_val_raw() - # Debug settings: reducing number of samples - # l_val_raw = {l: train[:5] for l, train in l_val_raw.items()} - # l_val_target = {l: target[:5] for l, target in l_val_target.items()} + if self.debug: + # Debug settings: reducing number of samples + l_val_raw = {l: train[:50] for l, train in l_val_raw.items()} + l_val_target = {l: target[:50] for l, target in l_val_target.items()} l_val_index = tokenize(l_val_raw, max_len=self.max_len) self.val_dataset = RecurrentDataset(l_val_index, l_val_target, @@ -225,12 +230,13 @@ class BertDataModule(RecurrentDataModule): if stage == 'test' or stage is None: if self.zero_shot: - l_test_raw, l_test_target = self.multilingualIndex.l_test_raw_zero_shot(langs=self.train_langs) # todo: check this! + l_test_raw, l_test_target = self.multilingualIndex.l_test_raw_zero_shot(langs=self.train_langs) else: l_test_raw, l_test_target = self.multilingualIndex.l_test_raw() - # Debug settings: reducing number of samples - # l_test_raw = {l: train[:5] for l, train in l_test_raw.items()} - # l_test_target = {l: target[:5] for l, target in l_test_target.items()} + if self.debug: + # Debug settings: reducing number of samples + l_test_raw = {l: train[:50] for l, train in l_test_raw.items()} + l_test_target = {l: target[:50] for l, target in l_test_target.items()} l_test_index = tokenize(l_test_raw, max_len=self.max_len) self.test_dataset = RecurrentDataset(l_test_index, l_test_target, @@ -241,10 +247,16 @@ class BertDataModule(RecurrentDataModule): NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files" :return: """ - return DataLoader(self.training_dataset, batch_size=self.batchsize) + return DataLoader(self.training_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert) def val_dataloader(self): - return DataLoader(self.val_dataset, batch_size=self.batchsize) + return DataLoader(self.val_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert) def test_dataloader(self): - return DataLoader(self.test_dataset, batch_size=self.batchsize) + return DataLoader(self.test_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert) + + def collate_fn_bert(self, data): + x_batch = np.vstack([elem[0] for elem in data]) + y_batch = np.vstack([elem[1] for elem in data]) + lang_batch = [elem[2] for elem in data] + return torch.LongTensor(x_batch), torch.FloatTensor(y_batch), lang_batch diff --git a/src/models/pl_bert.py b/src/models/pl_bert.py index 1da9c69..dba9c8e 100644 --- a/src/models/pl_bert.py +++ b/src/models/pl_bert.py @@ -23,7 +23,7 @@ class BertModel(pl.LightningModule): self.macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus) self.microK = CustomK(num_classes=output_size, average='micro', device=self.gpus) self.macroK = CustomK(num_classes=output_size, average='macro', device=self.gpus) - # Language specific metrics to compute metrics at epoch level + # Language specific metrics to compute at epoch level self.lang_macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus) self.lang_microF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus) self.lang_macroK = CustomF1(num_classes=output_size, average='macro', device=self.gpus) @@ -44,9 +44,7 @@ class BertModel(pl.LightningModule): return logits def training_step(self, train_batch, batch_idx): - X, y, _, batch_langs = train_batch - X = torch.cat(X).view([X[0].shape[0], len(X)]) - y = y.type(torch.FloatTensor) + X, y, batch_langs = train_batch y = y.to('cuda' if self.gpus else 'cpu') logits, _ = self.forward(X) loss = self.loss(logits, y) @@ -99,9 +97,7 @@ class BertModel(pl.LightningModule): self.logger.experiment.add_scalars('train-langs-microK', {f'{lang}': avg_microK}, self.current_epoch) def validation_step(self, val_batch, batch_idx): - X, y, _, batch_langs = val_batch - X = torch.cat(X).view([X[0].shape[0], len(X)]) - y = y.type(torch.FloatTensor) + X, y, batch_langs = val_batch y = y.to('cuda' if self.gpus else 'cpu') logits, _ = self.forward(X) loss = self.loss(logits, y) @@ -118,12 +114,10 @@ class BertModel(pl.LightningModule): return {'loss': loss} def test_step(self, test_batch, batch_idx): - X, y, _, batch_langs = test_batch - X = torch.cat(X).view([X[0].shape[0], len(X)]) - y = y.type(torch.FloatTensor) + X, y, batch_langs = test_batch y = y.to('cuda' if self.gpus else 'cpu') logits, _ = self.forward(X) - loss = self.loss(logits, y) + # loss = self.loss(logits, y) # Squashing logits through Sigmoid in order to get confidence score predictions = torch.sigmoid(logits) > 0.5 microF1 = self.microF1(predictions, y) diff --git a/src/models/pl_gru.py b/src/models/pl_gru.py index 4adb148..f6feb43 100644 --- a/src/models/pl_gru.py +++ b/src/models/pl_gru.py @@ -42,7 +42,7 @@ class RecurrentModel(pl.LightningModule): self.macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus) self.microK = CustomK(num_classes=output_size, average='micro', device=self.gpus) self.macroK = CustomK(num_classes=output_size, average='macro', device=self.gpus) - # Language specific metrics to compute metrics at epoch level + # Language specific metrics to compute at epoch level self.lang_macroF1 = CustomF1(num_classes=output_size, average='macro', device=self.gpus) self.lang_microF1 = CustomF1(num_classes=output_size, average='micro', device=self.gpus) self.lang_macroK = CustomF1(num_classes=output_size, average='macro', device=self.gpus) diff --git a/src/view_generators.py b/src/view_generators.py index 0804aec..cd992ba 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -474,7 +474,8 @@ class BertGen(ViewGen): create_if_not_exist(self.logger.save_dir) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512, - zero_shot=self.zero_shot, zscl_langs=self.train_langs) + zero_shot=self.zero_shot, zscl_langs=self.train_langs, + debug=True) if self.zero_shot: print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')