oneryalcin's picture
Add text register FastText classifier with training scripts
3dea709 verified
"""
Prepare balanced FastText training data from TurkuNLP/register_oscar dataset.
Downloads English shards, extracts labeled documents, and creates a balanced
training set by oversampling minority classes and undersampling majority classes
to the median class size.
Requirements:
pip install huggingface_hub
Usage:
# Download shards first:
for i in $(seq 0 9); do
hf download TurkuNLP/register_oscar \
$(printf "en/en_%05d.jsonl.gz" $i) \
--repo-type dataset --local-dir ./data
done
# Then run:
python prepare_data.py --data-dir ./data/en --output-dir ./prepared
"""
import json
import gzip
import re
import random
import glob
import argparse
from collections import Counter, defaultdict
from pathlib import Path
REGISTER_LABELS = {
"IN": "Informational",
"NA": "Narrative",
"OP": "Opinion",
"IP": "Persuasion",
"HI": "HowTo",
"ID": "Discussion",
"SP": "Spoken",
"LY": "Lyrical",
}
def clean_text(text: str, max_words: int = 500) -> str:
"""Collapse whitespace and truncate to max_words."""
text = re.sub(r"\s+", " ", text).strip()
words = text.split()[:max_words]
return " ".join(words)
def main():
parser = argparse.ArgumentParser(description="Prepare balanced FastText training data")
parser.add_argument("--data-dir", default="./data/en", help="Directory with .jsonl.gz shards")
parser.add_argument("--output-dir", default="./prepared", help="Output directory for train/test files")
parser.add_argument("--max-words", type=int, default=500, help="Max words per document")
parser.add_argument("--min-text-len", type=int, default=50, help="Min character length to keep")
parser.add_argument("--test-ratio", type=float, default=0.1, help="Fraction held out for test")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
random.seed(args.seed)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Collect all labeled docs grouped by primary label
by_label = defaultdict(list)
total = 0
skipped_nolabel = 0
skipped_short = 0
shard_files = sorted(glob.glob(f"{args.data_dir}/*.jsonl.gz"))
if not shard_files:
raise FileNotFoundError(f"No .jsonl.gz files found in {args.data_dir}")
print(f"Found {len(shard_files)} shard(s)")
for shard_file in shard_files:
print(f" Processing {Path(shard_file).name}...")
with gzip.open(shard_file, "rt") as f:
for line in f:
d = json.loads(line)
labels = d.get("labels", [])
text = d.get("text", "")
if not labels:
skipped_nolabel += 1
continue
if len(text) < args.min_text_len:
skipped_short += 1
continue
cleaned = clean_text(text, args.max_words)
if not cleaned:
continue
label_str = " ".join(f"__label__{l}" for l in labels)
ft_line = f"{label_str} {cleaned}\n"
primary = labels[0]
by_label[primary].append(ft_line)
total += 1
print(f"\nTotal labeled docs: {total}")
print(f"Skipped (no label): {skipped_nolabel}")
print(f"Skipped (too short): {skipped_short}")
# Raw distribution
print("\nRaw distribution:")
for label in sorted(by_label.keys()):
name = REGISTER_LABELS.get(label, label)
print(f" {label} ({name}): {len(by_label[label])}")
# Balance: oversample minority to median, undersample majority to median
sizes = {k: len(v) for k, v in by_label.items()}
sorted_sizes = sorted(sizes.values())
median_size = sorted_sizes[len(sorted_sizes) // 2]
target = median_size
print(f"\nBalancing target (median): {target}")
train_lines = []
test_lines = []
for label, lines in by_label.items():
random.shuffle(lines)
n_test = max(len(lines) // 10, 50)
test_pool = lines[:n_test]
train_pool = lines[n_test:]
test_lines.extend(test_pool)
n_train = len(train_pool)
if n_train >= target:
sampled = random.sample(train_pool, target)
train_lines.extend(sampled)
print(f" {label}: {n_train} -> {target} (undersampled)")
else:
train_lines.extend(train_pool)
n_needed = target - n_train
oversampled = random.choices(train_pool, k=n_needed)
train_lines.extend(oversampled)
print(f" {label}: {n_train} -> {target} (oversampled +{n_needed})")
random.shuffle(train_lines)
random.shuffle(test_lines)
train_path = output_dir / "train.txt"
test_path = output_dir / "test.txt"
with open(train_path, "w") as f:
f.writelines(train_lines)
with open(test_path, "w") as f:
f.writelines(test_lines)
print(f"\nTrain: {len(train_lines)} -> {train_path}")
print(f"Test: {len(test_lines)} -> {test_path}")
# Verify balance
c = Counter()
for line in train_lines:
for tok in line.split():
if tok.startswith("__label__"):
c[tok] += 1
print("\nFinal train label distribution:")
for l, cnt in c.most_common():
name = REGISTER_LABELS.get(l.replace("__label__", ""), l)
print(f" {l} ({name}): {cnt}")
if __name__ == "__main__":
main()