From a3ffd689b1fb048165ec71d3eea5fc28f82a39b8 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Fri, 5 Apr 2024 15:57:05 +0200 Subject: [PATCH] plotly methods fixed, plot saving implemented --- quacc/plot/plotly.py | 52 +++++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/quacc/plot/plotly.py b/quacc/plot/plotly.py index 2fb6978..33d1365 100644 --- a/quacc/plot/plotly.py +++ b/quacc/plot/plotly.py @@ -1,9 +1,11 @@ -import os +from pathlib import Path import numpy as np import plotly import plotly.graph_objects as go +from quacc.utils.commons import get_plots_path + MODE = "lines" L_WIDTH = 5 LEGEND = { @@ -16,12 +18,14 @@ FONT = {"size": 24} TEMPLATE = "ggplot2" -def _save_or_return(fig, basedir, dataset_name, measure_name, plot_type): - if basedir is not None: - plotsubdir = dataset_name - os.path.join(basedir, "plots", measure_name, plotsubdir, plot_type + ".svg") +def _save_or_return( + fig: go.Figure, basedir, cls_name, acc_name, dataset_name, plot_type +) -> go.Figure | None: + if basedir is None: + return fig - return fig + path = get_plots_path(basedir, cls_name, acc_name, dataset_name, plot_type) + fig.write_image(path) def _update_layout(fig, title, x_label, y_label, **kwargs): @@ -72,9 +76,10 @@ def plot_diagonal( method_names, true_accs, estim_accs, + cls_name, + acc_name, + dataset_name, *, - measure_name="vanilla_accuracy", - dataset_name=None, basedir=None, ) -> go.Figure: fig = go.Figure() @@ -111,8 +116,8 @@ def plot_diagonal( _update_layout( fig, - x_label=f"True {measure_name}", - y_label=f"Estimated {measure_name}", + x_label=f"True {acc_name}", + y_label=f"Estimated {acc_name}", autosize=False, width=1300, height=1000, @@ -120,18 +125,19 @@ def plot_diagonal( yaxis_scaleratio=1.0, yaxis_range=[-0.1, 1.1], ) - return _save_or_return(fig, basedir, dataset_name, measure_name, "diagonal") + return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "diagonal") def plot_delta( method_names: list[str], prevs: np.ndarray, acc_errs: np.ndarray, + cls_name, + acc_mame, + dataset_name, + prev_name, *, stdevs: np.ndarray | None = None, - prev_name="Test", - measure_name="Vanilla Accuracy", - dataset_name=None, basedir=None, ) -> go.Figure: fig = go.Figure() @@ -170,10 +176,15 @@ def plot_delta( _update_layout( fig, x_label=f"{prev_name} Prevalence", - y_label=f"Prediction Error for {measure_name}", + y_label=f"Prediction Error for {acc_mame}", ) return _save_or_return( - fig, basedir, dataset_name, measure_name, "delta" if stdevs is None else "stdev" + fig, + basedir, + cls_name, + acc_mame, + dataset_name, + "delta" if stdevs is None else "stdev", ) @@ -181,10 +192,11 @@ def plot_shift( method_names: list[str], prevs: np.ndarray, acc_errs: np.ndarray, + cls_name, + acc_name, + dataset_name, *, counts: np.ndarray | None = None, - measure_name="Vanilla Accuracy", - dataset_name=None, basedir=None, ) -> go.Figure: fig = go.Figure() @@ -212,6 +224,6 @@ def plot_shift( _update_layout( fig, x_label="Amount of Prior Probability Shift", - y_label=f"Prediction Error for {measure_name}", + y_label=f"Prediction Error for {acc_name}", ) - return _save_or_return(fig, basedir, dataset_name, measure_name, "shift") + return _save_or_return(fig, basedir, cls_name, acc_name, dataset_name, "shift")