Olive_Farm / open_instruct /get_data_stats.py
sam2ai's picture
Synced repo using 'sync_with_huggingface' Github Action
11fa0f1
raw
history blame
No virus
5.9 kB
import json
import os
import sys
import tqdm
import pandas as pd
import numpy as np
import argparse
from datasets import load_dataset
from transformers import AutoTokenizer
def get_statistics_for_messages_data(data_path):
# load dataset
dataset = load_dataset("json", data_files={"train": data_path})
# tokenize dataset
tokenizer = AutoTokenizer.from_pretrained("/net/nfs.cirrascale/allennlp/yizhongw/hf_llama_models/7B", use_fast=False)
# get statistics
num_instances = len(dataset["train"])
num_of_turns = [len(instance["messages"]) for instance in dataset["train"]]
user_prompt_lengths = []
assistant_response_lengths = []
instance_lengths = []
for instance in tqdm.tqdm(dataset["train"], desc="Processing instances"):
instance_length = 0
for message in instance["messages"]:
if message["role"] == "user":
user_prompt_lengths.append(len(tokenizer(message["content"], truncation=False, add_special_tokens=False)["input_ids"]))
instance_length += user_prompt_lengths[-1]
elif message["role"] == "assistant":
assistant_response_lengths.append(len(tokenizer(message["content"], truncation=False, add_special_tokens=False)["input_ids"]))
instance_length += assistant_response_lengths[-1]
instance_lengths.append(instance_length)
top_100_longest_instances = np.argsort(instance_lengths)[-100:][::-1].tolist()
top_100_longest_instances = [dataset["train"][i]["id"] for i in top_100_longest_instances]
result = {
"num_instances": num_instances,
"turns_summary": pd.Series(num_of_turns).describe(),
"user_prompt_lengths_summary": pd.Series(user_prompt_lengths).describe(),
"assistant_response_lengths_summary": pd.Series(assistant_response_lengths).describe(),
"total_lengths_summary": pd.Series(instance_lengths).describe(),
"num_instances_with_total_length_gt_512": np.sum(np.array(instance_lengths) > 512),
"num_instances_with_total_length_gt_768": np.sum(np.array(instance_lengths) > 768),
"num_instances_with_total_length_gt_1024": np.sum(np.array(instance_lengths) > 1024),
"num_instances_with_total_length_gt_1536": np.sum(np.array(instance_lengths) > 1536),
"num_instances_with_total_length_gt_2048": np.sum(np.array(instance_lengths) > 2048),
"num_instances_with_total_length_gt_4096": np.sum(np.array(instance_lengths) > 4096),
"top_100_longest_instances": top_100_longest_instances,
}
# convert everything to dict or scalar
for key, value in result.items():
if isinstance(value, pd.Series):
result[key] = value.to_dict()
elif isinstance(value, np.ndarray):
result[key] = value.tolist()
elif isinstance(value, np.int64):
result[key] = int(value)
return result
def get_statistics_for_prompt_completion_data(data_path):
# load dataset
dataset = load_dataset("json", data_files={"train": data_path})
prompts = [instance["prompt"] for instance in dataset["train"]]
completions = [instance["completion"] for instance in dataset["train"]]
# tokenize dataset
tokenizer = AutoTokenizer.from_pretrained("/net/nfs.cirrascale/allennlp/yizhongw/hf_llama_models/7B")
tokenized_prompts = tokenizer(prompts, truncation=False, add_special_tokens=False)
tokenized_completions = tokenizer(completions, truncation=False, add_special_tokens=False)
# get statistics
num_instances = len(dataset["train"])
prompt_lengths = [len(tokenized_prompts["input_ids"][i]) for i in range(num_instances)]
completion_lengths = [len(tokenized_completions["input_ids"][i]) for i in range(num_instances)]
prompt_completion_lengths = [prompt_lengths[i] + completion_lengths[i] for i in range(num_instances)]
result = {
"num_instances": num_instances,
"prompt_lengths_summary": pd.Series(prompt_lengths).describe(),
"completion_lengths_summary": pd.Series(completion_lengths).describe(),
"prompt_completion_lengths_summary": pd.Series(prompt_completion_lengths).describe(),
"num_instances_with_prompt_length_gt_512": np.sum(np.array(prompt_lengths) > 512),
"num_instances_with_completion_length_gt_512": np.sum(np.array(completion_lengths) > 512),
"num_instances_with_prompt_completion_length_gt_512": np.sum(np.array(prompt_completion_lengths) > 512),
"num_instances_with_completion_length_gt_768": np.sum(np.array(completion_lengths) > 768),
"num_instances_with_prompt_completion_length_gt_1024": np.sum(np.array(prompt_completion_lengths) > 1024),
}
# convert everything to dict or scalar
for key, value in result.items():
if isinstance(value, pd.Series):
result[key] = value.to_dict()
elif isinstance(value, np.ndarray):
result[key] = value.tolist()
elif isinstance(value, np.int64):
result[key] = int(value)
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--save_path", type=str, help="Path to save the statistics.")
args = parser.parse_args()
with open(args.data_path, "r") as f:
sample = json.loads(f.readline())
if "prompt" in sample:
statistics = get_statistics_for_prompt_completion_data(args.data_path)
elif "messages" in sample:
statistics = get_statistics_for_messages_data(args.data_path)
else:
raise ValueError("Invalid data format - the data should be either prompt completion data or messages data.")
print(json.dumps(statistics, indent=4))
if args.save_path is not None:
with open(args.save_path, "w") as f:
json.dump(statistics, f, indent=4)