forked from moreo/QuaPy
26 lines
916 B
Python
26 lines
916 B
Python
import numpy as np
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
|
|
"""
|
|
This script generates plots of sensibility for the hyperparameter "bandwidth".
|
|
Plots results for MAE, MRAE, and KLD
|
|
The rest of hyperparameters were set to their default values
|
|
"""
|
|
|
|
df_tweet = pd.read_csv('../results_tweet_sensibility/KDEy-MLE.csv', sep='\t')
|
|
df_lequa = pd.read_csv('../results_lequa_sensibility/KDEy-MLE.csv', sep='\t')
|
|
df = pd.concat([df_tweet, df_lequa])
|
|
|
|
for err in ['MAE', 'MRAE', 'KLD']:
|
|
piv = df.pivot_table(index='Bandwidth', columns='Dataset', values=err)
|
|
g = sns.lineplot(data=piv, markers=True, dashes=False)
|
|
g.set(xlim=(0.01, 0.2))
|
|
g.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
|
g.set_ylabel(err)
|
|
g.set_xticks(np.linspace(0.01, 0.2, 20))
|
|
plt.xticks(rotation=90)
|
|
plt.grid()
|
|
plt.savefig(f'./sensibility_{err}.pdf', bbox_inches='tight')
|
|
plt.clf() |