From 1105709a4c551a7b646055d07532887a673457dd Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 16 Nov 2023 01:36:07 +0100 Subject: [PATCH] plot adapted to qcpanel save feature --- quacc/plot.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/quacc/plot.py b/quacc/plot.py index 18a7d8a..e0bbefa 100644 --- a/quacc/plot.py +++ b/quacc/plot.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np from cycler import cycler -from quacc.environment import env +from quacc import utils matplotlib.use("agg") @@ -30,6 +30,7 @@ def plot_delta( legend=True, avg=None, return_fig=False, + base_path=None, ) -> Path: _base_title = "delta_stdev" if stdevs is not None else "delta" if train_prev is not None: @@ -38,6 +39,9 @@ def plot_delta( else: title = f"{_base_title}_{name}_avg_{avg}_{metric}" + if base_path is None: + base_path = utils.get_quacc_home() / "plots" + fig, ax = plt.subplots() ax.set_aspect("auto") ax.grid() @@ -89,7 +93,7 @@ def plot_delta( if return_fig: return fig - output_path = env.PLOT_OUT_DIR / f"{title}.png" + output_path = base_path / f"{title}.png" fig.savefig(output_path, bbox_inches="tight") return output_path @@ -105,6 +109,7 @@ def plot_diagonal( train_prev=None, legend=True, return_fig=False, + base_path=None, ): if train_prev is not None: t_prev_pos = int(round(train_prev[pos_class] * 100)) @@ -112,6 +117,9 @@ def plot_diagonal( else: title = f"diagonal_{name}_{metric}" + if base_path is None: + base_path = utils.get_quacc_home() / "plots" + fig, ax = plt.subplots() ax.set_aspect("auto") ax.grid() @@ -178,7 +186,7 @@ def plot_diagonal( if return_fig: return fig - output_path = env.PLOT_OUT_DIR / f"{title}.png" + output_path = base_path / f"{title}.png" fig.savefig(output_path, bbox_inches="tight") return output_path @@ -195,6 +203,7 @@ def plot_shift( train_prev=None, legend=True, return_fig=False, + base_path=None, ) -> Path: if train_prev is not None: t_prev_pos = int(round(train_prev[pos_class] * 100)) @@ -202,6 +211,9 @@ def plot_shift( else: title = f"shift_{name}_avg_{metric}" + if base_path is None: + base_path = utils.get_quacc_home() / "plots" + fig, ax = plt.subplots() ax.set_aspect("auto") ax.grid() @@ -247,7 +259,7 @@ def plot_shift( if return_fig: return fig - output_path = env.PLOT_OUT_DIR / f"{title}.png" + output_path = base_path / f"{title}.png" fig.savefig(output_path, bbox_inches="tight") return output_path