cra-window-rules / modules /plotting.py
Mark Febrizio
Documentation (#24)
fe4f734 unverified
from pandas import DataFrame
from plotnine import (
ggplot,
aes,
geom_col,
geom_line,
annotate,
theme,
element_blank,
labs,
coord_flip,
scale_x_discrete,
scale_x_datetime,
scale_y_continuous,
theme_light,
)
class DataAvailabilityError(Exception):
"""Raised when not enough data available to vizualize."""
pass
def plot_NA(placeholder_text: str = "Not enough data available to visualize.", placeholder_size: int = 14):
"""Placeholder plot for when there is not enough data available to visualize.
"""
return (
ggplot()
+ annotate("text", x=0, y=0, label=placeholder_text, size=placeholder_size)
+ theme(axis_ticks=element_blank(), axis_text=element_blank(), panel_grid=element_blank())
+ labs(x="", y="", title="")
)
def generate_rule_axis_label(rule_types: list | None = None):
"""Generate axis label for rules, accounting for rule type ("all", "3f1-significant", or "other-significant").
"""
categories = ""
if (rule_types is None) or ("all" in rule_types):
pass
elif all(True if cat in rule_types else False for cat in ("3f1-significant", "other-significant")):
categories = "significant"
elif ("3f1-significant" in rule_types) and ("other-significant" not in rule_types):
categories = "Section 3(f)(1) Significant"
elif ("3f1-significant" not in rule_types) and ("other-significant" in rule_types):
categories = "Other Significant"
return f"Number of {categories} rules".replace(" ", " ")
def plot_agency(df, group_col = "acronym", value_col = "rules", color="#033C5A", rule_types: list | None = None):
"""Plot rules by agency.
Args:
df (DataFrame): Input data.
group_col (str, optional): Column on which the data are grouped. Defaults to "acronym".
value_col (str, optional): Column of values to be plotted. Defaults to "rules".
color (str, optional): Color of values in plot. Defaults to "#033C5A" ([GW Blue](https://communications.gwu.edu/visual-identity/color-palette)).
rule_types (list | None, optional): One or more rule types to include in plot. Accepts "all", "3f1-significant", or "other-significant". Defaults to None.
Returns:
ggplot: Plotted data.
"""
order_list = df.loc[:, group_col].to_list()[::-1]
y_lab = generate_rule_axis_label(rule_types)
plot = (
ggplot(
df,
aes(x=group_col, y=value_col),
)
+ geom_col(color="#FFFFFF", fill=color)
+ coord_flip()
+ scale_x_discrete(limits=order_list)
+ labs(y=y_lab, x="", title="Rules Published by Agency")
+ theme_light()
)
return plot
def plot_month(
df: DataFrame,
group_cols: tuple = ("publication_year", "publication_month"),
value_col: str = "rules",
color: str = "#033C5A",
title: str | None = None,
y_lab: str = "",
):
"""Plot rules by month.
Args:
df (DataFrame): Input data.
group_cols (tuple, optional): Columns on which the data are grouped. Defaults to ("publication_year", "publication_month").
value_col (str, optional): Column of values to be plotted. Defaults to "rules".
color (str, optional): Color of values in plot. Defaults to "#033C5A" ([GW Blue](https://communications.gwu.edu/visual-identity/color-palette)).
title (str | None, optional): Plot title. Defaults to None.
y_lab (str, optional): Plot y label. Defaults to "" (empty string).
Returns:
ggplot: Plotted data.
"""
df.loc[:, "ym"] = df[group_cols[0]].astype(str) + "-" + df[group_cols[1]].astype(str).str.pad(2, fillchar="0")
order_list = df.loc[:, "ym"].to_list()
if title is None:
title = "Rules Published by Month"
plot = (
ggplot(
df,
aes(x="ym", y=value_col),
)
+ geom_col(color="#FFFFFF", fill=color)
+ scale_x_discrete(limits=order_list)
+ labs(y=y_lab, x="", title=title)
+ theme_light()
)
return plot
def plot_day(
df: DataFrame,
group_col: str = "publication_date",
value_col: str = "rules",
color: str = "#033C5A",
title: str | None = None,
y_lab: str = "",
):
"""Plot rules by day.
Args:
df (DataFrame): Input data.
group_col (str, optional): Column on which the data are grouped. Defaults to "publication_date".
value_col (str, optional): Column of values to be plotted. Defaults to "rules".
color (str, optional): Color of values in plot. Defaults to "#033C5A" ([GW Blue](https://communications.gwu.edu/visual-identity/color-palette)).
title (str | None, optional): Plot title. Defaults to None.
y_lab (str, optional): Plot y label. Defaults to "" (empty string).
Returns:
ggplot: Plotted data.
"""
min_date = df.loc[:, group_col].min()
max_date = df.loc[:, group_col].max()
diff = (max_date - min_date).days
if diff in range(0, 61):
freq = "1 week"
elif diff in range(61, 91):
freq = "2 weeks"
else:
freq = "1 month"
max_value = df.loc[:, value_col].max()
if title is None:
title = "Rules Published by Date"
plot = (
ggplot(
df,
aes(x=group_col, y=value_col),
)
+ geom_line(group=1, color=color)
+ scale_x_datetime(date_breaks=freq, date_labels="%m-%d")
+ scale_y_continuous(limits=(0, max_value), expand=(0, 0, 0.1, 0))
+ labs(y=y_lab, x="", title=title)
+ theme_light()
)
return plot
def plot_week(
df: DataFrame,
group_col: str = "week_of",
value_col: str = "rules",
color: str = "#033C5A",
title: str | None = None,
y_lab: str = "",
show_significant: bool = False,
):
"""Plot rules by week.
Args:
df (DataFrame): Input data.
group_col (str, optional): Column on which the data are grouped. Defaults to "week_of".
value_col (str, optional): Column of values to be plotted. Defaults to "rules".
color (str, optional): Color of values in plot. Defaults to "#033C5A" ([GW Blue](https://communications.gwu.edu/visual-identity/color-palette)).
title (str | None, optional): Plot title. Defaults to None.
y_lab (str, optional): Plot y label. Defaults to "" (empty string).
Returns:
ggplot: Plotted data.
"""
max_value = df.loc[:, value_col].max()
date_values = df[group_col].to_list()
num_weeks = len(date_values)
if num_weeks in range(8, 16):
reduce_by = 2
elif num_weeks in range(16, 24):
reduce_by = 3
elif num_weeks in range(24, 32):
reduce_by = 4
elif num_weeks >= 32:
reduce_by = 5
else:
reduce_by = 1
breaks = [val for idx, val in enumerate(date_values) if idx % reduce_by == 0]
if title is None:
title = "Rules Published by Week"
plot = (
ggplot(
df,
aes(x=group_col, y=value_col),
)
+ geom_line(group=1, color=color)
+ scale_x_datetime(breaks=breaks, labels=[f"{w.strftime('%m-%d')}" for w in breaks])
+ scale_y_continuous(limits=(0, max_value), expand=(0, 0, 0.1, 0))
+ labs(y=y_lab, x="", title=title)
+ theme_light()
)
return plot
def plot_tf(df: DataFrame, frequency: str, rule_types: list | None = None, **kwargs) -> ggplot:
"""Plot rules over time by given frequency.
Args:
df (DataFrame): Input data.
frequency (str): Frequency of time for aggregating rules. Accepts "monthly" or "daily".
rule_types (list | None, optional): One or more rule types to include in plot. Accepts "all", "3f1-significant", or "other-significant". Defaults to None.
Raises:
ValueError: Frequency parameter received invalid value.
Returns:
ggplot: Plotted data.
"""
freq_options = {
"monthly": plot_month,
"daily": plot_day,
"weekly": plot_week,
}
plot_freq = freq_options.get(frequency, None)
if plot_freq is None:
raise ValueError(f"Frequency must be one of: {', '.join(freq_options.keys())}")
y_lab = generate_rule_axis_label(rule_types)
return plot_freq(df, y_lab=y_lab, **kwargs)