| | """Quick test of model quality with diverse prompts.""" |
| |
|
| | import os, sys, time, torch |
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from model.config import ModelConfig |
| | from model.transformer import Transformer |
| | from model.data import get_tokenizer |
| |
|
| | DPO_CKPT = "/jfs/deepak-kumar/checkpoints_dpo/dpo_final.pt" |
| | SFT_CKPT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt" |
| | CHECKPOINT = DPO_CKPT if os.path.exists(DPO_CKPT) else SFT_CKPT |
| | DEVICE = "cuda:0" |
| |
|
| | USER_START = "<|user|>\n" |
| | ASST_START = "<|assistant|>\n" |
| | TURN_END = "\n<|end|>\n" |
| |
|
| | TEST_PROMPTS = [ |
| | "Hi! How are you?", |
| | "What is photosynthesis?", |
| | "Explain gravity to a 5-year-old.", |
| | "Write a short poem about the ocean.", |
| | "What are the three states of matter?", |
| | "How does a computer work?", |
| | "What is the capital of France and why is it famous?", |
| | "Give me 3 tips for learning a new language.", |
| | "What is machine learning in simple terms?", |
| | ] |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate(model, tokenizer, prompt, max_new_tokens=256, |
| | temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.15): |
| | input_ids = tokenizer.encode(prompt, add_special_tokens=False) |
| | input_ids = torch.tensor([input_ids], dtype=torch.long, device=DEVICE) |
| | generated = [] |
| | eos_id = tokenizer.eos_token_id |
| |
|
| | end_token_ids = tokenizer.encode("<|end|>", add_special_tokens=False) |
| | end_id = end_token_ids[0] if end_token_ids else None |
| | user_token_ids = tokenizer.encode("<|user|>", add_special_tokens=False) |
| | user_id = user_token_ids[0] if user_token_ids else None |
| |
|
| | stop_ids = set() |
| | if eos_id is not None: |
| | stop_ids.add(eos_id) |
| | if end_id is not None: |
| | stop_ids.add(end_id) |
| | if user_id is not None: |
| | stop_ids.add(user_id) |
| |
|
| | for _ in range(max_new_tokens): |
| | with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| | logits, _ = model(input_ids) |
| |
|
| | logits = logits[:, -1, :].float() |
| |
|
| | if repetition_penalty != 1.0 and generated: |
| | for tid in set(generated): |
| | if logits[0, tid] > 0: |
| | logits[0, tid] /= repetition_penalty |
| | else: |
| | logits[0, tid] *= repetition_penalty |
| |
|
| | logits = logits / max(temperature, 1e-5) |
| |
|
| | if top_k > 0: |
| | topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < topk_vals[:, -1:]] = float('-inf') |
| |
|
| | if top_p < 1.0: |
| | sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| | cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| | remove = cumulative - torch.softmax(sorted_logits, dim=-1) > top_p |
| | sorted_logits[remove] = float('-inf') |
| | logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) |
| |
|
| | probs = torch.softmax(logits, dim=-1) |
| | next_token = torch.multinomial(probs, 1) |
| | token_id = next_token.item() |
| |
|
| | if token_id in stop_ids: |
| | break |
| |
|
| | generated.append(token_id) |
| | input_ids = torch.cat([input_ids, next_token], dim=1) |
| |
|
| | if input_ids.size(1) > 2048: |
| | break |
| |
|
| | return tokenizer.decode(generated, skip_special_tokens=True) |
| |
|
| |
|
| | def main(): |
| | ckpt_name = "DPO" if "dpo" in CHECKPOINT else "SFT" |
| | print("=" * 70) |
| | print(" " + ckpt_name + " MODEL TEST") |
| | print("=" * 70) |
| |
|
| | tokenizer = get_tokenizer() |
| | special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"] |
| | vocab = tokenizer.get_vocab() |
| | new_tokens = [t for t in special_tokens if t not in vocab] |
| | if new_tokens: |
| | tokenizer.add_tokens(new_tokens, special_tokens=True) |
| |
|
| | config = ModelConfig() |
| | config.vocab_size = len(tokenizer) |
| | model = Transformer(config) |
| |
|
| | print("") |
| | print("Loading checkpoint: " + CHECKPOINT) |
| | ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False) |
| | model.load_state_dict(ckpt["model"]) |
| | step = ckpt.get("step", "?") |
| | del ckpt |
| |
|
| | model = model.to(DEVICE).bfloat16().eval() |
| | print("Model loaded (" + ckpt_name + " step " + str(step) + ", vocab " + str(config.vocab_size) + ")") |
| | mem = torch.cuda.max_memory_allocated(DEVICE) / 1e9 |
| | print("GPU memory: " + str(round(mem, 1)) + " GB") |
| | print("-" * 70) |
| |
|
| | for i, question in enumerate(TEST_PROMPTS, 1): |
| | prompt = USER_START + question + TURN_END + ASST_START |
| |
|
| | print("") |
| | print("[Test " + str(i) + "/" + str(len(TEST_PROMPTS)) + "]") |
| | print(" Q: " + question) |
| |
|
| | t0 = time.time() |
| | response = generate(model, tokenizer, prompt) |
| | dt = time.time() - t0 |
| | tokens = len(tokenizer.encode(response, add_special_tokens=False)) |
| |
|
| | response = response.split("<|end|>")[0].split("<|user|>")[0].strip() |
| |
|
| | print(" A: " + response) |
| | tps = int(tokens / max(dt, 0.01)) |
| | print(" [" + str(tokens) + " tokens, " + str(round(dt, 1)) + "s, " + str(tps) + " tok/s]") |
| | print("-" * 70) |
| |
|
| | print("") |
| | print("Done!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|