pathcosmos's picture
Upload folder using huggingface_hub (#29)
5b1ff4d
"""
ppl_task.py — Sliding-window perplexity evaluation task.
Top-level functions for ProcessPoolExecutor (spawn) compatibility:
- eval_ppl_single(val_file, device, model=None) -> dict
- eval_ppl_multi(val_files, device) -> list[dict]
"""
from __future__ import annotations
import math
import sys
import time
from pathlib import Path
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
_DEFAULT_CHECKPOINT = str(_PROJECT_ROOT / "checkpoints" / "korean_3b_fp8_run1" / "checkpoint-0057000")
CHECKPOINT = os.environ.get("EVAL_CHECKPOINT", _DEFAULT_CHECKPOINT)
TOKENIZER_PATH = os.environ.get("EVAL_TOKENIZER", str(_PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json"))
DATA_DIR = _PROJECT_ROOT / "data"
SEQ_LEN = 2048
STRIDE = 512
BATCH_SIZE = 32
# ---------------------------------------------------------------------------
# Shared dataset / model utilities
# ---------------------------------------------------------------------------
class SlidingWindowDataset(Dataset):
"""Sliding-window tokenized dataset for perplexity evaluation."""
def __init__(self, tokens: np.ndarray, seq_len: int, stride: int) -> None:
self.tokens = tokens
self.seq_len = seq_len
self.stride = stride
self.n_windows = max(0, (len(tokens) - seq_len + stride - 1) // stride)
def __len__(self) -> int:
return self.n_windows
def __getitem__(self, idx: int):
start = idx * self.stride
end = start + self.seq_len
actual_end = min(end, len(self.tokens))
chunk_len = actual_end - start
input_ids = torch.zeros(self.seq_len, dtype=torch.long)
targets = torch.full((self.seq_len,), fill_value=-100, dtype=torch.long)
loss_mask = torch.zeros(self.seq_len, dtype=torch.bool)
if chunk_len > 1:
toks = torch.from_numpy(self.tokens[start:actual_end].astype(np.int64))
input_ids[:chunk_len] = toks
targets[:chunk_len - 1] = toks[1:]
new_start = 0 if idx == 0 else self.stride
if chunk_len > 1:
for pos in range(new_start, chunk_len - 1):
loss_mask[pos] = True
return input_ids, targets, loss_mask
def _load_model(device: str):
"""Load FRANKENSTALLM 3B from checkpoint onto the given device."""
from model.transformer import LLM # type: ignore[import]
model = LLM.from_pretrained(CHECKPOINT)
model = model.to(device=device, dtype=torch.bfloat16)
model.eval()
return model
def _load_tokenizer():
"""Load the Korean SentencePiece tokenizer."""
from tokenizers import Tokenizer # type: ignore[import]
return Tokenizer.from_file(TOKENIZER_PATH)
# ---------------------------------------------------------------------------
# Main task functions (must be top-level for pickle / spawn compatibility)
# ---------------------------------------------------------------------------
def eval_ppl_single(val_file: str, device: str, model=None) -> dict:
"""Compute sliding-window perplexity for a single validation file.
Args:
val_file: Relative path under DATA_DIR, e.g. "3b_val.bin".
device: CUDA device string, e.g. "cuda:0".
model: Optional pre-loaded model. If None, loads from checkpoint.
Returns:
Dict with keys: name, file, n_tokens, n_eval_tokens, ppl,
bits_per_token, avg_nll, elapsed_sec, device.
"""
torch.cuda.set_device(int(device.split(":")[-1]))
data_path = DATA_DIR / val_file
if not data_path.exists():
raise FileNotFoundError(f"Validation file not found: {data_path}")
name = val_file.replace("_val.bin", "").replace(".bin", "")
own_model = model is None
if own_model:
print(f"[PPL {device}] Loading model for {name}...")
model = _load_model(device)
tokens = np.fromfile(str(data_path), dtype=np.uint16)
if len(tokens) == 0:
raise ValueError(f"Validation file is empty (0 tokens): {data_path}")
n_tokens = len(tokens)
print(f"[PPL {device}] {name}: {n_tokens:,} tokens, {n_tokens * 2 / 1e6:.1f} MB")
ds = SlidingWindowDataset(tokens, SEQ_LEN, STRIDE)
dl = DataLoader(
ds,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=4,
pin_memory=True,
)
total_nll = 0.0
total_count = 0
t0 = time.time()
with torch.inference_mode():
for batch_idx, (inp, tgt, mask) in enumerate(dl):
inp = inp.to(device)
tgt = tgt.to(device)
mask = mask.to(device)
logits, _ = model(inp)
loss_flat = F.cross_entropy(
logits.view(-1, logits.size(-1)),
tgt.view(-1),
reduction="none",
)
loss_flat = loss_flat.view(mask.shape)
nll = (loss_flat * mask.float()).sum().item()
cnt = mask.sum().item()
total_nll += nll
total_count += cnt
if (batch_idx + 1) % 50 == 0:
running_ppl = (
math.exp(total_nll / total_count) if total_count > 0 else float("inf")
)
elapsed = time.time() - t0
print(
f"[PPL {device}] {name}: batch {batch_idx + 1}/{len(dl)}, "
f"running PPL={running_ppl:.4f}, {elapsed:.0f}s"
)
avg_nll = total_nll / total_count if total_count > 0 else 0.0
ppl = math.exp(avg_nll)
bpt = avg_nll / math.log(2)
elapsed = time.time() - t0
result: dict = {
"name": name,
"file": val_file,
"n_tokens": int(n_tokens),
"n_eval_tokens": int(total_count),
"ppl": round(ppl, 4),
"bits_per_token": round(bpt, 4),
"avg_nll": round(avg_nll, 6),
"elapsed_sec": round(elapsed, 1),
"device": device,
}
print(
f"[PPL {device}] DONE {name}: PPL={ppl:.4f}, BPT={bpt:.4f}, {elapsed:.1f}s"
)
return result
def eval_ppl_multi(val_files: list[str], device: str) -> list[dict]:
"""Compute PPL for multiple val files on a single GPU, loading model once.
Args:
val_files: List of relative paths under DATA_DIR.
device: CUDA device string.
Returns:
List of result dicts (one per file), in the same order as val_files.
"""
torch.cuda.set_device(int(device.split(":")[-1]))
print(f"[PPL_MULTI {device}] Loading model once for {len(val_files)} files...")
model = _load_model(device)
results: list[dict] = []
for val_file in val_files:
result = eval_ppl_single(val_file, device, model=model)
results.append(result)
return results