Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import altair as alt | |
| import pandas as pd | |
| from omegaconf import OmegaConf | |
| from pydantic import BaseModel | |
| class PlotEntropiesConfig(BaseModel): | |
| data_path: str | None | |
| chart_path: str | |
| score_override_path: str | None = None | |
| threshold_override: float | None = None | |
| class Config: | |
| extra = "forbid" | |
| class PlotEntropiesData(BaseModel): | |
| text: str | |
| threshold: float = 1.335442066192627 | |
| dataframe_json: str | None | |
| class Config: | |
| extra = "forbid" | |
| def main(): | |
| config_path = sys.argv[1] | |
| file_config = OmegaConf.load(config_path) | |
| # Omit program name and config file name | |
| cli_conf = OmegaConf.from_cli(sys.argv[2:]) | |
| conf_dict = OmegaConf.to_container( | |
| OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True | |
| ) | |
| plot_config = PlotEntropiesConfig(**conf_dict) | |
| with open(plot_config.data_path) as f: | |
| json_data = f.read() | |
| plot_data = PlotEntropiesData.model_validate_json(json_data) | |
| df = pd.read_json(plot_data.dataframe_json) | |
| print("LEN", len(df)) | |
| if plot_config.threshold_override is None: | |
| threshold = plot_data.threshold | |
| else: | |
| threshold = plot_config.threshold_override | |
| if plot_config.score_override_path is not None: | |
| with open(plot_config.score_override_path) as f: | |
| scores = json.load(f)["score"] | |
| assert len(scores) == len(df) | |
| df["entropies"] = scores | |
| df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1] | |
| x_ticks = [] | |
| for row in df.itertuples(): | |
| position = row.position | |
| token = row.tokens | |
| x_ticks.append(f"{str(position).zfill(3)}|{token}") | |
| df["position_with_token"] = x_ticks | |
| print(df) | |
| x_axis = alt.Axis( | |
| labelExpr="split(datum.label, '|')[1]", | |
| grid=False, | |
| labelOverlap=False, | |
| labelAngle=0, | |
| ) | |
| width = 1200 | |
| height = 150 | |
| base = alt.Chart(df).properties(width=width, height=height) | |
| points = base.mark_line(point=True).encode( | |
| x=alt.X("position_with_token:O", title=None, axis=x_axis), | |
| y=alt.Y( | |
| "entropies", | |
| title="Entropy of Next Byte", | |
| ), | |
| ) | |
| rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( | |
| y=alt.datum(threshold), | |
| ) | |
| patch_rules = ( | |
| alt.Chart(df[df["start"] > 0]) | |
| .properties(width=width, height=height) | |
| .mark_rule(color="#474747", strokeDash=[4, 2]) | |
| .encode(x=alt.X("position_with_token:O", axis=x_axis)) | |
| ) | |
| chart = patch_rules + rule + points | |
| chart = chart.configure_axis(labelFontSize=15, titleFontSize=15) | |
| path = Path(plot_config.chart_path) | |
| path.parent.mkdir(exist_ok=True) | |
| chart.save(path) | |
| if __name__ == "__main__": | |
| main() | |