diff --git a/Ordinal/utils.py b/Ordinal/utils.py index 88a278e..ac22671 100644 --- a/Ordinal/utils.py +++ b/Ordinal/utils.py @@ -12,6 +12,28 @@ def load_samples(path_dir, classes): yield LabelledCollection.load(join(path_dir, f'{id}.txt'), loader_func=qp.data.reader.from_text, classes=classes) +def load_samples_as_csv(path_dir, debug=False): + import pandas as pd + import csv + import datasets + from datasets import Dataset + + nsamples = len(glob(join(path_dir, f'*.txt'))) + for id in range(nsamples): + df = pd.read_csv(join(path_dir, f'{id}.txt'), sep='\t', names=['labels', 'review'], quoting=csv.QUOTE_NONE) + labels = df.pop('labels').to_frame() + X = df + + features = datasets.Features({'review': datasets.Value('string')}) + if debug: + sample = Dataset.from_pandas(df=X, features=features).select(range(50)) + labels = labels[:50] + else: + sample = Dataset.from_pandas(df=X, features=features) + + yield sample, labels + + def load_samples_pkl(path_dir, filter=None): nsamples = len(glob(join(path_dir, f'*.pkl'))) for id in range(nsamples):