Merge branch 'dash'

This commit is contained in:
Lorenzo Volpi 2023-11-30 18:30:44 +01:00
commit 44d820d4ab
17 changed files with 1592 additions and 587 deletions

2
.gitignore vendored
View File

@ -6,11 +6,13 @@ quavenv/*
__pycache__/*
baselines/__pycache__/*
baselines/densratio/__pycache__/*
qcdash/__pycache__/*
qcpanel/__pycache__/*
quacc/__pycache__/*
quacc/evaluation/__pycache__/*
quacc/method/__pycache__/*
quacc/quantification/__pycache__/*
quacc/plot/__pycache__/*
tests/__pycache__/*
tests/*/__pycache__/*
tests/*/*/__pycache__/*

417
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -18,6 +18,7 @@ abstention = "^0.1.3.1"
main = "quacc.main:main"
run = "run:run"
panel = "qcpanel.run:run"
dash = "qcdash.app:run"
sync_up = "remote:sync_code"
sync_down = "remote:sync_output"
merge_data = "merge_data:run"
@ -27,6 +28,7 @@ poetry_command = ""
[tool.poe.tasks]
ilona = "ssh volpi@ilona.isti.cnr.it"
dash = "gunicorn qcdash.app:server -b ilona.isti.cnr.it:33421"
[tool.poe.tasks.logr]
shell = """
@ -48,6 +50,9 @@ ipympl = "^0.9.3"
ipykernel = "^6.26.0"
ipywidgets-bokeh = "^1.5.0"
pandas-stubs = "^2.1.1.230928"
dash = "^2.14.1"
dash-bootstrap-components = "^1.5.0"
gunicorn = "^21.2.0"
[tool.pytest.ini_options]
addopts = "--cov=quacc --capture=tee-sys"

0
qcdash/__init__.py Normal file
View File

413
qcdash/app.py Normal file
View File

@ -0,0 +1,413 @@
import json
import os
from collections import defaultdict
from json import JSONDecodeError
from pathlib import Path
from typing import List
from urllib.parse import parse_qsl, quote, urlencode, urlparse
import dash_bootstrap_components as dbc
import numpy as np
from dash import Dash, Input, Output, State, callback, ctx, dash_table, dcc, html
from dash.dash_table.Format import Format, Scheme
from quacc import plot
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import CompReport, DatasetReport
from quacc.evaluation.stats import ttest_rel
backend = plot.get_backend("plotly")
valid_plot_modes = defaultdict(lambda: CompReport._default_modes)
valid_plot_modes["avg"] = DatasetReport._default_dr_modes
def get_datasets(root: str | Path) -> List[DatasetReport]:
def load_dataset(dataset):
dataset = Path(dataset)
return DatasetReport.unpickle(dataset)
def explore_datasets(root: str | Path) -> List[Path]:
if isinstance(root, str):
root = Path(root)
if root.name == "plot":
return []
if not root.exists():
return []
dr_paths = []
for f in os.listdir(root):
if (root / f).is_dir():
dr_paths += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
dr_paths.append(root / f)
return dr_paths
dr_paths = sorted(explore_datasets(root), key=lambda t: (-len(t.parts), t))
return {str(drp.parent): load_dataset(drp) for drp in dr_paths}
def get_fig(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
match (view, mode):
case ("avg", _):
return dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
backend=backend,
)
case (_, _):
cr = dr.crs[[str(round(c.train_prev[1] * 100)) for c in dr.crs].index(view)]
return cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="plotly",
save_fig=False,
backend=backend,
)
def get_table(dr: DatasetReport, metric, estimators, view, mode):
estimators = CE.name[estimators]
_prevs = [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
match (view, mode):
case ("avg", "train_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=1).mean()
case ("avg", "test_table"):
return dr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
case ("avg", "shift_table"):
return (
dr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
case ("avg", "stats_table"):
return ttest_rel(dr, metric=metric, estimators=estimators)
case (_, "train_table"):
cr = dr.crs[_prevs.index(view)]
return cr.data(metric=metric, estimators=estimators).groupby(level=0).mean()
case (_, "shift_table"):
cr = dr.crs[_prevs.index(view)]
return (
cr.shift_data(metric=metric, estimators=estimators)
.groupby(level=0)
.mean()
)
def get_DataTable(df):
_primary = "#0d6efd"
if df.empty:
return None
df = df.reset_index()
columns = {
c: dict(
id=c,
name=c,
type="numeric",
format=Format(precision=6, scheme=Scheme.exponent),
)
for c in df.columns
}
columns["index"]["format"] = Format(precision=2, scheme=Scheme.fixed)
columns = list(columns.values())
data = df.to_dict("records")
return html.Div(
[
dash_table.DataTable(
data=data,
columns=columns,
id="table1",
style_cell={
"padding": "0 12px",
"border": "0",
"border-bottom": f"1px solid {_primary}",
},
style_table={
"margin": "6vh 15px",
"padding": "15px",
"maxWidth": "80vw",
"overflowX": "auto",
"border": f"0px solid {_primary}",
"border-radius": "6px",
},
)
],
style={
"display": "flex",
"flex-direction": "column",
# "justify-content": "center",
"align-items": "center",
"height": "100vh",
},
)
def get_Graph(fig):
if fig is None:
return None
return dcc.Graph(
id="graph1",
figure=fig,
style={
"margin": 0,
"height": "100vh",
},
)
datasets = get_datasets("output")
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# app.config.suppress_callback_exceptions = True
sidebar_style = {
"top": 0,
"left": 0,
"bottom": 0,
"padding": "1vw",
"padding-top": "2vw",
"margin": "0px",
"flex": 1,
"overflow": "scroll",
"height": "100vh",
}
content_style = {
"flex": 5,
"maxWidth": "84vw",
}
def parse_href(href: str):
parse_result = urlparse(href)
params = parse_qsl(parse_result.query)
return dict(params)
def get_sidebar():
return [
html.H4("Parameters:", style={"margin-bottom": "1vw"}),
dbc.Select(
# options=list(datasets.keys()),
# value=list(datasets.keys())[0],
id="dataset",
),
dbc.Select(
id="metric",
style={"margin-top": "1vh"},
),
html.Div(
[
dbc.RadioItems(
id="view",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
dbc.RadioItems(
id="mode",
class_name="btn-group mt-3",
input_class_name="btn-check",
label_class_name="btn btn-outline-primary",
label_checked_class_name="active",
),
],
className="radio-group-v d-flex justify-content-around",
),
html.Div(
[
dbc.Checklist(
id="estimators",
className="btn-group mt-3",
inputClassName="btn-check",
labelClassName="btn btn-outline-primary",
labelCheckedClassName="active",
),
],
className="radio-group-wide",
),
]
app.layout = html.Div(
[
dcc.Interval(id="reload", interval=10 * 60 * 1000),
dcc.Location(id="url", refresh=False),
html.Div(
[
html.Div(get_sidebar(), id="app_sidebar", style=sidebar_style),
html.Div(id="app_content", style=content_style),
],
id="page_layout",
style={"display": "flex", "flexDirection": "row"},
),
]
)
server = app.server
def apply_param(href, triggered_id, id, curr):
match triggered_id:
case "url":
params = parse_href(href)
return params.get(id, None)
case _:
return curr
@callback(
Output("dataset", "value"),
Output("dataset", "options"),
Input("url", "href"),
Input("reload", "n_intervals"),
State("dataset", "value"),
)
def update_dataset(href, n_intervals, dataset):
match ctx.triggered_id:
case "reload":
new_datasets = get_datasets("output")
global datasets
datasets = new_datasets
req_dataset = dataset
case "url":
params = parse_href(href)
req_dataset = params.get("dataset", None)
available_datasets = list(datasets.keys())
new_dataset = (
req_dataset if req_dataset in available_datasets else available_datasets[0]
)
return new_dataset, available_datasets
@callback(
Output("metric", "options"),
Output("metric", "value"),
Input("url", "href"),
Input("dataset", "value"),
State("metric", "value"),
)
def update_metrics(href, dataset, curr_metric):
dr = datasets[dataset]
old_metric = apply_param(href, ctx.triggered_id, "metric", curr_metric)
valid_metrics = [m for m in dr.data().columns.unique(0) if not m.endswith("_score")]
new_metric = old_metric if old_metric in valid_metrics else valid_metrics[0]
return valid_metrics, new_metric
@callback(
Output("estimators", "options"),
Output("estimators", "value"),
Input("url", "href"),
Input("dataset", "value"),
Input("metric", "value"),
State("estimators", "value"),
)
def update_estimators(href, dataset, metric, curr_estimators):
dr = datasets[dataset]
old_estimators = apply_param(href, ctx.triggered_id, "estimators", curr_estimators)
if isinstance(old_estimators, str):
try:
old_estimators = json.loads(old_estimators)
except JSONDecodeError:
old_estimators = []
valid_estimators = dr.data(metric=metric).columns.unique(0).to_numpy()
new_estimators = valid_estimators[
np.isin(valid_estimators, old_estimators)
].tolist()
return valid_estimators, new_estimators
@callback(
Output("view", "options"),
Output("view", "value"),
Input("url", "href"),
Input("dataset", "value"),
State("view", "value"),
)
def update_view(href, dataset, curr_view):
dr = datasets[dataset]
old_view = apply_param(href, ctx.triggered_id, "view", curr_view)
valid_views = ["avg"] + [str(round(cr.train_prev[1] * 100)) for cr in dr.crs]
new_view = old_view if old_view in valid_views else valid_views[0]
return valid_views, new_view
@callback(
Output("mode", "options"),
Output("mode", "value"),
Input("url", "href"),
Input("view", "value"),
State("mode", "value"),
)
def update_mode(href, view, curr_mode):
old_mode = apply_param(href, ctx.triggered_id, "mode", curr_mode)
valid_modes = valid_plot_modes[view]
new_mode = old_mode if old_mode in valid_modes else valid_modes[0]
return valid_modes, new_mode
@callback(
Output("app_content", "children"),
Output("url", "search"),
Input("dataset", "value"),
Input("metric", "value"),
Input("estimators", "value"),
Input("view", "value"),
Input("mode", "value"),
)
def update_content(dataset, metric, estimators, view, mode):
search = urlencode(
dict(
dataset=dataset,
metric=metric,
estimators=json.dumps(estimators),
view=view,
mode=mode,
),
quote_via=quote,
)
dr = datasets[dataset]
match mode:
case m if m.endswith("table"):
df = get_table(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
dt = get_DataTable(df)
app_content = [] if dt is None else [dt]
case _:
fig = get_fig(
dr=dr,
metric=metric,
estimators=estimators,
view=view,
mode=mode,
)
g = get_Graph(fig)
app_content = [] if g is None else [g]
return app_content, f"?{search}"
def run():
app.run(debug=True)
if __name__ == "__main__":
run()

View File

@ -0,0 +1,86 @@
/* restyle radio items */
.radio-group .form-check {
padding-left: 0;
}
.radio-group .btn-group > .form-check:not(:last-child) > .btn {
border-top-right-radius: 0;
border-bottom-right-radius: 0;
}
.radio-group .btn-group > .form-check:not(:first-child) > .btn {
border-top-left-radius: 0;
border-bottom-left-radius: 0;
margin-left: -1px;
}
.radio-group-v{
padding: 0 10px;
}
.radio-group-v .form-check {
padding-top: 0;
padding-left: 2px;
}
.radio-group-v .btn-group {
flex-direction: column;
justify-content: center;
align-items: stretch;
flex-grow: 1;
}
.radio-group-v > .btn-group:first-child {
flex-grow: 0;
}
.radio-group-v > .btn-group:last-child {
margin-left: 20px;
}
.radio-group-v .btn-group > .form-check:not(:last-child) > .btn {
border-bottom-right-radius: 0;
border-bottom-left-radius: 0;
}
.radio-group-v .btn-group > .form-check:not(:first-child) > .btn {
border-top-right-radius: 0;
border-top-left-radius: 0;
margin-top: -3px;
}
.radio-group-v .btn-group .btn{
width: 100%;
}
.radio-group-wide .form-check {
padding-left: 0px;
}
.radio-group-wide .btn-group{
flex-wrap: wrap;
}
.radio-group-wide .btn-group .form-check{
flex: 1;
margin-top: -3px;
margin-left: -1px;
}
.radio-group-wide .btn-group .form-check .btn{
width: 100%;
border-radius: 0;
}
.radio-group-wide .btn-group .form-check:first-child > .btn{
border-top-left-radius: 10px;
}
.radio-group-wide .btn-group .form-check:last-child > .btn{
border-bottom-right-radius: 10px;
}
div#app-sidebar{
border-right: 2px solid var(--primary);
}

