sampling GLAMI1-M dataset

This commit is contained in:
Andrea Pedrotti 2023-03-16 18:10:05 +01:00
parent ee38bcda10
commit ee2a9481de
4 changed files with 13 additions and 37 deletions

View File

@ -108,30 +108,13 @@ class gFunDataset:
return dataset, labels, data_langs
def _load_glami(self, dataset_dir, nrows):
# TODO: a better way to get a stratified sampling of the dataset (see: groupby + sample)
def _balanced_sample(data, n, remainder=0):
import pandas as pd
langs = sorted(data.geo.unique().tolist())
dict_n = {lang: n for lang in langs}
dict_n[langs[0]] += remainder
sampled = []
for lang in langs:
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
return pd.concat(sampled, axis=0)
# TODO: set this sampling as determinsitic/dependeing on the seed
lang_nrows = (
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
) # GLAMI 1-M has 13 languages
remainder = (
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows)
test_split = get_dataframe("test", dataset_dir=dataset_dir).sample(
n=int(nrows / 10)
)
train_split = get_dataframe("train", dataset_dir=dataset_dir)
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder)
gb_train = train_split.groupby("geo")
gb_test = test_split.groupby("geo")
if self.data_langs is None:
data_langs = sorted(train_split.geo.unique().tolist())
@ -139,14 +122,6 @@ class gFunDataset:
if self.labels is None:
labels = train_split.category_name.unique().tolist()
# TODO: atm test data should contain same languages as train data
test_split = get_dataframe("test", dataset_dir=dataset_dir)
# TODO: atm we're using 1:1 train-test
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
gb_train = train_split.groupby("geo")
gb_test = test_split.groupby("geo")
def _format_glami(data_df):
text = (data_df.name + " " + data_df.description).tolist()
image = data_df.image_file.tolist()

View File

@ -340,7 +340,7 @@ class EarlyStopping:
self.experiment_name = experiment_name
def __call__(self, validation, model, epoch):
if validation > self.best_score:
if validation >= self.best_score:
if self.verbose:
print(
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"

View File

@ -100,7 +100,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
return "bert-base-uncased"
elif "mbert" == model_name:
return "bert-base-multilingual-uncased"
elif "xlm" == model_name:
elif "xlm-roberta" == model_name:
return "xlm-roberta-base"
elif "mt5" == model_name:
return "google/mt5-small"
@ -270,8 +270,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
elif "bert" in model_name:
if "multilingual" in model_name:
return "mbert"
elif "xlm" in model_name:
return "xlm"
elif "xlm-roberta" in model_name:
return "xlm-roberta"
else:
return model_name

View File

@ -13,6 +13,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- Transformers VGFs:
- scheduler with warmup and cosine
- freeze params method
- General:
[!] zero-shot setup
@ -177,17 +178,17 @@ if __name__ == "__main__":
parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean")
# transformer parameters ---------------
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--textual_trf_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--textual_lr", type=float, default=1e-5)
parser.add_argument("--visual_lr", type=float, default=1e-5)
parser.add_argument("--textual_lr", type=float, default=1e-4)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10)
# Visual Transformer parameters --------------
parser.add_argument("--visual_trf_name", type=str, default="vit")
parser.add_argument("--visual_lr", type=float, default=1e-4)
args = parser.parse_args()