forked from moreo/QuaPy
regenerating dataset
This commit is contained in:
parent
9ad4503153
commit
b756871f21
|
@ -9,7 +9,7 @@ from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
datadir = '/mnt/1T/Datasets/Amazon/reviews'
|
datadir = '/media/moreo/Volume/Datasets/Amazon/reviews'
|
||||||
outdir = './data/'
|
outdir = './data/'
|
||||||
real_prev_path = './data/Books-real-prevalence-by-product_votes1_reviews100.csv'
|
real_prev_path = './data/Books-real-prevalence-by-product_votes1_reviews100.csv'
|
||||||
domain = 'Books'
|
domain = 'Books'
|
||||||
|
@ -22,7 +22,7 @@ nval = 1000
|
||||||
nte = 5000
|
nte = 5000
|
||||||
|
|
||||||
|
|
||||||
def from_gz_text(path, encoding='utf-8', class2int=True):
|
def from_text(path, encoding='utf-8', class2int=True):
|
||||||
"""
|
"""
|
||||||
Reads a labelled colletion of documents.
|
Reads a labelled colletion of documents.
|
||||||
File fomart <0-4>\t<document>\n
|
File fomart <0-4>\t<document>\n
|
||||||
|
@ -32,7 +32,7 @@ def from_gz_text(path, encoding='utf-8', class2int=True):
|
||||||
:return: a list of sentences, and a list of labels
|
:return: a list of sentences, and a list of labels
|
||||||
"""
|
"""
|
||||||
all_sentences, all_labels = [], []
|
all_sentences, all_labels = [], []
|
||||||
file = gzip.open(path, 'rt', encoding=encoding).readlines()
|
file = open(path, 'rt', encoding=encoding).readlines()
|
||||||
for line in file:
|
for line in file:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line:
|
if line:
|
||||||
|
@ -40,7 +40,7 @@ def from_gz_text(path, encoding='utf-8', class2int=True):
|
||||||
label, sentence = line.split('\t')
|
label, sentence = line.split('\t')
|
||||||
sentence = sentence.strip()
|
sentence = sentence.strip()
|
||||||
if class2int:
|
if class2int:
|
||||||
label = int(label) - 1
|
label = int(label)
|
||||||
if label >= 0:
|
if label >= 0:
|
||||||
if sentence:
|
if sentence:
|
||||||
all_sentences.append(sentence)
|
all_sentences.append(sentence)
|
||||||
|
@ -66,6 +66,7 @@ def gen_samples_APP(pool: LabelledCollection, nsamples, sample_size, outdir, pre
|
||||||
write_txt_sample(sample, join(outdir, f'{i}.txt'))
|
write_txt_sample(sample, join(outdir, f'{i}.txt'))
|
||||||
prevfile.write(f'{i},' + ','.join(f'{p:.3f}' for p in sample.prevalence()) + '\n')
|
prevfile.write(f'{i},' + ','.join(f'{p:.3f}' for p in sample.prevalence()) + '\n')
|
||||||
|
|
||||||
|
|
||||||
def gen_samples_NPP(pool: LabelledCollection, nsamples, sample_size, outdir, prevpath):
|
def gen_samples_NPP(pool: LabelledCollection, nsamples, sample_size, outdir, prevpath):
|
||||||
os.makedirs(outdir, exist_ok=True)
|
os.makedirs(outdir, exist_ok=True)
|
||||||
with open(prevpath, 'wt') as prevfile:
|
with open(prevpath, 'wt') as prevfile:
|
||||||
|
@ -85,36 +86,39 @@ def gen_samples_real_prevalences(real_prevalences, pool: LabelledCollection, sam
|
||||||
prevfile.write(f'{i},' + ','.join(f'{p:.3f}' for p in sample.prevalence()) + '\n')
|
prevfile.write(f'{i},' + ','.join(f'{p:.3f}' for p in sample.prevalence()) + '\n')
|
||||||
|
|
||||||
|
|
||||||
# fullpath = join(datadir,domain)+'.txt.gz'
|
# fullpath = join(datadir,domain)+'.txt.gz' <- deprecated; there were duplicates
|
||||||
#
|
|
||||||
# data = LabelledCollection.load(fullpath, from_gz_text)
|
# data = LabelledCollection.load(fullpath, from_gz_text)
|
||||||
# print(len(data))
|
|
||||||
# print(data.classes_)
|
fullpath = './data/Books/Books.txt'
|
||||||
# print(data.prevalence())
|
data = LabelledCollection.load(fullpath, from_text)
|
||||||
|
|
||||||
|
print(len(data))
|
||||||
|
print(data.classes_)
|
||||||
|
print(data.prevalence())
|
||||||
|
|
||||||
with qp.util.temp_seed(seed):
|
with qp.util.temp_seed(seed):
|
||||||
# train, rest = data.split_stratified(train_prop=tr_size)
|
train, rest = data.split_stratified(train_prop=tr_size)
|
||||||
#
|
|
||||||
# devel, test = rest.split_stratified(train_prop=0.5)
|
devel, test = rest.split_stratified(train_prop=0.5)
|
||||||
# print(len(train))
|
print(len(train))
|
||||||
# print(len(devel))
|
print(len(devel))
|
||||||
# print(len(test))
|
print(len(test))
|
||||||
#
|
|
||||||
domaindir = join(outdir, domain)
|
domaindir = join(outdir, domain)
|
||||||
|
|
||||||
# write_txt_sample(train, join(domaindir, 'training_data.txt'))
|
write_txt_sample(train, join(domaindir, 'training_data.txt'))
|
||||||
# write_txt_sample(devel, join(domaindir, 'development_data.txt'))
|
write_txt_sample(devel, join(domaindir, 'development_data.txt'))
|
||||||
# write_txt_sample(test, join(domaindir, 'test_data.txt'))
|
write_txt_sample(test, join(domaindir, 'test_data.txt'))
|
||||||
|
|
||||||
# this part is to be used when the partitions have already been created, in order to avoid re-generating them
|
# this part is to be used when the partitions have already been created, in order to avoid re-generating them
|
||||||
train = load_simple_sample_raw(domaindir, 'training_data')
|
#train = load_simple_sample_raw(domaindir, 'training_data')
|
||||||
devel = load_simple_sample_raw(domaindir, 'development_data')
|
#devel = load_simple_sample_raw(domaindir, 'development_data')
|
||||||
test = load_simple_sample_raw(domaindir, 'test_data')
|
#test = load_simple_sample_raw(domaindir, 'test_data')
|
||||||
|
|
||||||
# gen_samples_APP(devel, nsamples=nval, sample_size=val_size, outdir=join(domaindir, 'app', 'dev_samples'),
|
gen_samples_APP(devel, nsamples=nval, sample_size=val_size, outdir=join(domaindir, 'app', 'dev_samples'),
|
||||||
# prevpath=join(domaindir, 'app', 'dev_prevalences.txt'))
|
prevpath=join(domaindir, 'app', 'dev_prevalences.txt'))
|
||||||
# gen_samples_APP(test, nsamples=nte, sample_size=te_size, outdir=join(domaindir, 'app', 'test_samples'),
|
gen_samples_APP(test, nsamples=nte, sample_size=te_size, outdir=join(domaindir, 'app', 'test_samples'),
|
||||||
# prevpath=join(domaindir, 'app', 'test_prevalences.txt'))
|
prevpath=join(domaindir, 'app', 'test_prevalences.txt'))
|
||||||
|
|
||||||
# gen_samples_NPP(devel, nsamples=nval, sample_size=val_size, outdir=join(domaindir, 'npp', 'dev_samples'),
|
# gen_samples_NPP(devel, nsamples=nval, sample_size=val_size, outdir=join(domaindir, 'npp', 'dev_samples'),
|
||||||
# prevpath=join(domaindir, 'npp', 'dev_prevalences.txt'))
|
# prevpath=join(domaindir, 'npp', 'dev_prevalences.txt'))
|
||||||
|
|
|
@ -49,7 +49,7 @@ if __name__ == '__main__':
|
||||||
datapath = sys.argv[1] # './data/Books/training_data.txt'
|
datapath = sys.argv[1] # './data/Books/training_data.txt'
|
||||||
checkpoint = sys.argv[2] #e.g., 'bert-base-uncased' or 'distilbert-base-uncased' or 'roberta-base'
|
checkpoint = sys.argv[2] #e.g., 'bert-base-uncased' or 'distilbert-base-uncased' or 'roberta-base'
|
||||||
|
|
||||||
modelout = checkpoint+'-val-finetuned'
|
modelout = checkpoint+'-finetuned-new'
|
||||||
|
|
||||||
# load the training set, and extract a held-out validation split of 1000 documents (stratified)
|
# load the training set, and extract a held-out validation split of 1000 documents (stratified)
|
||||||
df = pd.read_csv(datapath, sep='\t', names=['labels', 'review'], quoting=csv.QUOTE_NONE)
|
df = pd.read_csv(datapath, sep='\t', names=['labels', 'review'], quoting=csv.QUOTE_NONE)
|
||||||
|
|
|
@ -98,14 +98,14 @@ if __name__ == '__main__':
|
||||||
assert torch.cuda.is_available(), 'cuda is not available'
|
assert torch.cuda.is_available(), 'cuda is not available'
|
||||||
|
|
||||||
#checkpoint='roberta-base-val-finetuned'
|
#checkpoint='roberta-base-val-finetuned'
|
||||||
#generation_mode = 'ave'
|
#generation_mode = 'average' #ave seemed to work slightly better
|
||||||
|
|
||||||
n_args = len(sys.argv)
|
n_args = len(sys.argv)
|
||||||
assert n_args==3, 'wrong arguments, expected: <checkpoint> <generation-mode>\n' \
|
assert n_args==3, 'wrong arguments, expected: <checkpoint> <generation-mode>\n' \
|
||||||
'\tgeneration-mode: last (last layer), ave (average pooling), or posteriors (posterior probabilities)'
|
'\tgeneration-mode: last (last layer), ave (average pooling), or posteriors (posterior probabilities)'
|
||||||
|
|
||||||
checkpoint = sys.argv[1] #e.g., 'bert-base-uncased'
|
checkpoint = sys.argv[1] #e.g., 'bert-base-uncased'
|
||||||
generation_mode = sys.argv[2] # e.g., 'last'
|
generation_mode = sys.argv[2] # e.g., 'average' # ave seemed to work slightly better
|
||||||
|
|
||||||
assert 'finetuned' in checkpoint, 'looks like this model is not finetuned'
|
assert 'finetuned' in checkpoint, 'looks like this model is not finetuned'
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
datapath = './data'
|
datapath = './data'
|
||||||
domain = 'Books'
|
domain = 'Books'
|
||||||
protocols = ['real'] # ['app', 'npp']
|
protocols = ['real', 'app'] # ['app', 'npp']
|
||||||
|
|
||||||
assert generation_mode in ['last', 'average', 'posteriors'], 'unknown generation_model'
|
assert generation_mode in ['last', 'average', 'posteriors'], 'unknown generation_model'
|
||||||
outname = domain + f'-{checkpoint}-{generation_mode}'
|
outname = domain + f'-{checkpoint}-{generation_mode}'
|
||||||
|
|
Loading…
Reference in New Issue