sampling GLAMI1-M dataset
This commit is contained in:
parent
ee38bcda10
commit
ee2a9481de
|
@ -108,30 +108,13 @@ class gFunDataset:
|
|||
return dataset, labels, data_langs
|
||||
|
||||
def _load_glami(self, dataset_dir, nrows):
|
||||
# TODO: a better way to get a stratified sampling of the dataset (see: groupby + sample)
|
||||
def _balanced_sample(data, n, remainder=0):
|
||||
import pandas as pd
|
||||
|
||||
langs = sorted(data.geo.unique().tolist())
|
||||
dict_n = {lang: n for lang in langs}
|
||||
dict_n[langs[0]] += remainder
|
||||
|
||||
sampled = []
|
||||
for lang in langs:
|
||||
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
|
||||
|
||||
return pd.concat(sampled, axis=0)
|
||||
|
||||
# TODO: set this sampling as determinsitic/dependeing on the seed
|
||||
lang_nrows = (
|
||||
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
|
||||
) # GLAMI 1-M has 13 languages
|
||||
remainder = (
|
||||
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
|
||||
train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows)
|
||||
test_split = get_dataframe("test", dataset_dir=dataset_dir).sample(
|
||||
n=int(nrows / 10)
|
||||
)
|
||||
|
||||
train_split = get_dataframe("train", dataset_dir=dataset_dir)
|
||||
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder)
|
||||
gb_train = train_split.groupby("geo")
|
||||
gb_test = test_split.groupby("geo")
|
||||
|
||||
if self.data_langs is None:
|
||||
data_langs = sorted(train_split.geo.unique().tolist())
|
||||
|
@ -139,14 +122,6 @@ class gFunDataset:
|
|||
if self.labels is None:
|
||||
labels = train_split.category_name.unique().tolist()
|
||||
|
||||
# TODO: atm test data should contain same languages as train data
|
||||
test_split = get_dataframe("test", dataset_dir=dataset_dir)
|
||||
# TODO: atm we're using 1:1 train-test
|
||||
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
|
||||
|
||||
gb_train = train_split.groupby("geo")
|
||||
gb_test = test_split.groupby("geo")
|
||||
|
||||
def _format_glami(data_df):
|
||||
text = (data_df.name + " " + data_df.description).tolist()
|
||||
image = data_df.image_file.tolist()
|
||||
|
|
|
@ -340,7 +340,7 @@ class EarlyStopping:
|
|||
self.experiment_name = experiment_name
|
||||
|
||||
def __call__(self, validation, model, epoch):
|
||||
if validation > self.best_score:
|
||||
if validation >= self.best_score:
|
||||
if self.verbose:
|
||||
print(
|
||||
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
||||
|
|
|
@ -100,7 +100,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
return "bert-base-uncased"
|
||||
elif "mbert" == model_name:
|
||||
return "bert-base-multilingual-uncased"
|
||||
elif "xlm" == model_name:
|
||||
elif "xlm-roberta" == model_name:
|
||||
return "xlm-roberta-base"
|
||||
elif "mt5" == model_name:
|
||||
return "google/mt5-small"
|
||||
|
@ -270,8 +270,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
elif "bert" in model_name:
|
||||
if "multilingual" in model_name:
|
||||
return "mbert"
|
||||
elif "xlm" in model_name:
|
||||
return "xlm"
|
||||
elif "xlm-roberta" in model_name:
|
||||
return "xlm-roberta"
|
||||
else:
|
||||
return model_name
|
||||
|
||||
|
|
7
main.py
7
main.py
|
@ -13,6 +13,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|||
"""
|
||||
TODO:
|
||||
- Transformers VGFs:
|
||||
- scheduler with warmup and cosine
|
||||
- freeze params method
|
||||
- General:
|
||||
[!] zero-shot setup
|
||||
|
@ -177,17 +178,17 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--features", action="store_false")
|
||||
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||
# transformer parameters ---------------
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--textual_trf_name", type=str, default="mbert")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--textual_lr", type=float, default=1e-5)
|
||||
parser.add_argument("--visual_lr", type=float, default=1e-5)
|
||||
parser.add_argument("--textual_lr", type=float, default=1e-4)
|
||||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--patience", type=int, default=5)
|
||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||
# Visual Transformer parameters --------------
|
||||
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
||||
parser.add_argument("--visual_lr", type=float, default=1e-4)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
Loading…
Reference in New Issue