""" Get stats of a dataset. Usage: python3 -m fastchat.data.get_stats --in sharegpt.json """ import argparse from concurrent.futures import ProcessPoolExecutor import json import numpy as np from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM K = 1e3 M = 1e6 def tokenize_one_sample(c): for i in range(len(c["conversations"])): v = c["conversations"][i]["value"] c["conversations"][i]["value"] = tokenizer.tokenize(v) return c def tokenize_dataset(content): processed = [] with ProcessPoolExecutor() as executor: for result in tqdm( executor.map(tokenize_one_sample, content), total=len(content) ): processed.append(result) return processed def compute_stats(content): sample_lens = [] sample_turns = [] prompt_lens = [] res_lens = [] for c in content: sample_len = 0 sample_turns.append(len(c["conversations"]) // 2) for i in range(len(c["conversations"]) // 2): p = c["conversations"][i * 2]["value"] r = c["conversations"][i * 2 + 1]["value"] turn_len = len(p) + len(r) sample_len += turn_len prompt_lens.append(len(p)) res_lens.append(len(r)) sample_lens.append(sample_len) return sample_lens, sample_turns, prompt_lens, res_lens if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--in-file", type=str) parser.add_argument( "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" ) args = parser.parse_args() content = json.load(open(args.in_file, "r")) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) content = tokenize_dataset(content) sample_lens, sample_turns, prompt_lens, res_lens = compute_stats(content) print(f"#sequence: {len(content)/K:.2f} K") print(f"#tokens: {np.sum(sample_lens)/M:.2f} M") print(f"avg. turns: {np.mean(sample_turns):.2f}") print(f"avg. prompt length: {np.mean(prompt_lens):.2f}") print(f"avg. response length: {np.mean(res_lens):.2f}") print("\n- Histogram -") bin_edges = [0, 1024, 2048, 4096, 8192, 16384, 32768] hist = np.histogram(sample_lens, bins=bin_edges)[0] for i in range(len(hist)): print(f"L{bin_edges[i]} - {bin_edges[i+1]}: {hist[i]}")