|
""" |
|
Split long conversations based on certain max length. |
|
|
|
Usage: python3 -m fastchat.data.split_long_conversation \ |
|
--in sharegpt_clean.json \ |
|
--out sharegpt_split.json \ |
|
--model-name-or-path $<model-name> |
|
""" |
|
import argparse |
|
from concurrent.futures import ProcessPoolExecutor |
|
import json |
|
from typing import Dict, Sequence, Optional |
|
|
|
import transformers |
|
from tqdm import tqdm |
|
|
|
|
|
def make_sample(sample, start_idx, end_idx): |
|
assert (end_idx - start_idx) % 2 == 0 |
|
return { |
|
"id": sample["id"] + "_" + str(start_idx), |
|
"model": sample.get("model", ""), |
|
"conversations": sample["conversations"][start_idx:end_idx], |
|
} |
|
|
|
|
|
tokenizer = max_length = None |
|
|
|
|
|
def split_one_sample(sample): |
|
tokenized_lens = [] |
|
conversations = sample["conversations"] |
|
conversations = conversations[: len(conversations) // 2 * 2] |
|
for c in conversations: |
|
length = len(tokenizer(c["value"]).input_ids) + 6 |
|
tokenized_lens.append(length) |
|
|
|
start_idx = 0 |
|
cur_len = 0 |
|
|
|
if len(conversations) % 2 != 0 or len(conversations) < 2: |
|
return [] |
|
|
|
new_samples = [] |
|
for i in range(0, len(conversations), 2): |
|
tmp_len = tokenized_lens[i] + tokenized_lens[i + 1] |
|
if cur_len + tmp_len > max_length: |
|
new_samples.append(make_sample(sample, start_idx, i)) |
|
start_idx = i |
|
cur_len = 0 |
|
elif i == len(conversations) - 2: |
|
new_samples.append(make_sample(sample, start_idx, i + 2)) |
|
|
|
cur_len += tmp_len |
|
|
|
return new_samples |
|
|
|
|
|
def worker(input_data): |
|
result = [] |
|
for sample in input_data: |
|
result.extend(split_one_sample(sample)) |
|
return result |
|
|
|
|
|
def split_all(content, begin, end, tokenizer_, max_length_): |
|
""" |
|
Keep the maximum round of conversations within the max token length constraint |
|
""" |
|
global tokenizer, max_length |
|
tokenizer = tokenizer_ |
|
max_length = max_length_ |
|
|
|
content = content[begin:end] |
|
new_content = [] |
|
|
|
|
|
chunks = [content[i : i + 1000] for i in range(0, len(content), 1000)] |
|
with ProcessPoolExecutor() as executor: |
|
for result in tqdm(executor.map(worker, chunks), total=len(chunks)): |
|
new_content.extend(result) |
|
|
|
return new_content |
|
|
|
|
|
def filter_invalid_roles(content): |
|
new_content = [] |
|
for i, c in enumerate(content): |
|
roles = ["human", "gpt"] |
|
if len(c["conversations"]) <= 0: |
|
continue |
|
|
|
valid = True |
|
for j, s in enumerate(c["conversations"]): |
|
if s["from"] != roles[j % 2]: |
|
valid = False |
|
break |
|
|
|
if valid: |
|
new_content.append(c) |
|
|
|
return new_content |
|
|
|
|
|
def main(args): |
|
content = json.load(open(args.in_file, "r")) |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
args.model_name_or_path, |
|
model_max_length=args.max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length) |
|
new_content = filter_invalid_roles(new_content) |
|
|
|
print(f"#in: {len(content)}, #out: {len(new_content)}") |
|
json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--in-file", type=str, required=True) |
|
parser.add_argument("--out-file", type=str, default="sharegpt_split.json") |
|
parser.add_argument("--begin", type=int) |
|
parser.add_argument("--end", type=int) |
|
parser.add_argument("--model-name-or-path", type=str, required=True) |
|
parser.add_argument("--max-length", type=int, default=2048) |
|
args = parser.parse_args() |
|
main(args) |
|
|