|
import os |
|
import json |
|
import pandas as pd |
|
import ast |
|
|
|
import matplotlib.pyplot as plt |
|
from matplotlib import rcParams |
|
|
|
import argparse |
|
import seaborn as sns |
|
from tqdm import tqdm |
|
import matplotlib.pyplot as plt |
|
|
|
import numpy as np |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--output_dir", type=str, default="output") |
|
parser.add_argument("--model", type=str, default=None) |
|
parser.add_argument("--input_file", type=str, required=True) |
|
parser.add_argument("--percentile", type=float, default=0.9999) |
|
args = parser.parse_args() |
|
output_dir = args.output_dir |
|
input_file = args.input_file |
|
|
|
with open(input_file) as f: |
|
data = json.load(f) |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
all_convs_new = [] |
|
convs = [] |
|
for row in data: |
|
conv = "" |
|
for turns in row["conversation_a"]: |
|
if turns["role"] == "user": |
|
conv += f"{turns['content']}\n" |
|
|
|
convs.append(conv[:10000]) |
|
row["post_process_conv"] = conv[:10000] |
|
all_convs_new.append(row) |
|
|
|
df = pd.DataFrame(all_convs_new) |
|
print("Number of conversations: ", len(df)) |
|
|
|
prompt_counts = df["post_process_conv"].value_counts() |
|
|
|
top_prompts = prompt_counts.head(20) |
|
print(top_prompts) |
|
|
|
|
|
percentile_cutoff = prompt_counts.quantile(args.percentile) |
|
print(f"{args.percentile*100} percentile count: {percentile_cutoff}") |
|
|
|
|
|
high_frequency_prompts = prompt_counts[prompt_counts > percentile_cutoff].index |
|
print( |
|
f"Number of high frequency prompts: {len(high_frequency_prompts)}/{len(prompt_counts)}" |
|
) |
|
|
|
|
|
dedup_tags = np.array( |
|
[{"high_freq": False, "sampled": True} for _ in range(len(df))] |
|
) |
|
high_freq_groups = df.groupby("post_process_conv") |
|
for prompt in tqdm(high_frequency_prompts): |
|
df_high_freq = high_freq_groups.get_group(prompt) |
|
sampled_indices = df_high_freq.sample( |
|
n=int(percentile_cutoff), random_state=42 |
|
).index |
|
dedup_tags[df_high_freq.index] = {"high_freq": True, "sampled": False} |
|
dedup_tags[sampled_indices] = {"high_freq": True, "sampled": True} |
|
|
|
df["dedup_tag"] = dedup_tags |
|
|
|
|
|
df = df.drop(columns=["post_process_conv"]) |
|
|
|
df.to_json( |
|
os.path.join(output_dir, "dedup.json"), |
|
orient="records", |
|
indent=4, |
|
force_ascii=False, |
|
) |
|
|