1
0
Fork 0
QuaPy/Ordinal/partition_dataset_by_shift.py

52 lines
2.1 KiB
Python
Raw Normal View History

import numpy as np
import quapy as qp
from evaluation import nmd
from Ordinal.utils import load_samples_folder, load_single_sample_pkl
from quapy.data import LabelledCollection
import pickle
import os
from os.path import join
from tqdm import tqdm
2022-03-31 18:46:56 +02:00
"""
This scripts generates a partition of a dataset in terms of "shift".
The partition is only carried out by generating index vectors.
"""
def partition_by_drift(split, training_prevalence):
assert split in ['dev', 'test'], 'invalid split name'
total=1000 if split=='dev' else 5000
drifts = []
folderpath = join(datapath, domain, 'app', f'{split}_samples')
for sample in tqdm(load_samples_folder(folderpath, load_fn=load_single_sample_pkl), total=total):
drifts.append(nmd(training_prevalence, sample.prevalence()))
drifts = np.asarray(drifts)
order = np.argsort(drifts)
nD = len(order)
low_drift, mid_drift, high_drift = order[:nD // 3], order[nD // 3:2 * nD // 3], order[2 * nD // 3:]
all_drift = np.arange(nD)
np.save(join(datapath, domain, 'app', f'lowdrift.{split}.id.npy'), low_drift)
np.save(join(datapath, domain, 'app', f'middrift.{split}.id.npy'), mid_drift)
np.save(join(datapath, domain, 'app', f'highdrift.{split}.id.npy'), high_drift)
np.save(join(datapath, domain, 'app', f'alldrift.{split}.id.npy'), all_drift)
lows = drifts[low_drift]
mids = drifts[mid_drift]
highs = drifts[high_drift]
all = drifts[all_drift]
print(f'low drift: interval [{lows.min():.4f}, {lows.max():.4f}] mean: {lows.mean():.4f}')
print(f'mid drift: interval [{mids.min():.4f}, {mids.max():.4f}] mean: {mids.mean():.4f}')
print(f'high drift: interval [{highs.min():.4f}, {highs.max():.4f}] mean: {highs.mean():.4f}')
print(f'all drift: interval [{all.min():.4f}, {all.max():.4f}] mean: {all.mean():.4f}')
domain = 'Books-roberta-base-finetuned-pkl/checkpoint-1188-posteriors'
datapath = './data'
training = pickle.load(open(join(datapath,domain,'training_data.pkl'), 'rb'))
partition_by_drift('dev', training.prevalence())
partition_by_drift('test', training.prevalence())