from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader from gfun.vgfs.learners.svms import FeatureSet2Posteriors class TransformerGen: """Base class for all transformers. It implements the basic methods for the creation of the datasets, datalaoders and the train-val split method. It is designed to be used with MultilingualDataset in the form of dictioanries {lang: data} """ def __init__( self, model_name, dataset_name, epochs=10, lr=1e-5, batch_size=4, batch_size_eval=32, max_length=512, print_steps=50, device="cpu", probabilistic=False, n_jobs=-1, evaluate_step=10, verbose=False, patience=5, scheduler=None, ): self.model_name = model_name self.dataset_name = dataset_name self.device = device self.model = None self.lr = lr self.epochs = epochs self.tokenizer = None self.max_length = max_length self.batch_size = batch_size self.batch_size_eval = batch_size_eval self.print_steps = print_steps self.probabilistic = probabilistic self.n_jobs = n_jobs self.fitted = False self.datasets = {} self.evaluate_step = evaluate_step self.verbose = verbose self.patience = patience self.datasets = {} self.scheduler = scheduler self.feature2posterior_projector = ( self.make_probabilistic() if probabilistic else None ) def make_probabilistic(self): if self.probabilistic: feature2posterior_projector = FeatureSet2Posteriors( n_jobs=self.n_jobs, verbose=False ) return feature2posterior_projector def build_dataloader( self, lX, lY, torchDataset, processor_fn, batch_size, split="train", shuffle=False, ): l_processed = {lang: processor_fn(lX[lang]) for lang in lX.keys()} self.datasets[split] = torchDataset(l_processed, lY, split=split) return DataLoader( self.datasets[split], batch_size=batch_size, shuffle=shuffle, # collate_fn=processor_fn, ) def get_train_val_data(self, lX, lY, split=0.2, seed=42, modality="text"): assert modality in ["text", "image"], "modality must be either text or image" tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {} for lang in lX.keys(): tr_X, val_X, tr_Y, val_Y = train_test_split( lX[lang][modality], lY[lang], test_size=split, random_state=seed, shuffle=False, ) tr_lX[lang] = tr_X tr_lY[lang] = tr_Y val_lX[lang] = val_X val_lY[lang] = val_Y return tr_lX, tr_lY, val_lX, val_lY def get_config(self): return { "model_name": self.model_name, "dataset_name": self.dataset_name, "epochs": self.epochs, "lr": self.lr, "scheduler": self.scheduler, "batch_size": self.batch_size, "batch_size_eval": self.batch_size_eval, "max_length": self.max_length, "print_steps": self.print_steps, "device": self.device, "probabilistic": self.probabilistic, "n_jobs": self.n_jobs, "evaluate_step": self.evaluate_step, "verbose": self.verbose, "patience": self.patience, }