1
0
Fork 0
QuaPy/distribution_matching/figures/sensibility_plot.py

26 lines
916 B
Python

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
"""
This script generates plots of sensibility for the hyperparameter "bandwidth".
Plots results for MAE, MRAE, and KLD
The rest of hyperparameters were set to their default values
"""
df_tweet = pd.read_csv('../results_tweet_sensibility/KDEy-MLE.csv', sep='\t')
df_lequa = pd.read_csv('../results_lequa_sensibility/KDEy-MLE.csv', sep='\t')
df = pd.concat([df_tweet, df_lequa])
for err in ['MAE', 'MRAE', 'KLD']:
piv = df.pivot_table(index='Bandwidth', columns='Dataset', values=err)
g = sns.lineplot(data=piv, markers=True, dashes=False)
g.set(xlim=(0.01, 0.2))
g.legend(loc="center left", bbox_to_anchor=(1, 0.5))
g.set_ylabel(err)
g.set_xticks(np.linspace(0.01, 0.2, 20))
plt.xticks(rotation=90)
plt.grid()
plt.savefig(f'./sensibility_{err}.pdf', bbox_inches='tight')
plt.clf()