Compare commits
23 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
fbd740fabd | |
|
|
ae92199613 | |
|
|
b6b1d33fdb | |
|
|
8354d76513 | |
|
|
6995854e3d | |
|
|
55e12505c0 | |
|
|
d36e185ffe | |
|
|
317fb93da6 | |
|
|
86fbd90bd4 | |
|
|
1a1c48e136 | |
|
|
c63c35269a | |
|
|
2800694672 | |
|
|
e8b6396366 | |
|
|
e3e6f061d8 | |
|
|
60171c1b5e | |
|
|
2554c58fac | |
|
|
9437ccc837 | |
|
|
de98926d00 | |
|
|
bef086ab50 | |
|
|
732ffbefb1 | |
|
|
9ce0001047 | |
|
|
b3b7c69263 | |
|
|
770e8e62be |
|
|
@ -182,4 +182,11 @@ scripts/
|
|||
logger/*
|
||||
explore_data.ipynb
|
||||
run.sh
|
||||
wandb
|
||||
wandb
|
||||
local_datasets
|
||||
hf_models
|
||||
embeddings
|
||||
results
|
||||
net.py
|
||||
lel.py
|
||||
stats.py
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
Appliances
|
||||
Arts Crafts and Sewing
|
||||
Automotive
|
||||
CDs and Vinyl
|
||||
Cell Phones and Accessories
|
||||
Electronics
|
||||
Grocery and Gourmet Food
|
||||
Home and Kitchen
|
||||
Industrial and Scientific
|
||||
Luxury Beauty
|
||||
Magazine Subscriptions
|
||||
Movies and TV
|
||||
Musical Instruments
|
||||
Office Products
|
||||
Patio Lawn and Garden
|
||||
Pet Supplies
|
||||
Software
|
||||
Sports and Outdoors
|
||||
Tools and Home Improvement
|
||||
Toys and Games
|
||||
Video Games
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
from argparse import ArgumentParser
|
||||
from csvlogger import CsvLogger
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_absolute_error
|
||||
|
||||
from os.path import join
|
||||
|
||||
"""
|
||||
MEA and classification is meaningful only in "ordinal" tasks e.g., sentiment classification.
|
||||
Otherwise the distance between the categories has no semantics!
|
||||
|
||||
- NB: we want to get the macro-averaged class specific MAE!
|
||||
"""
|
||||
|
||||
def main():
|
||||
# SETTINGS = ["p", "m", "w", "t", "mp", "mpw", "mpt", "mptw"]
|
||||
SETTINGS = ["mbert"]
|
||||
results = []
|
||||
for setting in SETTINGS:
|
||||
results.append(evalaute(setting))
|
||||
df = pd.DataFrame()
|
||||
for r in results:
|
||||
df = df.append(r)
|
||||
print(df)
|
||||
|
||||
|
||||
def evalaute(setting):
|
||||
result_dir = "results"
|
||||
# result_file = f"lang-specific.gfun.{setting}.webis.csv"
|
||||
result_file = f"lang-specific.mbert.webis.csv"
|
||||
# print(f"- reading from: {result_file}")
|
||||
df = pd.read_csv(join(result_dir, result_file))
|
||||
langs = df.langs.unique()
|
||||
res = []
|
||||
for lang in langs:
|
||||
l_df = df.langs == lang
|
||||
selected_neg = df.labels == 0
|
||||
seleteced_neutral = df.labels == 1
|
||||
selected_pos = df.labels == 2
|
||||
neg = df[l_df & selected_neg]
|
||||
neutral = df[l_df & seleteced_neutral]
|
||||
pos = df[l_df & selected_pos]
|
||||
|
||||
# print(f"{lang=}")
|
||||
# print(neg.shape, neutral.shape, pos.shape)
|
||||
|
||||
neg_mae = mean_absolute_error(neg.labels, neg.preds).round(3)
|
||||
neutral_mae = mean_absolute_error(neutral.labels, neutral.preds).round(3)
|
||||
pos_mae = mean_absolute_error(pos.labels, pos.preds).round(3)
|
||||
|
||||
macro_mae = ((neg_mae + neutral_mae + pos_mae) / 3).round(3)
|
||||
# print(f"{lang=} - {neg_mae=}, {neutral_mae=}, {pos_mae=}, {macro_mae=}")
|
||||
res.append([lang, neg_mae, neutral_mae, pos_mae, setting])
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
import csv
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
class CsvLogger:
|
||||
def __init__(self, outfile="log.csv"):
|
||||
self.outfile = outfile
|
||||
# self.init_logfile()
|
||||
|
||||
# def init_logfile(self):
|
||||
# if not os.path.isfile(self.outfile.replace(".csv", ".avg.csv")):
|
||||
# os.makedirs(self.outfile.replace(".csv", ".avg.csv"), exist_ok=True)
|
||||
# if not os.path.isfile(self.outfile.replace(".csv", ".lang.avg.csv")):
|
||||
# os.makedirs(self.outfile.replace(".csv", ".lang.csv"), exist_ok=True)
|
||||
# return
|
||||
|
||||
def log_lang_results(self, results: dict, config="gfun-lello"):
|
||||
df = pd.DataFrame.from_dict(results, orient="columns")
|
||||
df["config"] = config["gFun"]["simple_id"]
|
||||
df["aggfunc"] = config["gFun"]["aggfunc"]
|
||||
df["dataset"] = config["gFun"]["dataset"]
|
||||
df["id"] = config["gFun"]["id"]
|
||||
df["optimc"] = config["gFun"]["optimc"]
|
||||
df["timing"] = config["gFun"]["timing"]
|
||||
with open(self.outfile, 'a') as f:
|
||||
df.to_csv(f, mode='a', header=f.tell()==0)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import sys
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
|
|
@ -8,13 +9,90 @@ import re
|
|||
from dataManager.multilingualDataset import MultilingualDataset
|
||||
|
||||
CLS_PROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-processed/")
|
||||
LANGS = ["de", "en", "fr", "jp"]
|
||||
CLS_UNPROCESSED_DATA_DIR = os.path.expanduser("~/datasets/cls-acl10-unprocessed/")
|
||||
# LANGS = ["de", "en", "fr", "jp"]
|
||||
LANGS = ["de", "en", "fr"]
|
||||
DOMAINS = ["books", "dvd", "music"]
|
||||
|
||||
regex = r":\d+"
|
||||
subst = ""
|
||||
|
||||
|
||||
def load_unprocessed_cls(reduce_target_space=False):
|
||||
data = {}
|
||||
data_tr = []
|
||||
data_te = []
|
||||
c_tr = 0
|
||||
c_te = 0
|
||||
for lang in LANGS:
|
||||
data[lang] = {}
|
||||
for domain in DOMAINS:
|
||||
data[lang][domain] = {}
|
||||
print(f"lang: {lang}, domain: {domain}")
|
||||
for split in ["train", "test"]:
|
||||
domain_data = []
|
||||
fdir = os.path.join(
|
||||
CLS_UNPROCESSED_DATA_DIR, lang, domain, f"{split}.review"
|
||||
)
|
||||
tree = ET.parse(fdir)
|
||||
root = tree.getroot()
|
||||
for child in root:
|
||||
if reduce_target_space:
|
||||
rating = np.zeros(3, dtype=int)
|
||||
original_rating = int(float(child.find("rating").text))
|
||||
# if original_rating < 3:
|
||||
if original_rating < 2:
|
||||
new_rating = 1
|
||||
# elif original_rating > 3:
|
||||
elif original_rating > 4:
|
||||
new_rating = 3
|
||||
else:
|
||||
new_rating = 2
|
||||
rating[new_rating - 1] = 1
|
||||
# rating = new_rating
|
||||
else:
|
||||
rating = np.zeros(5, dtype=int)
|
||||
rating[int(float(child.find("rating").text)) - 1] = 1
|
||||
# rating = new_rating
|
||||
# if split == "train":
|
||||
# target_data = data_tr
|
||||
# current_count = len(target_data)
|
||||
# c_tr = +1
|
||||
# else:
|
||||
# target_data = data_te
|
||||
# current_count = len(target_data)
|
||||
# c_te = +1
|
||||
domain_data.append(
|
||||
# target_data.append(
|
||||
{
|
||||
"asin": child.find("asin").text
|
||||
if child.find("asin") is not None
|
||||
else None,
|
||||
# "category": child.find("category").text
|
||||
# if child.find("category") is not None
|
||||
# else None,
|
||||
"category": domain,
|
||||
# "rating": child.find("rating").text
|
||||
# if child.find("rating") is not None
|
||||
# else None,
|
||||
"original_rating": int(float(child.find("rating").text)),
|
||||
"rating": rating.argmax(),
|
||||
"title": child.find("title").text
|
||||
if child.find("title") is not None
|
||||
else None,
|
||||
"text": child.find("text").text
|
||||
if child.find("text") is not None
|
||||
else None,
|
||||
"summary": child.find("summary").text
|
||||
if child.find("summary") is not None
|
||||
else None,
|
||||
"lang": lang,
|
||||
}
|
||||
)
|
||||
data[lang][domain].update({split: domain_data})
|
||||
return data
|
||||
|
||||
|
||||
def load_cls():
|
||||
data = {}
|
||||
for lang in LANGS:
|
||||
|
|
@ -24,7 +102,7 @@ def load_cls():
|
|||
train = (
|
||||
open(
|
||||
os.path.join(
|
||||
CLS_PROCESSED_DATA_DIR, lang, domain, "train.processed"
|
||||
CLS_UNPROCESSED_DATA_DIR, lang, domain, "train.processed"
|
||||
),
|
||||
"r",
|
||||
)
|
||||
|
|
@ -34,7 +112,7 @@ def load_cls():
|
|||
test = (
|
||||
open(
|
||||
os.path.join(
|
||||
CLS_PROCESSED_DATA_DIR, lang, domain, "test.processed"
|
||||
CLS_UNPROCESSED_DATA_DIR, lang, domain, "test.processed"
|
||||
),
|
||||
"r",
|
||||
)
|
||||
|
|
@ -59,18 +137,33 @@ def process_data(line):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"datapath: {CLS_PROCESSED_DATA_DIR}")
|
||||
data = load_cls()
|
||||
multilingualDataset = MultilingualDataset(dataset_name="cls")
|
||||
for lang in LANGS:
|
||||
# TODO: just using book domain atm
|
||||
Xtr = [text[0] for text in data[lang]["books"]["train"]]
|
||||
# Ytr = np.expand_dims([text[1] for text in data[lang]["books"]["train"]], axis=1)
|
||||
Ytr = np.vstack([text[1] for text in data[lang]["books"]["train"]])
|
||||
print(f"datapath: {CLS_UNPROCESSED_DATA_DIR}")
|
||||
# data = load_cls()
|
||||
data = load_unprocessed_cls(reduce_target_space=True)
|
||||
multilingualDataset = MultilingualDataset(dataset_name="webis-cls-unprocessed")
|
||||
|
||||
Xte = [text[0] for text in data[lang]["books"]["test"]]
|
||||
# Yte = np.expand_dims([text[1] for text in data[lang]["books"]["test"]], axis=1)
|
||||
Yte = np.vstack([text[1] for text in data[lang]["books"]["test"]])
|
||||
for lang in LANGS:
|
||||
# Xtr = [text["summary"] for text in data[lang]["books"]["train"]]
|
||||
Xtr = [text["text"] for text in data[lang]["books"]["train"]]
|
||||
Xtr += [text["text"] for text in data[lang]["dvd"]["train"]]
|
||||
Xtr += [text["text"] for text in data[lang]["music"]["train"]]
|
||||
|
||||
Ytr =[text["rating"] for text in data[lang]["books"]["train"]]
|
||||
Ytr += [text["rating"] for text in data[lang]["dvd"]["train"]]
|
||||
Ytr += [text["rating"] for text in data[lang]["music"]["train"]]
|
||||
|
||||
Ytr = np.vstack(Ytr)
|
||||
|
||||
Xte = [text["text"] for text in data[lang]["books"]["test"]]
|
||||
Xte += [text["text"] for text in data[lang]["dvd"]["test"]]
|
||||
Xte += [text["text"] for text in data[lang]["music"]["test"]]
|
||||
|
||||
|
||||
Yte = [text["rating"] for text in data[lang]["books"]["test"]]
|
||||
Yte += [text["rating"] for text in data[lang]["dvd"]["test"]]
|
||||
Yte += [text["rating"] for text in data[lang]["music"]["test"]]
|
||||
|
||||
Yte = np.vstack(Yte)
|
||||
|
||||
multilingualDataset.add(
|
||||
lang=lang,
|
||||
|
|
@ -81,6 +174,8 @@ if __name__ == "__main__":
|
|||
tr_ids=None,
|
||||
te_ids=None,
|
||||
)
|
||||
multilingualDataset.save(
|
||||
os.path.expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
||||
)
|
||||
# multilingualDataset.save(
|
||||
# os.path.expanduser(
|
||||
# "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
|
||||
# )
|
||||
# )
|
||||
|
|
|
|||
|
|
@ -1,10 +1,134 @@
|
|||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.expanduser("~/devel/gfun_multimodal"))
|
||||
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
import numpy as np
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from dataManager.glamiDataset import get_dataframe
|
||||
from dataManager.multilingualDataset import MultilingualDataset
|
||||
|
||||
|
||||
class SimpleGfunDataset:
|
||||
def __init__(self, datadir="~/datasets/rai/csv/", textual=True, visual=False, multilabel=False, set_tr_langs=None, set_te_langs=None):
|
||||
self.datadir = os.path.expanduser(datadir)
|
||||
self.textual = textual
|
||||
self.visual = visual
|
||||
self.multilabel = multilabel
|
||||
self.load_csv(set_tr_langs, set_te_langs)
|
||||
self.print_stats()
|
||||
|
||||
def print_stats(self):
|
||||
print(f"Dataset statistics {'-' * 15}")
|
||||
tr = 0
|
||||
va = 0
|
||||
te = 0
|
||||
for lang in self.all_langs:
|
||||
n_tr = len(self.train_data[lang]) if lang in self.tr_langs else 0
|
||||
n_va = len(self.val_data[lang]) if lang in self.tr_langs else 0
|
||||
n_te = len(self.test_data[lang])
|
||||
tr += n_tr
|
||||
va += n_va
|
||||
te += n_te
|
||||
print(f"{lang} - tr: {n_tr} - va: {n_va} - te: {n_te}")
|
||||
print(f"Total {'-' * 15}")
|
||||
print(f"tr: {tr} - va: {va} - te: {te}")
|
||||
|
||||
def load_csv(self, set_tr_langs, set_te_langs):
|
||||
# _data_tr = pd.read_csv(os.path.join(self.datadir, "train.small.csv"))
|
||||
_data_tr = pd.read_csv(os.path.join(self.datadir, "train.balanced.csv")).sample(100, random_state=42)
|
||||
train, val = train_test_split(_data_tr, test_size=0.2, random_state=42, stratify=_data_tr.lang) # TODO stratify on lang or label?
|
||||
# test = pd.read_csv(os.path.join(self.datadir, "test.small.csv"))
|
||||
test = pd.read_csv(os.path.join(self.datadir, "test.balanced.csv")).sample(100, random_state=42)
|
||||
self._set_langs (train, test, set_tr_langs, set_te_langs)
|
||||
self._set_labels(_data_tr)
|
||||
self.full_train = _data_tr
|
||||
self.full_test = self.test
|
||||
self.train_data = self._set_datalang(train)
|
||||
self.val_data = self._set_datalang(val)
|
||||
self.test_data = self._set_datalang(test)
|
||||
return
|
||||
|
||||
def _set_labels(self, data):
|
||||
# self.labels = [i for i in range(28)] # todo hard-coded for rai
|
||||
# self.labels = [i for i in range(3)] # TODO hard coded for sentimnet
|
||||
self.labels = sorted(list(data.label.unique()))
|
||||
|
||||
def _set_langs(self, train, test, set_tr_langs=None, set_te_langs=None):
|
||||
self.tr_langs = set(train.lang.unique().tolist())
|
||||
self.te_langs = set(test.lang.unique().tolist())
|
||||
if set_tr_langs is not None:
|
||||
print(f"-- [SETTING TRAINING LANGS TO: {list(set_tr_langs)}]")
|
||||
self.tr_langs = self.tr_langs.intersection(set(set_tr_langs))
|
||||
if set_te_langs is not None:
|
||||
print(f"-- [SETTING TESTING LANGS TO: {list(set_tr_langs)}]")
|
||||
self.te_langs = self.te_langs.intersection(set(set_te_langs))
|
||||
self.all_langs = self.tr_langs.union(self.te_langs)
|
||||
|
||||
return self.tr_langs, self.te_langs, self.all_langs
|
||||
|
||||
def _set_datalang(self, data: pd.DataFrame):
|
||||
return {lang: data[data.lang == lang] for lang in self.all_langs}
|
||||
|
||||
def training(self, merge_validation=False, mask_number=False, target_as_csr=False):
|
||||
# TODO some additional pre-processing on the textual data?
|
||||
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
||||
lXtr = {
|
||||
lang: {"text": apply_mask(self.train_data[lang].text.tolist())} # TODO inserting dict for textual data - we still have to manage visual
|
||||
for lang in self.tr_langs
|
||||
}
|
||||
if merge_validation:
|
||||
for lang in self.tr_langs:
|
||||
lXtr[lang]["text"] += apply_mask(self.val_data[lang].text.tolist())
|
||||
|
||||
lYtr = {
|
||||
lang: self.train_data[lang].label.tolist() for lang in self.tr_langs
|
||||
}
|
||||
if merge_validation:
|
||||
for lang in self.tr_langs:
|
||||
lYtr[lang] += self.val_data[lang].label.tolist()
|
||||
|
||||
for lang in self.tr_langs:
|
||||
lYtr[lang] = self.indices_to_one_hot(
|
||||
indices = lYtr[lang],
|
||||
n_labels = self.num_labels()
|
||||
)
|
||||
|
||||
return lXtr, lYtr
|
||||
|
||||
def test(self, mask_number=False, target_as_csr=False):
|
||||
# TODO some additional pre-processing on the textual data?
|
||||
apply_mask = lambda x: _mask_numbers(x) if _mask_numbers else x
|
||||
lXte = {
|
||||
lang: {"text": apply_mask(self.test_data[lang].text.tolist())}
|
||||
for lang in self.te_langs
|
||||
}
|
||||
lYte = {
|
||||
lang: self.indices_to_one_hot(
|
||||
indices=self.test_data[lang].label.tolist(),
|
||||
n_labels=self.num_labels())
|
||||
for lang in self.te_langs
|
||||
}
|
||||
return lXte, lYte
|
||||
|
||||
def langs(self):
|
||||
return list(self.all_langs)
|
||||
|
||||
def num_labels(self):
|
||||
return len(self.labels)
|
||||
|
||||
def indices_to_one_hot(self, indices, n_labels):
|
||||
one_hot_matrix = np.zeros((len(indices), n_labels))
|
||||
one_hot_matrix[np.arange(len(indices)), indices] = 1
|
||||
return one_hot_matrix
|
||||
|
||||
|
||||
class gFunDataset:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -27,8 +151,8 @@ class gFunDataset:
|
|||
self._load_dataset()
|
||||
|
||||
def get_label_binarizer(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
mlb = "Labels are already binarized for rcv1-2 dataset"
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
mlb = f"Labels are already binarized for {self.dataset_name} dataset"
|
||||
elif self.is_multilabel:
|
||||
mlb = MultiLabelBinarizer()
|
||||
mlb.fit([labels])
|
||||
|
|
@ -62,16 +186,39 @@ class gFunDataset:
|
|||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
elif "cls" in self.dataset_dir.lower():
|
||||
print(f"- Loading CLS dataset from {self.dataset_dir}")
|
||||
# WEBIS-CLS (processed)
|
||||
elif (
|
||||
"cls" in self.dataset_dir.lower()
|
||||
and "unprocessed" not in self.dataset_dir.lower()
|
||||
):
|
||||
print(f"- Loading WEBIS-CLS (processed) dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "cls"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||
self.dataset_name, self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
self.show_dimension()
|
||||
# WEBIS-CLS (unprocessed)
|
||||
elif (
|
||||
"cls" in self.dataset_dir.lower()
|
||||
and "unprocessed" in self.dataset_dir.lower()
|
||||
):
|
||||
print(f"- Loading WEBIS-CLS (unprocessed) dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "cls"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(
|
||||
self.dataset_name, self.dataset_dir, self.nrows
|
||||
)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
elif "rai" in self.dataset_dir.lower():
|
||||
print(f"- Loading RAI-CORPUS dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "rai"
|
||||
self.dataset, self.labels, self.data_langs = self._load_multilingual(dataset_name=self.dataset_name,
|
||||
dataset_dir="~/datasets/rai/csv/train-split-rai.csv",
|
||||
nrows=self.nrows)
|
||||
self.mlb = self.get_label_binarizer(self.labels)
|
||||
|
||||
self.show_dimension()
|
||||
return
|
||||
|
||||
def show_dimension(self):
|
||||
|
|
@ -80,13 +227,20 @@ class gFunDataset:
|
|||
print(
|
||||
f"-- Lang: {lang} - train docs: {len(data['train']['text'])} - test docs: {len(data['test']['text'])}"
|
||||
)
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
print(f"-- Labels: {self.labels}")
|
||||
else:
|
||||
print(f"-- Labels: {len(self.labels)}")
|
||||
|
||||
def _load_multilingual(self, dataset_name, dataset_dir, nrows):
|
||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
||||
if "csv" in dataset_dir:
|
||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).from_csv(
|
||||
# path_tr="~/datasets/rai/csv/train-rai-multilingual-2000.csv",
|
||||
#path_te="~/datasets/rai/csv/test-rai-multilingual-2000.csv")
|
||||
path_tr="~/datasets/rai/csv/train-split-rai.csv",
|
||||
path_te="~/datasets/rai/csv/test-split-rai.csv")
|
||||
else:
|
||||
old_dataset = MultilingualDataset(dataset_name=dataset_name).load(dataset_dir)
|
||||
if nrows is not None:
|
||||
if dataset_name == "cls":
|
||||
old_dataset.reduce_data(langs=["de", "en", "fr"], maxn=nrows)
|
||||
|
|
@ -139,7 +293,7 @@ class gFunDataset:
|
|||
return dataset, labels, data_langs
|
||||
|
||||
def binarize_labels(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
# labels are already binarized for rcv1-2 dataset
|
||||
return labels
|
||||
if hasattr(self, "mlb"):
|
||||
|
|
@ -177,7 +331,7 @@ class gFunDataset:
|
|||
return self.data_langs
|
||||
|
||||
def num_labels(self):
|
||||
if self.dataset_name not in ["rcv1-2", "jrc", "cls"]:
|
||||
if self.dataset_name not in ["rcv1-2", "jrc", "cls", "rai"]:
|
||||
return len(self.labels)
|
||||
else:
|
||||
return self.labels
|
||||
|
|
@ -190,30 +344,48 @@ class gFunDataset:
|
|||
print(f"- saving dataset in {filepath}")
|
||||
pickle.dump(self, f)
|
||||
|
||||
def _mask_numbers(data):
|
||||
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
||||
mask_4digit = re.compile(r"\s[\+-]?\d{4}([\.,]\d*)*\b")
|
||||
mask_3digit = re.compile(r"\s[\+-]?\d{3}([\.,]\d*)*\b")
|
||||
mask_2digit = re.compile(r"\s[\+-]?\d{2}([\.,]\d*)*\b")
|
||||
mask_1digit = re.compile(r"\s[\+-]?\d{1}([\.,]\d*)*\b")
|
||||
masked = []
|
||||
for text in tqdm(data, desc="masking numbers", disable=True):
|
||||
text = " " + text
|
||||
text = mask_moredigit.sub(" MoreDigitMask", text)
|
||||
text = mask_4digit.sub(" FourDigitMask", text)
|
||||
text = mask_3digit.sub(" ThreeDigitMask", text)
|
||||
text = mask_2digit.sub(" TwoDigitMask", text)
|
||||
text = mask_1digit.sub(" OneDigitMask", text)
|
||||
masked.append(text.replace(".", "").replace(",", "").strip())
|
||||
return masked
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
data_rai = SimpleGfunDataset()
|
||||
lXtr, lYtr = data_rai.training(mask_number=False)
|
||||
lXte, lYte = data_rai.test(mask_number=False)
|
||||
exit()
|
||||
# import os
|
||||
|
||||
GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
RCV_DATAPATH = os.path.expanduser(
|
||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
)
|
||||
JRC_DATAPATH = os.path.expanduser(
|
||||
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||
)
|
||||
# GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
# RCV_DATAPATH = os.path.expanduser(
|
||||
# "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
# )
|
||||
# JRC_DATAPATH = os.path.expanduser(
|
||||
# "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||
# )
|
||||
|
||||
print("Hello gFunDataset")
|
||||
dataset = gFunDataset(
|
||||
# dataset_dir=GLAMI_DATAPATH,
|
||||
# dataset_dir=RCV_DATAPATH,
|
||||
dataset_dir=JRC_DATAPATH,
|
||||
data_langs=None,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
labels=None,
|
||||
nrows=13,
|
||||
)
|
||||
lXtr, lYtr = dataset.training()
|
||||
lXte, lYte = dataset.test()
|
||||
exit(0)
|
||||
# print("Hello gFunDataset")
|
||||
# dataset = gFunDataset(
|
||||
# dataset_dir=JRC_DATAPATH,
|
||||
# data_langs=None,
|
||||
# is_textual=True,
|
||||
# is_visual=True,
|
||||
# is_multilabel=False,
|
||||
# labels=None,
|
||||
# nrows=13,
|
||||
# )
|
||||
# lXtr, lYtr = dataset.training()
|
||||
# lXte, lYte = dataset.test()
|
||||
# exit(0)
|
||||
|
|
|
|||
|
|
@ -222,6 +222,37 @@ class MultilingualDataset:
|
|||
new_data.append((docs[:maxn], labels[:maxn], None))
|
||||
return new_data
|
||||
|
||||
def from_csv(self, path_tr, path_te):
|
||||
import pandas as pd
|
||||
from os.path import expanduser
|
||||
train = pd.read_csv(expanduser(path_tr))
|
||||
test = pd.read_csv(expanduser(path_te))
|
||||
for lang in train.lang.unique():
|
||||
tr_datalang = train.loc[train["lang"] == lang]
|
||||
Xtr = tr_datalang.text.to_list()
|
||||
tr_labels = tr_datalang.label.to_list()
|
||||
Ytr = np.zeros((len(Xtr), 28), dtype=int)
|
||||
for j, i in enumerate(tr_labels):
|
||||
Ytr[j, i] = 1
|
||||
tr_ids = tr_datalang.id.to_list()
|
||||
te_datalang = test.loc[test["lang"] == lang]
|
||||
Xte = te_datalang.text.to_list()
|
||||
te_labels = te_datalang.label.to_list()
|
||||
Yte = np.zeros((len(Xte), 28), dtype=int)
|
||||
for j, i in enumerate(te_labels):
|
||||
Yte[j, i] = 1
|
||||
te_ids = te_datalang.id.to_list()
|
||||
self.add(
|
||||
lang=lang,
|
||||
Xtr=Xtr,
|
||||
Ytr=Ytr,
|
||||
Xte=Xte,
|
||||
Yte=Yte,
|
||||
tr_ids=tr_ids,
|
||||
te_ids=te_ids
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
def _mask_numbers(data):
|
||||
mask_moredigit = re.compile(r"\s[\+-]?\d{5,}([\.,]\d*)*\b")
|
||||
|
|
@ -240,7 +271,6 @@ def _mask_numbers(data):
|
|||
masked.append(text.replace(".", "").replace(",", "").strip())
|
||||
return masked
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
DATAPATH = expanduser(
|
||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from os.path import expanduser, join
|
||||
from dataManager.gFunDataset import gFunDataset
|
||||
from dataManager.gFunDataset import gFunDataset, SimpleGfunDataset
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from dataManager.amazonDataset import AmazonDataset
|
||||
|
||||
|
|
@ -23,6 +23,8 @@ def get_dataset(dataset_name, args):
|
|||
"rcv1-2",
|
||||
"glami",
|
||||
"cls",
|
||||
"webis",
|
||||
"rai",
|
||||
], "dataset not supported"
|
||||
|
||||
RCV_DATAPATH = expanduser(
|
||||
|
|
@ -37,6 +39,13 @@ def get_dataset(dataset_name, args):
|
|||
|
||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
|
||||
WEBIS_CLS = expanduser(
|
||||
# "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl"
|
||||
"~/datasets/cls-acl10-unprocessed/csv"
|
||||
)
|
||||
|
||||
RAI_DATAPATH = expanduser("~/datasets/rai/rai_corpus.pkl")
|
||||
|
||||
if dataset_name == "multinews":
|
||||
# TODO: convert to gFunDataset
|
||||
raise NotImplementedError
|
||||
|
|
@ -82,7 +91,6 @@ def get_dataset(dataset_name, args):
|
|||
)
|
||||
|
||||
dataset.save_as_pickle(GLAMI_DATAPATH)
|
||||
|
||||
elif dataset_name == "cls":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=CLS_DATAPATH,
|
||||
|
|
@ -91,6 +99,36 @@ def get_dataset(dataset_name, args):
|
|||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif dataset_name == "webis":
|
||||
dataset = SimpleGfunDataset(
|
||||
datadir=WEBIS_CLS,
|
||||
textual=True,
|
||||
visual=False,
|
||||
multilabel=False,
|
||||
set_tr_langs=args.tr_langs,
|
||||
set_te_langs=args.te_langs
|
||||
)
|
||||
# dataset = gFunDataset(
|
||||
# dataset_dir=WEBIS_CLS,
|
||||
# is_textual=True,
|
||||
# is_visual=False,
|
||||
# is_multilabel=False,
|
||||
# nrows=args.nrows,
|
||||
# )
|
||||
elif dataset_name == "rai":
|
||||
dataset = SimpleGfunDataset(
|
||||
datadir="~/datasets/rai/csv",
|
||||
textual=True,
|
||||
visual=False,
|
||||
multilabel=False
|
||||
)
|
||||
# dataset = gFunDataset(
|
||||
# dataset_dir=RAI_DATAPATH,
|
||||
# is_textual=True,
|
||||
# is_visual=False,
|
||||
# is_multilabel=False,
|
||||
# nrows=args.nrows
|
||||
# )
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dataset
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from joblib import Parallel, delayed
|
||||
from collections import defaultdict
|
||||
|
||||
from evaluation.metrics import *
|
||||
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
|
||||
# from evaluation.metrics import *
|
||||
import numpy as np
|
||||
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score, precision_score, recall_score
|
||||
|
||||
|
||||
def evaluation_metrics(y, y_, clf_type):
|
||||
|
|
@ -13,13 +14,17 @@ def evaluation_metrics(y, y_, clf_type):
|
|||
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
|
||||
f1_score(y, y_, average="macro", zero_division=1),
|
||||
f1_score(y, y_, average="micro"),
|
||||
precision_score(y, y_, zero_division=1, average="macro"),
|
||||
recall_score(y, y_, zero_division=1, average="macro"),
|
||||
)
|
||||
elif clf_type == "multilabel":
|
||||
return (
|
||||
macroF1(y, y_),
|
||||
microF1(y, y_),
|
||||
macroK(y, y_),
|
||||
microK(y, y_),
|
||||
f1_score(y, y_, average="macro", zero_division=1),
|
||||
f1_score(y, y_, average="micro"),
|
||||
0,
|
||||
0,
|
||||
# macroK(y, y_),
|
||||
# microK(y, y_),
|
||||
)
|
||||
else:
|
||||
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
|
||||
|
|
@ -47,9 +52,11 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
|
|||
metrics = []
|
||||
|
||||
if clf_type == "multilabel":
|
||||
for lang in l_eval.keys():
|
||||
macrof1, microf1, macrok, microk = l_eval[lang]
|
||||
metrics.append([macrof1, microf1, macrok, microk])
|
||||
for lang in sorted(l_eval.keys()):
|
||||
# macrof1, microf1, macrok, microk = l_eval[lang]
|
||||
# metrics.append([macrof1, microf1, macrok, microk])
|
||||
macrof1, microf1, precision, recall = l_eval[lang]
|
||||
metrics.append([macrof1, microf1, precision, recall])
|
||||
if phase != "validation":
|
||||
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
||||
averages = np.mean(np.array(metrics), axis=0)
|
||||
|
|
@ -69,12 +76,15 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
|
|||
# "acc10", # "accuracy-at-10",
|
||||
"MF1", # "macro-F1",
|
||||
"mF1", # "micro-F1",
|
||||
"precision",
|
||||
"recall"
|
||||
]
|
||||
for lang in l_eval.keys():
|
||||
for lang in sorted(l_eval.keys()):
|
||||
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
|
||||
acc, macrof1, microf1 = l_eval[lang]
|
||||
acc, macrof1, microf1, precision, recall= l_eval[lang]
|
||||
# metrics.append([acc, top5, top10, macrof1, microf1])
|
||||
metrics.append([acc, macrof1, microf1])
|
||||
# metrics.append([acc, macrof1, microf1])
|
||||
metrics.append([acc, macrof1, microf1, precision, recall])
|
||||
|
||||
for m, v in zip(_metrics, l_eval[lang]):
|
||||
lang_metrics[m][lang] = v
|
||||
|
|
@ -82,7 +92,8 @@ def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
|
|||
if phase != "validation":
|
||||
print(
|
||||
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||
# f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f} pr = {precision:.3f} re = {recall:.3f}"
|
||||
)
|
||||
averages = np.mean(np.array(metrics), axis=0)
|
||||
if verbose:
|
||||
|
|
|
|||
|
|
@ -124,6 +124,16 @@ class GeneralizedFunnelling:
|
|||
epochs=self.epochs,
|
||||
attn_stacking_type=attn_stacking,
|
||||
)
|
||||
|
||||
self._model_id = get_unique_id(
|
||||
self.dataset_name,
|
||||
self.posteriors_vgf,
|
||||
self.multilingual_vgf,
|
||||
self.wce_vgf,
|
||||
self.textual_trf_vgf,
|
||||
self.visual_trf_vgf,
|
||||
self.aggfunc,
|
||||
)
|
||||
return self
|
||||
|
||||
if self.posteriors_vgf:
|
||||
|
|
@ -241,7 +251,7 @@ class GeneralizedFunnelling:
|
|||
self.metaclassifier.fit(agg, lY)
|
||||
return self
|
||||
|
||||
self.vectorizer.fit(lX)
|
||||
self.vectorizer.fit(lX) # TODO this should fit also out-of-voc languages (for muses)
|
||||
self.init_vgfs_vectorizers()
|
||||
|
||||
projections = []
|
||||
|
|
@ -314,16 +324,19 @@ class GeneralizedFunnelling:
|
|||
|
||||
def get_config(self):
|
||||
c = {}
|
||||
simple_config = ""
|
||||
|
||||
for vgf in self.first_tier_learners:
|
||||
vgf_config = vgf.get_config()
|
||||
c.update(vgf_config)
|
||||
c.update({vgf_config["name"]: vgf_config})
|
||||
simple_config += vgf_config["simple_id"]
|
||||
|
||||
gfun_config = {
|
||||
"id": self._model_id,
|
||||
"aggfunc": self.aggfunc,
|
||||
"optimc": self.optimc,
|
||||
"dataset": self.dataset_name,
|
||||
"simple_id": "".join(sorted(simple_config))
|
||||
}
|
||||
|
||||
c["gFun"] = gfun_config
|
||||
|
|
@ -372,6 +385,7 @@ class GeneralizedFunnelling:
|
|||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
print(f"- loaded trained VanillaFun VGF")
|
||||
if self.multilingual_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
|
|
@ -380,6 +394,7 @@ class GeneralizedFunnelling:
|
|||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
print(f"- loaded trained Multilingual VGF")
|
||||
if self.wce_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
|
|
@ -388,20 +403,38 @@ class GeneralizedFunnelling:
|
|||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
print(f"- loaded trained WCE VGF")
|
||||
if self.textual_trf_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
||||
"models",
|
||||
"vgfs",
|
||||
"textual_transformer",
|
||||
f"textualTransformerGen_{model_id}.pkl",
|
||||
),
|
||||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
print(f"- loaded trained Textual Transformer VGF")
|
||||
if self.visual_trf_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
"models",
|
||||
"vgfs",
|
||||
"visual_transformer",
|
||||
f"visualTransformerGen_{model_id}.pkl",
|
||||
),
|
||||
"rb",
|
||||
print(f"- loaded trained Visual Transformer VGF"),
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
|
||||
if load_meta:
|
||||
with open(
|
||||
os.path.join("models", "metaclassifier", f"meta_{model_id}.pkl"), "rb"
|
||||
) as f:
|
||||
metaclassifier = pickle.load(f)
|
||||
print(f"- loaded trained metaclassifier")
|
||||
else:
|
||||
metaclassifier = None
|
||||
return first_tier_learners, metaclassifier, vectorizer
|
||||
|
|
|
|||
|
|
@ -103,6 +103,11 @@ def predict(logits, clf_type="multilabel"):
|
|||
class TfidfVectorizerMultilingual:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def update(self, X, lang):
|
||||
self.langs.append(lang)
|
||||
self.vectorizer[lang] = TfidfVectorizer(**self.kwargs).fit(X["text"])
|
||||
return self
|
||||
|
||||
def fit(self, lX, ly=None):
|
||||
self.langs = sorted(lX.keys())
|
||||
|
|
@ -112,7 +117,12 @@ class TfidfVectorizerMultilingual:
|
|||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs}
|
||||
in_langs = lX.keys()
|
||||
for in_l in in_langs:
|
||||
if in_l not in self.langs:
|
||||
print(f"[NB: found unvectorized language! Updatding vectorizer for {in_l=}]")
|
||||
self.update(X=lX[in_l], lang=in_l)
|
||||
return {l: self.vectorizer[l].transform(lX[l]["text"]) for l in self.langs} # TODO we can update the vectorizer with new languages here!
|
||||
|
||||
def fit_transform(self, lX, ly=None):
|
||||
return self.fit(lX, ly).transform(lX)
|
||||
|
|
|
|||
|
|
@ -56,6 +56,13 @@ class MultilingualGen(ViewGen):
|
|||
|
||||
def transform(self, lX):
|
||||
lX = self.vectorizer.transform(lX)
|
||||
if self.langs != sorted(self.vectorizer.vectorizer.keys()):
|
||||
# new_langs = set(self.vectorizer.vectorizer.keys()) - set(self.langs)
|
||||
old_langs = self.langs
|
||||
self.langs = sorted(self.vectorizer.vectorizer.keys())
|
||||
new_load, _ = self._load_embeddings(embed_dir=self.embed_dir, cached=self.cached, exclude=old_langs)
|
||||
for k, v in new_load.items():
|
||||
self.multi_embeddings[k] = v
|
||||
|
||||
XdotMulti = Parallel(n_jobs=self.n_jobs)(
|
||||
delayed(XdotM)(lX[lang], self.multi_embeddings[lang], sif=self.sif)
|
||||
|
|
@ -70,10 +77,12 @@ class MultilingualGen(ViewGen):
|
|||
def fit_transform(self, lX, lY):
|
||||
return self.fit(lX, lY).transform(lX)
|
||||
|
||||
def _load_embeddings(self, embed_dir, cached):
|
||||
def _load_embeddings(self, embed_dir, cached, exclude=None):
|
||||
if "muse" in self.embed_dir.lower():
|
||||
if exclude is not None:
|
||||
langs = set(self.langs) - set(exclude)
|
||||
multi_embeddings = load_MUSEs(
|
||||
langs=self.langs,
|
||||
langs=self.langs if exclude is None else langs,
|
||||
l_vocab=self.vectorizer.vocabulary(),
|
||||
dir_path=embed_dir,
|
||||
cached=cached,
|
||||
|
|
@ -89,6 +98,7 @@ class MultilingualGen(ViewGen):
|
|||
"cached": self.cached,
|
||||
"sif": self.sif,
|
||||
"probabilistic": self.probabilistic,
|
||||
"simple_id": "m"
|
||||
}
|
||||
|
||||
def save_vgf(self, model_id):
|
||||
|
|
@ -164,6 +174,8 @@ def extract(l_voc, l_embeddings):
|
|||
"""
|
||||
l_extracted = {}
|
||||
for lang, words in l_voc.items():
|
||||
if lang not in l_embeddings:
|
||||
continue
|
||||
source_id, target_id = reindex(words, l_embeddings[lang].stoi)
|
||||
extraction = torch.zeros((len(words), l_embeddings[lang].vectors.shape[-1]))
|
||||
extraction[source_id] = l_embeddings[lang].vectors[target_id]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from dataManager.torchDataset import MultilingualDatasetTorch
|
|||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# TODO should pass also attention_mask to transformer model!
|
||||
|
||||
class MT5ForSequenceClassification(nn.Module):
|
||||
def __init__(self, model_name, num_labels, output_hidden_states):
|
||||
|
|
@ -45,11 +46,12 @@ class MT5ForSequenceClassification(nn.Module):
|
|||
|
||||
def save_pretrained(self, checkpoint_dir):
|
||||
torch.save(self.state_dict(), checkpoint_dir + ".pt")
|
||||
return
|
||||
return self
|
||||
|
||||
def from_pretrained(self, checkpoint_dir):
|
||||
checkpoint_dir += ".pt"
|
||||
return self.load_state_dict(torch.load(checkpoint_dir))
|
||||
self.load_state_dict(torch.load(checkpoint_dir))
|
||||
return self
|
||||
|
||||
|
||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||
|
|
@ -99,7 +101,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
if "bert" == model_name:
|
||||
return "bert-base-uncased"
|
||||
elif "mbert" == model_name:
|
||||
return "bert-base-multilingual-uncased"
|
||||
return "bert-base-multilingual-cased"
|
||||
elif "xlm-roberta" == model_name:
|
||||
return "xlm-roberta-base"
|
||||
elif "mt5" == model_name:
|
||||
|
|
@ -113,11 +115,16 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
else:
|
||||
# model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
|
||||
# model_name = "hf_models/mbert-rai-fewshot-second/checkpoint-19000" # TODO hardcoded to pre-traiend mbert
|
||||
# model_name = "hf_models/mbert-sentiment/checkpoint-1150" # TODO hardcoded to pre-traiend mbert
|
||||
model_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600"
|
||||
return AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
|
||||
def load_tokenizer(self, model_name):
|
||||
# model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert
|
||||
return AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
def init_model(self, model_name, num_labels):
|
||||
|
|
@ -144,61 +151,64 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
self.model_name, num_labels=self.num_labels
|
||||
)
|
||||
|
||||
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
||||
lX, lY, split=0.2, seed=42, modality="text"
|
||||
)
|
||||
self.model.to("cuda")
|
||||
|
||||
tra_dataloader = self.build_dataloader(
|
||||
tr_lX,
|
||||
tr_lY,
|
||||
processor_fn=self._tokenize,
|
||||
torchDataset=MultilingualDatasetTorch,
|
||||
batch_size=self.batch_size,
|
||||
split="train",
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
val_dataloader = self.build_dataloader(
|
||||
val_lX,
|
||||
val_lY,
|
||||
processor_fn=self._tokenize,
|
||||
torchDataset=MultilingualDatasetTorch,
|
||||
batch_size=self.batch_size_eval,
|
||||
split="val",
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
optimizer_name="adamW",
|
||||
lr=self.lr,
|
||||
device=self.device,
|
||||
loss_fn=torch.nn.CrossEntropyLoss(),
|
||||
print_steps=self.print_steps,
|
||||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
checkpoint_path=os.path.join(
|
||||
"models",
|
||||
"vgfs",
|
||||
"transformer",
|
||||
self._format_model_name(self.model_name),
|
||||
),
|
||||
vgf_name="textual_trf",
|
||||
classification_type=self.clf_type,
|
||||
n_jobs=self.n_jobs,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
eval_dataloader=val_dataloader,
|
||||
epochs=self.epochs,
|
||||
)
|
||||
# tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
||||
# lX, lY, split=0.2, seed=42, modality="text"
|
||||
# )
|
||||
#
|
||||
# tra_dataloader = self.build_dataloader(
|
||||
# tr_lX,
|
||||
# tr_lY,
|
||||
# processor_fn=self._tokenize,
|
||||
# torchDataset=MultilingualDatasetTorch,
|
||||
# batch_size=self.batch_size,
|
||||
# split="train",
|
||||
# shuffle=True,
|
||||
# )
|
||||
#
|
||||
# val_dataloader = self.build_dataloader(
|
||||
# val_lX,
|
||||
# val_lY,
|
||||
# processor_fn=self._tokenize,
|
||||
# torchDataset=MultilingualDatasetTorch,
|
||||
# batch_size=self.batch_size_eval,
|
||||
# split="val",
|
||||
# shuffle=False,
|
||||
# )
|
||||
#
|
||||
# experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
#
|
||||
# trainer = Trainer(
|
||||
# model=self.model,
|
||||
# optimizer_name="adamW",
|
||||
# lr=self.lr,
|
||||
# device=self.device,
|
||||
# loss_fn=torch.nn.CrossEntropyLoss(),
|
||||
# print_steps=self.print_steps,
|
||||
# evaluate_step=self.evaluate_step,
|
||||
# patience=self.patience,
|
||||
# experiment_name=experiment_name,
|
||||
# checkpoint_path=os.path.join(
|
||||
# "models",
|
||||
# "vgfs",
|
||||
# "trained_transformer",
|
||||
# self._format_model_name(self.model_name),
|
||||
# ),
|
||||
# vgf_name="textual_trf",
|
||||
# classification_type=self.clf_type,
|
||||
# n_jobs=self.n_jobs,
|
||||
# scheduler_name=self.scheduler,
|
||||
# )
|
||||
# trainer.train(
|
||||
# train_dataloader=tra_dataloader,
|
||||
# eval_dataloader=val_dataloader,
|
||||
# epochs=self.epochs,
|
||||
# )
|
||||
|
||||
if self.probabilistic:
|
||||
self.feature2posterior_projector.fit(self.transform(lX), lY)
|
||||
transformed = self.transform(lX)
|
||||
self.feature2posterior_projector.fit(transformed, lY)
|
||||
|
||||
self.fitted = True
|
||||
|
||||
|
|
@ -222,9 +232,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
# TODO should pass also attention_mask !
|
||||
for input_ids, lang in dataloader:
|
||||
input_ids = input_ids.to(self.device)
|
||||
# TODO: check this
|
||||
if isinstance(self.model, MT5ForSequenceClassification):
|
||||
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
|
||||
else:
|
||||
|
|
@ -277,4 +287,4 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
|
||||
def get_config(self):
|
||||
c = super().get_config()
|
||||
return {"textual_trf": c}
|
||||
return {"name": "textual-transformer VGF", "textual_trf": c, "simple_id": "t"}
|
||||
|
|
|
|||
|
|
@ -65,3 +65,6 @@ class VanillaFunGen(ViewGen):
|
|||
with open(_path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def get_config(self):
|
||||
return {"name": "Vanilla Funnelling VGF", "simple_id": "p"}
|
||||
|
|
|
|||
|
|
@ -186,4 +186,4 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
return self
|
||||
|
||||
def get_config(self):
|
||||
return {"visual_trf": super().get_config()}
|
||||
return {"name": "visual-transformer VGF", "visual_trf": super().get_config()}
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class WceGen(ViewGen):
|
|||
"name": "Word-Class Embeddings VGF",
|
||||
"n_jobs": self.n_jobs,
|
||||
"sif": self.sif,
|
||||
"simple_id": "w"
|
||||
}
|
||||
|
||||
def save_vgf(self, model_id):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,225 @@
|
|||
from os.path import expanduser
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
TrainingArguments,
|
||||
)
|
||||
from gfun.vgfs.commons import Trainer
|
||||
from datasets import load_dataset, DatasetDict
|
||||
|
||||
from transformers import Trainer
|
||||
from pprint import pprint
|
||||
|
||||
import transformers
|
||||
import evaluate
|
||||
import pandas as pd
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
IWSLT_D_COLUMNS = ["text", "category", "rating", "summary", "title"]
|
||||
RAI_D_COLUMNS = ["id", "provider", "date", "title", "text", "label"] # "lang"
|
||||
WEBIS_D_COLUMNS = ['Unnamed: 0', 'asin', 'category', 'original_rating', 'label', 'title', 'text', 'summary'] # "lang"
|
||||
MAX_LEN = 128
|
||||
# DATASET_NAME = "rai"
|
||||
# DATASET_NAME = "rai-multilingual-2000"
|
||||
# DATASET_NAME = "webis-cls"
|
||||
|
||||
|
||||
def init_callbacks(patience=-1, nosave=False):
|
||||
callbacks = []
|
||||
if patience != -1 and not nosave:
|
||||
callbacks.append(transformers.EarlyStoppingCallback(early_stopping_patience=patience))
|
||||
return callbacks
|
||||
|
||||
|
||||
def init_model(model_name, nlabels):
|
||||
if model_name == "mbert":
|
||||
# hf_name = "bert-base-multilingual-cased"
|
||||
hf_name = "hf_models/mbert-sentiment-balanced/checkpoint-1600"
|
||||
# hf_name = "hf_models/mbert-rai-fewshot-second/checkpoint-9000"
|
||||
elif model_name == "xlm-roberta":
|
||||
hf_name = "xlm-roberta-base"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(hf_name, num_labels=nlabels)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def main(args):
|
||||
tokenizer, model = init_model(args.model, args.nlabels)
|
||||
|
||||
data = load_dataset(
|
||||
"csv",
|
||||
data_files = {
|
||||
"train": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/train.balanced.csv"),
|
||||
"test": expanduser(f"~/datasets/cls-acl10-unprocessed/csv/test.balanced.csv")
|
||||
# "train": expanduser(f"~/datasets/rai/csv/train-{DATASET_NAME}.csv"),
|
||||
# "test": expanduser(f"~/datasets/rai/csv/test-{DATASET_NAME}.csv")
|
||||
# "train": expanduser(f"~/datasets/rai/csv/train.small.csv"),
|
||||
# "test": expanduser(f"~/datasets/rai/csv/test.small.csv")
|
||||
}
|
||||
)
|
||||
|
||||
def process_sample_rai(sample):
|
||||
inputs = [f"{title}. {text}" for title, text in zip(sample["title"], sample["text"])]
|
||||
labels = sample["label"]
|
||||
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
def process_sample_webis(sample):
|
||||
inputs = sample["text"]
|
||||
labels = sample["label"]
|
||||
model_inputs = tokenizer(inputs, max_length=MAX_LEN, truncation=True) # TODO pre-process text cause there's a lot of noise in there...
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
|
||||
data = data.map(
|
||||
# process_sample_rai,
|
||||
process_sample_webis,
|
||||
batched=True,
|
||||
num_proc=4,
|
||||
load_from_cache_file=True,
|
||||
# remove_columns=RAI_D_COLUMNS,
|
||||
remove_columns=WEBIS_D_COLUMNS,
|
||||
)
|
||||
train_val_splits = data["train"].train_test_split(test_size=0.2, seed=42)
|
||||
data.set_format("torch")
|
||||
data = DatasetDict(
|
||||
{
|
||||
"train": train_val_splits["train"],
|
||||
"validation": train_val_splits["test"],
|
||||
"test": data["test"],
|
||||
}
|
||||
)
|
||||
|
||||
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
callbacks = init_callbacks(args.patience, args.nosave)
|
||||
|
||||
f1_metric = evaluate.load("f1")
|
||||
accuracy_metric = evaluate.load("accuracy")
|
||||
precision_metric = evaluate.load("precision")
|
||||
recall_metric = evaluate.load("recall")
|
||||
|
||||
training_args = TrainingArguments(
|
||||
# output_dir=f"hf_models/{args.model}-rai",
|
||||
output_dir=f"hf_models/{args.model}-sentiment-balanced",
|
||||
do_train=True,
|
||||
evaluation_strategy="steps",
|
||||
per_device_train_batch_size=args.batch,
|
||||
per_device_eval_batch_size=args.batch,
|
||||
gradient_accumulation_steps=args.gradacc,
|
||||
eval_accumulation_steps=10,
|
||||
learning_rate=args.lr,
|
||||
weight_decay=0.1,
|
||||
max_grad_norm=5.0,
|
||||
num_train_epochs=args.epochs,
|
||||
lr_scheduler_type=args.scheduler,
|
||||
warmup_ratio=0.01,
|
||||
logging_strategy="steps",
|
||||
logging_first_step=True,
|
||||
logging_steps=args.steplog,
|
||||
seed=42,
|
||||
fp16=args.fp16,
|
||||
load_best_model_at_end=False if args.nosave else True,
|
||||
save_strategy="no" if args.nosave else "steps",
|
||||
save_total_limit=2,
|
||||
eval_steps=args.stepeval,
|
||||
# run_name=f"{args.model}-rai-stratified",
|
||||
run_name=f"{args.model}-sentiment",
|
||||
disable_tqdm=False,
|
||||
log_level="warning",
|
||||
report_to=["wandb"] if args.wandb else "none",
|
||||
optim="adamw_torch",
|
||||
save_steps=args.stepeval
|
||||
)
|
||||
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
preds = eval_preds.predictions.argmax(-1)
|
||||
# targets = eval_preds.label_ids.argmax(-1)
|
||||
targets = eval_preds.label_ids
|
||||
setting = "macro"
|
||||
f1_score_macro = f1_metric.compute(
|
||||
predictions=preds, references=targets, average="macro"
|
||||
)
|
||||
f1_score_micro = f1_metric.compute(
|
||||
predictions=preds, references=targets, average="micro"
|
||||
)
|
||||
accuracy_score = accuracy_metric.compute(predictions=preds, references=targets)
|
||||
precision_score = precision_metric.compute(
|
||||
predictions=preds, references=targets, average=setting, zero_division=1
|
||||
)
|
||||
recall_score = recall_metric.compute(
|
||||
predictions=preds, references=targets, average=setting, zero_division=1
|
||||
)
|
||||
results = {
|
||||
"macro_f1score": f1_score_macro["f1"],
|
||||
"micro_f1score": f1_score_micro["f1"],
|
||||
"accuracy": accuracy_score["accuracy"],
|
||||
"precision": precision_score["precision"],
|
||||
"recall": recall_score["recall"],
|
||||
}
|
||||
results = {k: round(v, 4) for k, v in results.items()}
|
||||
return results
|
||||
|
||||
if args.wandb:
|
||||
import wandb
|
||||
wandb.init(entity="andreapdr", project=f"gfun-rai-hf", name="mbert-rai", config=vars(args))
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=data["train"],
|
||||
eval_dataset=data["validation"],
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if not args.onlytest:
|
||||
print("- Training:")
|
||||
trainer.train()
|
||||
|
||||
print("- Testing:")
|
||||
test_results = trainer.predict(test_dataset=data["test"], metric_key_prefix="test")
|
||||
pprint(test_results.metrics)
|
||||
save_preds(data["test"], test_results.predictions)
|
||||
exit()
|
||||
|
||||
def save_preds(dataset, predictions):
|
||||
df = pd.DataFrame()
|
||||
df["langs"] = dataset["lang"]
|
||||
df["labels"] = dataset["labels"]
|
||||
df["preds"] = predictions.argmax(axis=1)
|
||||
df.to_csv("results/lang-specific.mbert.webis.csv", index=False)
|
||||
return
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--model", type=str, metavar="", default="mbert")
|
||||
parser.add_argument("--nlabels", type=int, metavar="", default=28)
|
||||
parser.add_argument("--lr", type=float, metavar="", default=5e-5, help="Set learning rate",)
|
||||
parser.add_argument("--scheduler", type=str, metavar="", default="cosine", help="Accepted: [\"cosine\", \"cosine-reset\", \"cosine-warmup\", \"cosine-warmup-reset\", \"constant\"]")
|
||||
parser.add_argument("--batch", type=int, metavar="", default=8, help="Set batch size")
|
||||
parser.add_argument("--gradacc", type=int, metavar="", default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--epochs", type=int, metavar="", default=100, help="Set epochs")
|
||||
parser.add_argument("--stepeval", type=int, metavar="", default=50, help="Run evaluation every n steps")
|
||||
parser.add_argument("--steplog", type=int, metavar="", default=50, help="Log training every n steps")
|
||||
parser.add_argument("--patience", type=int, metavar="", default=10, help="EarlyStopper patience")
|
||||
parser.add_argument("--fp16", action="store_true", help="Use fp16 precision")
|
||||
parser.add_argument("--wandb", action="store_true", help="Log to wandb")
|
||||
parser.add_argument("--nosave", action="store_true", help="Avoid saving model")
|
||||
parser.add_argument("--onlytest", action="store_true", help="Simply test model on test set")
|
||||
# parser.add_argument("--sanity", action="store_true", help="Train and evaluate on the same reduced (1000) data")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
52
main.py
52
main.py
|
|
@ -1,27 +1,17 @@
|
|||
import os
|
||||
import wandb
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from time import time
|
||||
|
||||
from csvlogger import CsvLogger
|
||||
from dataManager.utils import get_dataset
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||
|
||||
import pandas as pd
|
||||
|
||||
"""
|
||||
TODO:
|
||||
- Transformers VGFs:
|
||||
- scheduler with warmup and cosine
|
||||
- freeze params method
|
||||
- General:
|
||||
[!] zero-shot setup
|
||||
- CLS dataset is loading only "books" domain data
|
||||
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
|
||||
- Attention Aggregator:
|
||||
- experiment with weight init of Attention-aggregator
|
||||
- FFNN posterior-probabilities' dependent
|
||||
- Docs:
|
||||
- add documentations sphinx
|
||||
"""
|
||||
|
|
@ -44,7 +34,7 @@ def get_config_name(args):
|
|||
|
||||
def main(args):
|
||||
dataset = get_dataset(args.dataset, args)
|
||||
lX, lY = dataset.training()
|
||||
lX, lY = dataset.training(merge_validation=True)
|
||||
lX_te, lY_te = dataset.test()
|
||||
|
||||
tinit = time()
|
||||
|
|
@ -106,7 +96,9 @@ def main(args):
|
|||
|
||||
config = gfun.get_config()
|
||||
|
||||
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
|
||||
if args.wandb:
|
||||
import wandb
|
||||
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
|
||||
|
||||
gfun.fit(lX, lY)
|
||||
|
||||
|
|
@ -149,8 +141,28 @@ def main(args):
|
|||
)
|
||||
wandb.log(gfun_res)
|
||||
|
||||
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
||||
log_barplot_wandb(avg_metrics_gfun, title_affix="averages")
|
||||
if args.wandb:
|
||||
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
||||
|
||||
config["gFun"]["timing"] = f"{timeval - tinit:.2f}"
|
||||
csvlogger = CsvLogger(outfile="results/random.log.csv").log_lang_results(lang_metrics_gfun, config)
|
||||
save_preds(gfun_preds, lY_te, config=config["gFun"]["simple_id"], dataset=config["gFun"]["dataset"])
|
||||
|
||||
|
||||
def save_preds(preds, targets, config="unk", dataset="unk"):
|
||||
df = pd.DataFrame()
|
||||
langs = sorted(preds.keys())
|
||||
_preds = []
|
||||
_targets = []
|
||||
_langs = []
|
||||
for lang in langs:
|
||||
_preds.extend(preds[lang].argmax(axis=1).tolist())
|
||||
_targets.extend(targets[lang].argmax(axis=1).tolist())
|
||||
_langs.extend([lang for i in range(len(preds[lang]))])
|
||||
df["langs"] = _langs
|
||||
df["labels"] = _targets
|
||||
df["preds"] = _preds
|
||||
df.to_csv(f"results/lang-specific.gfun.{config}.{dataset}.csv", index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -159,6 +171,8 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--meta", action="store_true")
|
||||
parser.add_argument("--nosave", action="store_true")
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--tr_langs", nargs="+", default=None)
|
||||
parser.add_argument("--te_langs", nargs="+", default=None)
|
||||
# Dataset parameters -------------------
|
||||
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
|
||||
parser.add_argument("--domains", type=str, default="all")
|
||||
|
|
@ -178,7 +192,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--features", action="store_false")
|
||||
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||
# transformer parameters ---------------
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--epochs", type=int, default=5)
|
||||
parser.add_argument("--textual_trf_name", type=str, default="mbert")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||
|
|
@ -189,6 +203,8 @@ if __name__ == "__main__":
|
|||
# Visual Transformer parameters --------------
|
||||
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
||||
parser.add_argument("--visual_lr", type=float, default=1e-4)
|
||||
# logging
|
||||
parser.add_argument("--wandb", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,137 @@
|
|||
#!bin/bash
|
||||
|
||||
njobs=-1
|
||||
clf=singlelabel
|
||||
patience=5
|
||||
eval_every=5
|
||||
text_len=512
|
||||
text_lr=1e-4
|
||||
bsize=2
|
||||
txt_model=mbert
|
||||
dataset=rai
|
||||
|
||||
# config="-p"
|
||||
# echo "[Running gFun config: $config]"
|
||||
# python main.py $config \
|
||||
# -d $dataset\
|
||||
# --nosave \
|
||||
# --n_jobs $njobs \
|
||||
# --clf_type $clf \
|
||||
# --patience $patience \
|
||||
# --evaluate_step $eval_every \
|
||||
# --batch_size $bsize \
|
||||
# --max_length $text_len \
|
||||
# --textual_lr $text_lr \
|
||||
# --textual_trf_name $txt_model\
|
||||
|
||||
config="-m"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-w"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-t"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-pm"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-pmw"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-pt"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-pmt"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-pmwt"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
#!bin/bash
|
||||
|
||||
njobs=-1
|
||||
clf=singlelabel
|
||||
patience=5
|
||||
eval_every=5
|
||||
text_len=512
|
||||
text_lr=1e-4
|
||||
bsize=2
|
||||
txt_model=mbert
|
||||
dataset=webis
|
||||
|
||||
config="-p"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-m"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-w"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-t"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-pm"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-pmw"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-pt"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-pmt"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
|
||||
config="-pmwt"
|
||||
echo "[Running gFun config: $config]"
|
||||
python main.py $config \
|
||||
-d $dataset\
|
||||
--nosave \
|
||||
--n_jobs $njobs \
|
||||
--clf_type $clf \
|
||||
--patience $patience \
|
||||
--evaluate_step $eval_every \
|
||||
--batch_size $bsize \
|
||||
--max_length $text_len \
|
||||
--textual_lr $text_lr \
|
||||
--textual_trf_name $txt_model\
|
||||
Loading…
Reference in New Issue