melt / fastchat /data /split_long_conversation.py
martinakaduc's picture
Upload folder using huggingface_hub
f3305db verified
"""
Split long conversations based on certain max length.
Usage: python3 -m fastchat.data.split_long_conversation \
--in sharegpt_clean.json \
--out sharegpt_split.json \
--model-name-or-path $<model-name>
"""
import argparse
from concurrent.futures import ProcessPoolExecutor
import json
from typing import Dict, Sequence, Optional
import transformers
from tqdm import tqdm
def make_sample(sample, start_idx, end_idx):
assert (end_idx - start_idx) % 2 == 0
return {
"id": sample["id"] + "_" + str(start_idx),
"model": sample.get("model", ""),
"conversations": sample["conversations"][start_idx:end_idx],
}
tokenizer = max_length = None
def split_one_sample(sample):
tokenized_lens = []
conversations = sample["conversations"]
conversations = conversations[: len(conversations) // 2 * 2]
for c in conversations:
length = len(tokenizer(c["value"]).input_ids) + 6
tokenized_lens.append(length)
start_idx = 0
cur_len = 0
if len(conversations) % 2 != 0 or len(conversations) < 2:
return []
new_samples = []
for i in range(0, len(conversations), 2):
tmp_len = tokenized_lens[i] + tokenized_lens[i + 1]
if cur_len + tmp_len > max_length:
new_samples.append(make_sample(sample, start_idx, i))
start_idx = i
cur_len = 0
elif i == len(conversations) - 2:
new_samples.append(make_sample(sample, start_idx, i + 2))
cur_len += tmp_len
return new_samples
def worker(input_data):
result = []
for sample in input_data:
result.extend(split_one_sample(sample))
return result
def split_all(content, begin, end, tokenizer_, max_length_):
"""
Keep the maximum round of conversations within the max token length constraint
"""
global tokenizer, max_length
tokenizer = tokenizer_
max_length = max_length_
content = content[begin:end]
new_content = []
# Split content into chunks
chunks = [content[i : i + 1000] for i in range(0, len(content), 1000)]
with ProcessPoolExecutor() as executor:
for result in tqdm(executor.map(worker, chunks), total=len(chunks)):
new_content.extend(result)
return new_content
def filter_invalid_roles(content):
new_content = []
for i, c in enumerate(content):
roles = ["human", "gpt"]
if len(c["conversations"]) <= 0:
continue
valid = True
for j, s in enumerate(c["conversations"]):
if s["from"] != roles[j % 2]:
valid = False
break
if valid:
new_content.append(c)
return new_content
def main(args):
content = json.load(open(args.in_file, "r"))
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.model_name_or_path,
model_max_length=args.max_length,
padding_side="right",
use_fast=False,
)
new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length)
new_content = filter_invalid_roles(new_content)
print(f"#in: {len(content)}, #out: {len(new_content)}")
json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in-file", type=str, required=True)
parser.add_argument("--out-file", type=str, default="sharegpt_split.json")
parser.add_argument("--begin", type=int)
parser.add_argument("--end", type=int)
parser.add_argument("--model-name-or-path", type=str, required=True)
parser.add_argument("--max-length", type=int, default=2048)
args = parser.parse_args()
main(args)