plot adapted to qcpanel save feature

This commit is contained in:
Lorenzo Volpi 2023-11-16 01:36:07 +01:00
parent 81b92157a5
commit 1105709a4c
1 changed files with 16 additions and 4 deletions

View File

@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from quacc.environment import env
from quacc import utils
matplotlib.use("agg")
@ -30,6 +30,7 @@ def plot_delta(
legend=True,
avg=None,
return_fig=False,
base_path=None,
) -> Path:
_base_title = "delta_stdev" if stdevs is not None else "delta"
if train_prev is not None:
@ -38,6 +39,9 @@ def plot_delta(
else:
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
@ -89,7 +93,7 @@ def plot_delta(
if return_fig:
return fig
output_path = env.PLOT_OUT_DIR / f"{title}.png"
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
@ -105,6 +109,7 @@ def plot_diagonal(
train_prev=None,
legend=True,
return_fig=False,
base_path=None,
):
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
@ -112,6 +117,9 @@ def plot_diagonal(
else:
title = f"diagonal_{name}_{metric}"
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
@ -178,7 +186,7 @@ def plot_diagonal(
if return_fig:
return fig
output_path = env.PLOT_OUT_DIR / f"{title}.png"
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
@ -195,6 +203,7 @@ def plot_shift(
train_prev=None,
legend=True,
return_fig=False,
base_path=None,
) -> Path:
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
@ -202,6 +211,9 @@ def plot_shift(
else:
title = f"shift_{name}_avg_{metric}"
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
@ -247,7 +259,7 @@ def plot_shift(
if return_fig:
return fig
output_path = env.PLOT_OUT_DIR / f"{title}.png"
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path