View File

@ -1,290 +0,0 @@
import argparse
import os
from pathlib import Path
import panel as pn
import param
from quacc.evaluation.estimators import CE
from quacc.evaluation.report import DatasetReport
pn.extension(design="bootstrap")
def create_cr_plots(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
prev=None,
):
idx = [round(cr.train_prev[1] * 100) for cr in dr.crs].index(prev)
cr = dr.crs[idx]
estimators = CE.name[estimators]
_dpi = 112
return pn.pane.Matplotlib(
cr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
),
tight=True,
format="png",
sizing_mode="scale_height",
# sizing_mode="scale_both",
)
def create_avg_plots(
dr: DatasetReport,
mode="delta",
metric="acc",
estimators=None,
prev=None,
):
estimators = CE.name[estimators]
return pn.pane.Matplotlib(
dr.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
),
tight=True,
format="png",
sizing_mode="scale_height",
# sizing_mode="scale_both",
)
def build_cr_tab(dr: DatasetReport):
_data = dr.data()
_metrics = _data.columns.unique(0)
_estimators = _data.columns.unique(1)
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
metric_widget = pn.widgets.Select(
name="metric",
value="acc",
options=valid_metrics,
align="center",
)
valid_estimators = [e for e in _estimators if e != "ref"]
estimators_widget = pn.widgets.CheckButtonGroup(
name="estimators",
options=valid_estimators,
value=valid_estimators,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
valid_plot_modes = ["delta", "delta_stdev", "diagonal", "shift"]
plot_mode_widget = pn.widgets.RadioButtonGroup(
name="mode",
value=valid_plot_modes[0],
options=valid_plot_modes,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
valid_prevs = [round(cr.train_prev[1] * 100) for cr in dr.crs]
prevs_widget = pn.widgets.RadioButtonGroup(
name="train prevalence",
value=valid_prevs[0],
options=valid_prevs,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
)
plot_pane = pn.bind(
create_cr_plots,
dr=dr,
mode=plot_mode_widget,
metric=metric_widget,
estimators=estimators_widget,
prev=prevs_widget,
)
return pn.Row(
pn.Spacer(width=20),
pn.Column(
metric_widget,
pn.Row(
prevs_widget,
plot_mode_widget,
),
estimators_widget,
align="center",
),
pn.Spacer(sizing_mode="scale_width"),
plot_pane,
)
def build_avg_tab(dr: DatasetReport):
_data = dr.data()
_metrics = _data.columns.unique(0)
_estimators = _data.columns.unique(1)
valid_metrics = [m for m in _metrics if not m.endswith("_score")]
metric_widget = pn.widgets.Select(
name="metric",
value="acc",
options=valid_metrics,
align="center",
)
valid_estimators = [e for e in _estimators if e != "ref"]
estimators_widget = pn.widgets.CheckButtonGroup(
name="estimators",
options=valid_estimators,
value=valid_estimators,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
valid_plot_modes = [
"delta_train",
"stdev_train",
"delta_test",
"stdev_test",
"shift",
]
plot_mode_widget = pn.widgets.RadioButtonGroup(
name="mode",
value=valid_plot_modes[0],
options=valid_plot_modes,
button_style="outline",
button_type="primary",
align="center",
orientation="vertical",
sizing_mode="scale_width",
)
plot_pane = pn.bind(
create_avg_plots,
dr=dr,
mode=plot_mode_widget,
metric=metric_widget,
estimators=estimators_widget,
)
return pn.Row(
pn.Spacer(width=20),
pn.Column(
metric_widget,
plot_mode_widget,
estimators_widget,
align="center",
),
pn.Spacer(sizing_mode="scale_width"),
plot_pane,
)
def build_dataset(dataset_path: Path):
dr: DatasetReport = DatasetReport.unpickle(dataset_path)
prevs_tab = ("train prevs.", build_cr_tab(dr))
avg_tab = ("avg", build_avg_tab(dr))
app = pn.Tabs(objects=[avg_tab, prevs_tab], dynamic=False)
app.servable()
return app
def explore_datasets(root: Path | str):
if isinstance(root, str):
root = Path(root)
if root.name == "plot":
return []
if not root.exists():
return []
drs = []
for f in os.listdir(root):
if (root / f).is_dir():
drs += explore_datasets(root / f)
elif f == f"{root.name}.pickle":
drs.append((root, build_dataset(root / f)))
# drs.append((str(root),))
return drs
class PlotSelector(param.Parameterized):
metric = param.Selector(objects=["acc", "f1"])
view = param.Selector(objects=["train prevs", "avg"])
def plot_selector_widget():
return pn.Param(
PlotSelector.param,
widgets={
"metric": pn.widgets.Select,
"view": pn.widgets.Select,
},
)
def serve(address="localhost"):
# app = build_dataset(Path("output/rcv1_CCAT_9prevs/rcv1_CCAT_9prevs.pickle"))
__base_path = "output"
__tabs = sorted(
explore_datasets(__base_path), key=lambda t: (len(t[0].parts), t[0])
)
__tabs = [(str(p.relative_to(Path(__base_path))), d) for (p, d) in __tabs]
if len(__tabs) > 0:
app = pn.Tabs(
objects=__tabs,
tabs_location="left",
dynamic=False,
)
else:
app = pn.Column(
pn.pane.Str("No Dataset Found", styles={"font-size": "24pt"}),
align="center",
)
__port = 33420
__allowed = [address]
if address == "localhost":
__allowed.append("127.0.0.1")
pn.serve(
app,
autoreload=True,
port=__port,
show=False,
address=address,
websocket_origin=[f"{_a}:{__port}" for _a in __allowed],
)
def run():
parser = argparse.ArgumentParser()
parser.add_argument(
"--address",
action="store",
dest="address",
default="localhost",
)
args = parser.parse_args()
serve(address=args.address)

View File

@ -1,10 +1,11 @@
import argparse
import panel as pn
from panel.theme.fast import FastDarkTheme, FastDefaultTheme
from qcpanel.viewer import QuaccTestViewer
# pn.config.design = pn.theme.Bootstrap
# pn.config.design = Fast
# pn.config.theme = "dark"
pn.config.notifications = True
@ -59,8 +60,8 @@ def app_instance():
],
main=[pn.Column(qtv.get_plot, sizing_mode="stretch_both")],
modal=[qtv.modal_pane],
theme=pn.theme.DarkTheme,
theme_toggle=False,
# theme=FastDefaultTheme,
theme_toggle=True,
)
app.servable()

View File

@ -52,7 +52,7 @@ def create_plots(
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
save_fig=False,
)
return (
pn.pane.Matplotlib(
@ -91,7 +91,7 @@ def create_plots(
metric=metric,
estimators=estimators,
conf="panel",
return_fig=True,
save_fig=False,
)
return (
pn.pane.Matplotlib(

View File

@ -6,7 +6,7 @@ from typing import List, Tuple
import numpy as np
import pandas as pd
from quacc import plot
import quacc.plot as plot
from quacc.utils import fmt_line_md
@ -215,16 +215,17 @@ class CompReport:
def get_plots(
self,
mode="delta",
mode="delta_train",
metric="acc",
estimators=None,
conf="default",
return_fig=False,
save_fig=True,
base_path=None,
backend=None,
) -> List[Tuple[str, Path]]:
if mode == "delta_train":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
if avg_data.empty is True:
if avg_data.empty:
return None
return plot.plot_delta(
@ -234,8 +235,9 @@ class CompReport:
metric=metric,
name=conf,
train_prev=self.train_prev,
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "stdev_train":
avg_data = self.avg_by_prevs(metric=metric, estimators=estimators)
@ -251,8 +253,9 @@ class CompReport:
name=conf,
train_prev=self.train_prev,
stdevs=st_data.T.to_numpy(),
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "diagonal":
f_data = self.data(metric=metric + "_score", estimators=estimators)
@ -268,8 +271,9 @@ class CompReport:
metric=metric,
name=conf,
train_prev=self.train_prev,
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "shift":
_shift_data = self.shift_data(metric=metric, estimators=estimators)
@ -290,8 +294,9 @@ class CompReport:
name=conf,
train_prev=self.train_prev,
counts=shift_counts.T.to_numpy(),
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
def to_md(
@ -323,11 +328,12 @@ class CompReport:
plot_modes = [m for m in modes if not m.endswith("table")]
for mode in plot_modes:
res += f"### {mode}\n"
op = self.get_plots(
_, op = self.get_plots(
mode=mode,
metric=metric,
estimators=estimators,
conf=conf,
save_fig=True,
base_path=plot_path,
)
res += f"![plot_{mode}]({op.relative_to(op.parents[1]).as_posix()})\n"
@ -424,12 +430,15 @@ class DatasetReport:
metric="acc",
estimators=None,
conf="default",
return_fig=False,
save_fig=True,
base_path=None,
backend=None,
):
if mode == "delta_train":
_data = self.data(metric, estimators) if data is None else data
avg_on_train = _data.groupby(level=1).mean()
if avg_on_train.empty:
return None
prevs_on_train = np.sort(avg_on_train.index.unique(0))
return plot.plot_delta(
base_prevs=np.around(
@ -441,12 +450,15 @@ class DatasetReport:
name=conf,
train_prev=None,
avg="train",
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "stdev_train":
_data = self.data(metric, estimators) if data is None else data
avg_on_train = _data.groupby(level=1).mean()
if avg_on_train.empty:
return None
prevs_on_train = np.sort(avg_on_train.index.unique(0))
stdev_on_train = _data.groupby(level=1).std()
return plot.plot_delta(
@ -460,12 +472,15 @@ class DatasetReport:
train_prev=None,
stdevs=stdev_on_train.T.to_numpy(),
avg="train",
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "delta_test":
_data = self.data(metric, estimators) if data is None else data
avg_on_test = _data.groupby(level=0).mean()
if avg_on_test.empty:
return None
prevs_on_test = np.sort(avg_on_test.index.unique(0))
return plot.plot_delta(
base_prevs=np.around([(1.0 - p, p) for p in prevs_on_test], decimals=2),
@ -475,12 +490,15 @@ class DatasetReport:
name=conf,
train_prev=None,
avg="test",
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "stdev_test":
_data = self.data(metric, estimators) if data is None else data
avg_on_test = _data.groupby(level=0).mean()
if avg_on_test.empty:
return None
prevs_on_test = np.sort(avg_on_test.index.unique(0))
stdev_on_test = _data.groupby(level=0).std()
return plot.plot_delta(
@ -492,12 +510,15 @@ class DatasetReport:
train_prev=None,
stdevs=stdev_on_test.T.to_numpy(),
avg="test",
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
elif mode == "shift":
_shift_data = self.shift_data(metric, estimators) if data is None else data
avg_shift = _shift_data.groupby(level=0).mean()
if avg_shift.empty:
return None
count_shift = _shift_data.groupby(level=0).count()
prevs_shift = np.sort(avg_shift.index.unique(0))
return plot.plot_shift(
@ -508,8 +529,9 @@ class DatasetReport:
name=conf,
train_prev=None,
counts=count_shift.T.to_numpy(),
return_fig=return_fig,
save_fig=save_fig,
base_path=base_path,
backend=backend,
)
def to_md(
@ -545,24 +567,26 @@ class DatasetReport:
res += avg_on_train_tbl.to_html() + "\n\n"
if "delta_train" in dr_modes:
delta_op = self.get_plots(
_, delta_op = self.get_plots(
data=_data,
mode="delta_train",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
save_fig=True,
)
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n"
if "stdev_train" in dr_modes:
delta_stdev_op = self.get_plots(
_, delta_stdev_op = self.get_plots(
data=_data,
mode="stdev_train",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
save_fig=True,
)
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n"
@ -575,24 +599,26 @@ class DatasetReport:
res += avg_on_test_tbl.to_html() + "\n\n"
if "delta_test" in dr_modes:
delta_op = self.get_plots(
_, delta_op = self.get_plots(
data=_data,
mode="delta_test",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
save_fig=True,
)
res += f"![plot_delta]({delta_op.relative_to(delta_op.parents[1]).as_posix()})\n"
if "stdev_test" in dr_modes:
delta_stdev_op = self.get_plots(
_, delta_stdev_op = self.get_plots(
data=_data,
mode="stdev_test",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
save_fig=True,
)
res += f"![plot_delta_stdev]({delta_stdev_op.relative_to(delta_stdev_op.parents[1]).as_posix()})\n"
@ -605,13 +631,14 @@ class DatasetReport:
res += shift_on_train_tbl.to_html() + "\n\n"
if "shift" in dr_modes:
shift_op = self.get_plots(
_, shift_op = self.get_plots(
data=_shift_data,
mode="shift",
metric=metric,
estimators=estimators,
conf=conf,
base_path=plot_path,
save_fig=True,
)
res += f"![plot_shift]({shift_op.relative_to(shift_op.parents[1]).as_posix()})\n"

View File

@ -1,265 +0,0 @@
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from quacc import utils
matplotlib.use("agg")
def _get_markers(n: int):
ls = "ovx+sDph*^1234X><.Pd"
if n > len(ls):
ls = ls * (n / len(ls) + 1)
return list(ls)[:n]
def plot_delta(
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
avg=None,
return_fig=False,
base_path=None,
) -> Path:
_base_title = "delta_stdev" if stdevs is not None else "delta"
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
else:
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
base_prevs = base_prevs[:, pos_class]
for method, deltas, _cy in zip(columns, data, cy):
ax.plot(
base_prevs,
deltas,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if stdevs is not None:
_col_idx = np.where(columns == method)[0]
stdev = stdevs[_col_idx].flatten()
nn_idx = np.intersect1d(
np.where(deltas != np.nan)[0],
np.where(stdev != np.nan)[0],
)
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
ax.fill_between(
_bps,
_ds - _st,
_ds + _st,
color=_cy["color"],
alpha=0.25,
)
x_label = "test" if avg is None or avg == "train" else "train"
ax.set(
xlabel=f"{x_label} prevalence",
ylabel=metric,
title=title,
)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
if return_fig:
return fig
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
def plot_diagonal(
reference,
columns,
data,
*,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
return_fig=False,
base_path=None,
):
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
else:
title = f"diagonal_{name}_{metric}"
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
ax.set_aspect("equal")
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(
color=[cm(i) for i in range(NUM_COLORS)],
marker=_get_markers(NUM_COLORS),
)
reference = np.array(reference)
x_ticks = np.unique(reference)
x_ticks.sort()
for deltas, _cy in zip(data, cy):
ax.plot(
reference,
deltas,
color=_cy["color"],
linestyle="None",
marker=_cy["marker"],
markersize=3,
zorder=2,
alpha=0.25,
)
# ensure limits are equal for both axes
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
for method, deltas, _cy in zip(columns, data, cy):
slope, interc = np.polyfit(reference, deltas, 1)
y_lr = np.array([slope * x + interc for x in _lims])
ax.plot(
_lims,
y_lr,
label=method,
color=_cy["color"],
linestyle="-",
markersize="0",
zorder=1,
)
# plot reference line
ax.plot(
_lims,
_lims,
color="black",
linestyle="--",
markersize=0,
zorder=1,
)
ax.set(xlabel=f"true {metric}", ylabel=f"estim. {metric}", title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
if return_fig:
return fig
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
def plot_shift(
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
return_fig=False,
base_path=None,
) -> Path:
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"shift_{name}_{t_prev_pos}_{metric}"
else:
title = f"shift_{name}_avg_{metric}"
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
shift_prevs = shift_prevs[:, pos_class]
for method, shifts, _cy in zip(columns, data, cy):
ax.plot(
shift_prevs,
shifts,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if counts is not None:
_col_idx = np.where(columns == method)[0]
count = counts[_col_idx].flatten()
for prev, shift, cnt in zip(shift_prevs, shifts, count):
label = f"{cnt}"
plt.annotate(
label,
(prev, shift),
textcoords="offset points",
xytext=(0, 10),
ha="center",
color=_cy["color"],
fontsize=12.0,
)
ax.set(xlabel="dataset shift", ylabel=metric, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
if return_fig:
return fig
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path

1
quacc/plot/__init__.py Normal file
View File

@ -0,0 +1 @@
from quacc.plot.plot import get_backend, plot_delta, plot_diagonal, plot_shift

54
quacc/plot/base.py Normal file
View File

@ -0,0 +1,54 @@
from pathlib import Path
class BasePlot:
@classmethod
def save_fig(cls, fig, base_path, title) -> Path:
...
@classmethod
def plot_diagonal(
cls,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
...
@classmethod
def plot_delta(
cls,
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
title="default",
x_label="prevs.",
y_label="error",
legend=True,
):
...
@classmethod
def plot_shift(
cls,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
...

222
quacc/plot/mpl.py Normal file
View File

@ -0,0 +1,222 @@
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from quacc import utils
from quacc.plot.base import BasePlot
matplotlib.use("agg")
class MplPlot(BasePlot):
def _get_markers(self, n: int):
ls = "ovx+sDph*^1234X><.Pd"
if n > len(ls):
ls = ls * (n / len(ls) + 1)
return list(ls)[:n]
def save_fig(self, fig, base_path, title) -> Path:
if base_path is None:
base_path = utils.get_quacc_home() / "plots"
output_path = base_path / f"{title}.png"
fig.savefig(output_path, bbox_inches="tight")
return output_path
def plot_delta(
self,
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
title="default",
x_label="prevs.",
y_label="error",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
base_prevs = base_prevs[:, pos_class]
for method, deltas, _cy in zip(columns, data, cy):
ax.plot(
base_prevs,
deltas,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if stdevs is not None:
_col_idx = np.where(columns == method)[0]
stdev = stdevs[_col_idx].flatten()
nn_idx = np.intersect1d(
np.where(deltas != np.nan)[0],
np.where(stdev != np.nan)[0],
)
_bps, _ds, _st = base_prevs[nn_idx], deltas[nn_idx], stdev[nn_idx]
ax.fill_between(
_bps,
_ds - _st,
_ds + _st,
color=_cy["color"],
alpha=0.25,
)
ax.set(
xlabel=f"{x_label} prevalence",
ylabel=y_label,
title=title,
)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig
def plot_diagonal(
self,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
ax.set_aspect("equal")
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(
color=[cm(i) for i in range(NUM_COLORS)],
marker=self._get_markers(NUM_COLORS),
)
reference = np.array(reference)
x_ticks = np.unique(reference)
x_ticks.sort()
for deltas, _cy in zip(data, cy):
ax.plot(
reference,
deltas,
color=_cy["color"],
linestyle="None",
marker=_cy["marker"],
markersize=3,
zorder=2,
alpha=0.25,
)
# ensure limits are equal for both axes
_alims = np.stack(((ax.get_xlim(), ax.get_ylim())), axis=-1)
_lims = np.array([f(ls) for f, ls in zip([np.min, np.max], _alims)])
ax.set(xlim=tuple(_lims), ylim=tuple(_lims))
for method, deltas, _cy in zip(columns, data, cy):
slope, interc = np.polyfit(reference, deltas, 1)
y_lr = np.array([slope * x + interc for x in _lims])
ax.plot(
_lims,
y_lr,
label=method,
color=_cy["color"],
linestyle="-",
markersize="0",
zorder=1,
)
# plot reference line
ax.plot(
_lims,
_lims,
color="black",
linestyle="--",
markersize=0,
zorder=1,
)
ax.set(xlabel=x_label, ylabel=y_label, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig
def plot_shift(
self,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
):
fig, ax = plt.subplots()
ax.set_aspect("auto")
ax.grid()
NUM_COLORS = len(data)
cm = plt.get_cmap("tab10")
if NUM_COLORS > 10:
cm = plt.get_cmap("tab20")
cy = cycler(color=[cm(i) for i in range(NUM_COLORS)])
shift_prevs = shift_prevs[:, pos_class]
for method, shifts, _cy in zip(columns, data, cy):
ax.plot(
shift_prevs,
shifts,
label=method,
color=_cy["color"],
linestyle="-",
marker="o",
markersize=3,
zorder=2,
)
if counts is not None:
_col_idx = np.where(columns == method)[0]
count = counts[_col_idx].flatten()
for prev, shift, cnt in zip(shift_prevs, shifts, count):
label = f"{cnt}"
plt.annotate(
label,
(prev, shift),
textcoords="offset points",
xytext=(0, 10),
ha="center",
color=_cy["color"],
fontsize=12.0,
)
ax.set(xlabel=x_label, ylabel=y_label, title=title)
if legend:
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
return fig

144
quacc/plot/plot.py Normal file
View File

@ -0,0 +1,144 @@
from quacc.plot.base import BasePlot
from quacc.plot.mpl import MplPlot
from quacc.plot.plotly import PlotlyPlot
__backend: BasePlot = MplPlot()
def get_backend(be, theme=None):
match be:
case "matplotlib" | "mpl":
return MplPlot()
case "plotly":
return PlotlyPlot(theme=theme)
case _:
return MplPlot()
def plot_delta(
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
avg=None,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
_base_title = "delta_stdev" if stdevs is not None else "delta"
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"{_base_title}_{name}_{t_prev_pos}_{metric}"
else:
title = f"{_base_title}_{name}_avg_{avg}_{metric}"
x_label = f"{'test' if avg is None or avg == 'train' else 'train'} prevalence"
y_label = f"{metric} error"
fig = backend.plot_delta(
base_prevs,
columns,
data,
stdevs=stdevs,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig
def plot_diagonal(
reference,
columns,
data,
*,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"diagonal_{name}_{t_prev_pos}_{metric}"
else:
title = f"diagonal_{name}_{metric}"
x_label = f"true {metric}"
y_label = f"estim. {metric}"
fig = backend.plot_diagonal(
reference,
columns,
data,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig
def plot_shift(
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
metric="acc",
name="default",
train_prev=None,
legend=True,
save_fig=False,
base_path=None,
backend=None,
):
backend = __backend if backend is None else backend
if train_prev is not None:
t_prev_pos = int(round(train_prev[pos_class] * 100))
title = f"shift_{name}_{t_prev_pos}_{metric}"
else:
title = f"shift_{name}_avg_{metric}"
x_label = "dataset shift"
y_label = f"{metric} error"
fig = backend.plot_shift(
shift_prevs,
columns,
data,
counts=counts,
pos_class=pos_class,
title=title,
x_label=x_label,
y_label=y_label,
legend=legend,
)
if save_fig:
output_path = backend.save_fig(fig, base_path, title)
return fig, output_path
return fig

201
quacc/plot/plotly.py Normal file
View File

@ -0,0 +1,201 @@
from collections import defaultdict
from pathlib import Path
import numpy as np
import plotly
import plotly.graph_objects as go
from quacc.plot.base import BasePlot
class PlotlyPlot(BasePlot):
__themes = defaultdict(
lambda: {
"template": "seaborn",
}
)
__themes = __themes | {
"dark": {
"template": "plotly_dark",
},
}
def __init__(self, theme=None):
self.theme = PlotlyPlot.__themes[theme]
def hex_to_rgb(self, hex: str, t: float | None = None):
hex = hex.lstrip("#")
rgb = [int(hex[i : i + 2], 16) for i in [0, 2, 4]]
if t is not None:
rgb.append(t)
return f"{'rgb' if t is None else 'rgba'}{str(tuple(rgb))}"
def get_colors(self, num):
match num:
case v if v > 10:
__colors = plotly.colors.qualitative.Light24
case _:
__colors = plotly.colors.qualitative.Plotly
def __generator(cs):
while True:
for c in cs:
yield c
return __generator(__colors)
def update_layout(self, fig, title, x_label, y_label):
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
template=self.theme["template"],
)
def save_fig(self, fig, base_path, title) -> Path:
return None
def plot_delta(
self,
base_prevs,
columns,
data,
*,
stdevs=None,
pos_class=1,
title="default",
x_label="prevs.",
y_label="error",
legend=True,
) -> go.Figure:
fig = go.Figure()
x = base_prevs[:, pos_class]
line_colors = self.get_colors(len(columns))
for name, delta in zip(columns, data):
color = next(line_colors)
_line = [
go.Scatter(
x=x,
y=delta,
mode="lines+markers",
name=name,
line=dict(color=self.hex_to_rgb(color)),
hovertemplate="prev.: %{x}<br>error: %{y:,.4f}",
)
]
_error = []
if stdevs is not None:
_col_idx = np.where(columns == name)[0]
stdev = stdevs[_col_idx].flatten()
_error = [
go.Scatter(
x=np.concatenate([x, x[::-1]]),
y=np.concatenate([delta - stdev, (delta + stdev)[::-1]]),
name=int(_col_idx[0]),
fill="toself",
fillcolor=self.hex_to_rgb(color, t=0.2),
line=dict(color="rgba(255, 255, 255, 0)"),
hoverinfo="skip",
showlegend=False,
)
]
fig.add_traces(_line + _error)
self.update_layout(fig, title, x_label, y_label)
return fig
def plot_diagonal(
self,
reference,
columns,
data,
*,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
) -> go.Figure:
fig = go.Figure()
x = reference
line_colors = self.get_colors(len(columns))
_edges = (np.min([np.min(x), np.min(data)]), np.max([np.max(x), np.max(data)]))
_lims = np.array([[_edges[0], _edges[1]], [_edges[0], _edges[1]]])
for name, val in zip(columns, data):
color = next(line_colors)
slope, interc = np.polyfit(x, val, 1)
y_lr = np.array([slope * _x + interc for _x in _lims[0]])
fig.add_traces(
[
go.Scatter(
x=x,
y=val,
customdata=np.stack((val - x,), axis=-1),
mode="markers",
name=name,
line=dict(color=self.hex_to_rgb(color, t=0.5)),
hovertemplate="true acc: %{x:,.4f}<br>estim. acc: %{y:,.4f}<br>acc err.: %{customdata[0]:,.4f}",
),
go.Scatter(
x=_lims[0],
y=y_lr,
mode="lines",
name=name,
line=dict(color=self.hex_to_rgb(color), width=3),
showlegend=False,
),
]
)
fig.add_trace(
go.Scatter(
x=_lims[0],
y=_lims[1],
mode="lines",
name="reference",
showlegend=False,
line=dict(color=self.hex_to_rgb("#000000"), dash="dash"),
)
)
self.update_layout(fig, title, x_label, y_label)
fig.update_layout(yaxis_scaleanchor="x", yaxis_scaleratio=1.0)
return fig
def plot_shift(
self,
shift_prevs,
columns,
data,
*,
counts=None,
pos_class=1,
title="default",
x_label="true",
y_label="estim.",
legend=True,
) -> go.Figure:
fig = go.Figure()
x = shift_prevs[:, pos_class]
line_colors = self.get_colors(len(columns))
for name, delta in zip(columns, data):
col_idx = (columns == name).nonzero()[0][0]
color = next(line_colors)
fig.add_trace(
go.Scatter(
x=x,
y=delta,
customdata=np.stack((counts[col_idx],), axis=-1),
mode="lines+markers",
name=name,
line=dict(color=self.hex_to_rgb(color)),
hovertemplate="shift: %{x}<br>error: %{y}"
+ "<br>count: %{customdata[0]}"
if counts is not None
else "",
)
)
self.update_layout(fig, title, x_label, y_label)
return fig

View File

@ -24,6 +24,7 @@ __to_sync_up = {
"quacc",
"baselines",
"qcpanel",
"qcdash",
],
"file": [
"conf.yaml",