plot update
This commit is contained in:
parent
3129187df8
commit
f64654d6f0
1
TODO.txt
1
TODO.txt
|
|
@ -6,3 +6,4 @@
|
||||||
- [TODO] add Friedman's method and DeBias
|
- [TODO] add Friedman's method and DeBias
|
||||||
- [TODO] check ignore warning stuff
|
- [TODO] check ignore warning stuff
|
||||||
check https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
|
check https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
|
||||||
|
- [TODO] nmd and md are not selectable from qp.evaluation.evaluate as a string
|
||||||
|
|
@ -567,4 +567,40 @@ def _join_data_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, x_error
|
||||||
if method not in method_order:
|
if method not in method_order:
|
||||||
method_order.append(method)
|
method_order.append(method)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def calibration_plot(prob_classifier, X, y, nbins=10, savepath=None):
|
||||||
|
posteriors = prob_classifier.predict_proba(X)
|
||||||
|
assert posteriors.ndim==2, 'calibration plot only works for binary problems'
|
||||||
|
posteriors = posteriors[:,1]
|
||||||
|
pred_y = posteriors>=0.5
|
||||||
|
bins = np.linspace(0, 1, nbins + 1)
|
||||||
|
binned_values = np.digitize(posteriors, bins, right=False)
|
||||||
|
print(np.unique(binned_values))
|
||||||
|
correct = pred_y == y
|
||||||
|
bin_centers = (bins[:-1] + bins[1:]) / 2
|
||||||
|
bins_names = np.arange(nbins)
|
||||||
|
y_axis = [correct[binned_values==bin].mean() for bin in bins_names]
|
||||||
|
y_axis = [v if not np.isnan(v) else 0 for v in y_axis]
|
||||||
|
# Crear el gráfico de barras
|
||||||
|
plt.bar(bin_centers, y_axis, width=bins[1]-bins[0], edgecolor='black', alpha=0.7)
|
||||||
|
|
||||||
|
# Etiquetas y título
|
||||||
|
plt.xlabel("Bin")
|
||||||
|
plt.ylabel("Value")
|
||||||
|
plt.title("Bar plot of calculated values per bin")
|
||||||
|
plt.xticks(bin_centers, [f"{b:.2f}" for b in bin_centers], rotation=45)
|
||||||
|
|
||||||
|
# Mostrar el gráfico
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import quapy as qp
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
data = qp.datasets.fetch_UCIBinaryDataset(qp.datasets.UCI_BINARY_DATASETS[6])
|
||||||
|
train, test = data.train_test
|
||||||
|
classifier = LogisticRegression()
|
||||||
|
classifier.fit(*train.Xy)
|
||||||
|
calibration_plot(classifier, *test.Xy)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue