V2L-Alpha1-Example1 / formatter.py
openagi-agi's picture
Upload 8 files
7cd7caf verified
from datasets import load_dataset
from transformers import T5Tokenizer
import pandas as pd, csv, re
from tqdm import tqdm
# ── Config ────────────────────────────────────────────────────────────────
jsonl_path = "lmsys_chat_1m_full.jsonl" # local file
use_subset = False # False β‡’ full 1 M rows
num_samples = 500 # if subset
max_turn_pairs = 1 # 4 user+assistant = 8 lines
max_input_tokens = 512 # fits t5-small/base
# ──────────────────────────────────────────────────────────────────────────
tok = T5Tokenizer.from_pretrained("t5-small")
ds = load_dataset("json", data_files=jsonl_path, split="train")
if use_subset:
ds = ds.select(range(min(num_samples, len(ds))))
print(f"πŸ” subset β†’ {len(ds)} rows")
def mostly_ascii(s: str, threshold: float = .3) -> bool:
try:
return sum(ord(ch) > 127 for ch in s) / len(s) < threshold
except ZeroDivisionError:
return False
def format_turns(conv):
return [f"{m['role'].capitalize()}: {m['content'].strip()}" for m in conv]
def build_pair(turns, max_tokens=512):
if len(turns) < max_turn_pairs * 2:
return None
# last N pairs
use_turns = turns[-(max_turn_pairs * 2):]
prompt = "chat:\n\n" + "\n\n".join(use_turns[:-1])
target = use_turns[-1].replace("Assistant: ", "", 1)
# --- safe trimming loop --------------------------------------------
for _ in range(max_turn_pairs): # at most 4 trims if max_turn_pairs=4
if len(tok.tokenize(prompt)) <= max_tokens:
break # fits β†’ good
sep_pos = prompt.find("\n\n", len("chat:\n\n"))
if sep_pos == -1: # no more turns to drop
return None
prompt = "chat:\n\n" + prompt[sep_pos + 2:]
else:
# still too long after all trims
return None
# -------------------------------------------------------------------
if len(prompt) < 30 or len(target) < 10:
return None
if not mostly_ascii(prompt + target):
return None
return prompt, target
rows, kept = [], 0
for ex in tqdm(ds, desc="formatting"):
conv = ex.get("conversation")
if not isinstance(conv, list): continue
p = build_pair(format_turns(conv))
if p:
rows.append({"source": p[0], "target": p[1]})
kept += 1
print(f"βœ… kept {kept} examples")
pd.DataFrame(rows).to_csv(
"chat_1turn.csv",
index=False,
quoting=csv.QUOTE_ALL, # preserves embedded newlines
encoding="utf-8"
)
print("πŸ’Ύ saved β†’ t5_chat_4turn.csv")