from pandas import DataFrame from plotnine import ( ggplot, aes, geom_col, geom_line, annotate, theme, element_blank, labs, coord_flip, scale_x_discrete, scale_x_datetime, scale_y_continuous, theme_light, ) class DataAvailabilityError(Exception): """Raised when not enough data available to vizualize.""" pass def plot_NA(placeholder_text: str = "Not enough data available to visualize.", placeholder_size: int = 14): """Placeholder plot for when there is not enough data available to visualize.""" return ( ggplot() + annotate("text", x=0, y=0, label=placeholder_text, size=placeholder_size) + theme(axis_ticks=element_blank(), axis_text=element_blank(), panel_grid=element_blank()) + labs(x="", y="", title="") ) def generate_rule_axis_label(rule_types: list | None = None): """Generate axis label for rules, accounting for rule type ("all", "3f1", or "other").""" categories = "" if (rule_types is None) or ("all" in rule_types): pass elif all(True if cat in rule_types else False for cat in ("3f1-significant", "other-significant")): categories = "significant" elif ("3f1-significant" in rule_types) and ("other-significant" not in rule_types): categories = "Section 3(f)(1) Significant" elif ("3f1-significant" not in rule_types) and ("other-significant" in rule_types): categories = "Other Significant" return f"Number of {categories} rules".replace(" ", " ") def plot_agency(df, group_col = "acronym", value_col = "rules", color="#033C5A", rule_types: list | None = None): """Plot rules by agency. Args: df (DataFrame): Input data. group_col (str, optional): Column on which the data are grouped. Defaults to "acronym". value_col (str, optional): Column of values to be plotted. Defaults to "rules". Returns: ggplot: Plotted data. """ order_list = df.loc[:, group_col].to_list()[::-1] y_lab = generate_rule_axis_label(rule_types) plot = ( ggplot( df, aes(x=group_col, y=value_col), ) + geom_col(color="#FFFFFF", fill=color) + coord_flip() + scale_x_discrete(limits=order_list) + labs(y=y_lab, x="", title="Rules Published by Agency") + theme_light() ) return plot def plot_month( df: DataFrame, group_cols: tuple = ("publication_year", "publication_month"), value_col: str = "rules", color: str = "#033C5A", title: str | None = None, y_lab: str = "", ): """Plot rules by month. Args: df (DataFrame): Input data. group_cols (tuple, optional): Columns on which the data are grouped. Defaults to ("publication_year", "publication_month"). value_col (str, optional): Column of values to be plotted. Defaults to "rules". Returns: ggplot: Plotted data. """ df.loc[:, "ym"] = df[group_cols[0]].astype(str) + "-" + df[group_cols[1]].astype(str).str.pad(2, fillchar="0") order_list = df.loc[:, "ym"].to_list() if title is None: title = "Rules Published by Month" plot = ( ggplot( df, aes(x="ym", y=value_col), ) + geom_col(color="#FFFFFF", fill=color) + scale_x_discrete(limits=order_list) + labs(y=y_lab, x="", title=title) + theme_light() ) return plot def plot_day( df: DataFrame, group_col: str = "publication_date", value_col: str = "rules", color: str = "#033C5A", title: str | None = None, y_lab: str = "", ): """Plot rules by day. Args: df (DataFrame): Input data. group_col (str, optional): Column on which the data are grouped. Defaults to ("publication_year", "publication_month"). value_col (str, optional): Column of values to be plotted. Defaults to "rules". Returns: ggplot: Plotted data. """ min_date = df.loc[:, group_col].min() max_date = df.loc[:, group_col].max() diff = (max_date - min_date).days if diff in range(0, 61): freq = "1 week" elif diff in range(61, 91): freq = "2 weeks" else: freq = "1 month" max_value = df.loc[:, value_col].max() if title is None: title = "Rules Published by Date" plot = ( ggplot( df, aes(x=group_col, y=value_col), ) + geom_line(group=1, color=color) + scale_x_datetime(date_breaks=freq, date_labels="%m-%d") + scale_y_continuous(limits=(0, max_value), expand=(0, 0, 0.1, 0)) + labs(y=y_lab, x="", title=title) + theme_light() ) return plot def plot_week( df: DataFrame, group_col: str = "week_of", value_col: str = "rules", color: str = "#033C5A", title: str | None = None, y_lab: str = "", show_significant: bool = False, ): max_value = df.loc[:, value_col].max() date_values = df[group_col].to_list() num_weeks = len(date_values) if num_weeks in range(8, 16): reduce_by = 2 elif num_weeks in range(16, 24): reduce_by = 3 elif num_weeks in range(24, 32): reduce_by = 4 elif num_weeks >= 32: reduce_by = 5 else: reduce_by = 1 breaks = [val for idx, val in enumerate(date_values) if idx % reduce_by == 0] if title is None: title = "Rules Published by Week" plot = ( ggplot( df, aes(x=group_col, y=value_col), ) + geom_line(group=1, color=color) + scale_x_datetime(breaks=breaks, labels=[f"{w.strftime('%m-%d')}" for w in breaks]) + scale_y_continuous(limits=(0, max_value), expand=(0, 0, 0.1, 0)) + labs(y=y_lab, x="", title=title) + theme_light() ) if show_significant: # trying to add significant rules as additional lines # but getting "TypeError: Discrete value supplied to continuous scale" # for 3f1 sig rules df = df.astype({"3f1_significant": "float"}) plot = ( plot #+ geom_line(aes(x=group_col, y="3f1_significant"), inherit_aes=False, group=1, color="#AA9868", linetype="dotted") + geom_line(aes(x=group_col, y="other_significant"), inherit_aes=False, group=1, color="#0190DB", linetype="dashed") #+ guide_legend() ) return plot def plot_tf(df: DataFrame, frequency: str, rule_types: str | None = None, **kwargs) -> ggplot: """Plot rules over time by given frequency. Args: df (DataFrame): Input data. frequency (str): Frequency of time for aggregating rules. Accepts "monthly" or "daily". Raises: ValueError: Frequency parameter received invalid value. Returns: ggplot: Plotted data. """ freq_options = { "monthly": plot_month, "daily": plot_day, "weekly": plot_week, } plot_freq = freq_options.get(frequency, None) if plot_freq is None: raise ValueError(f"Frequency must be one of: {', '.join(freq_options.keys())}") y_lab = generate_rule_axis_label(rule_types) return plot_freq(df, y_lab=y_lab, **kwargs)