creative-help / test_model.py
roemmele's picture
Enabled control of generation parameters; created README.md
c886682
#!/usr/bin/env python3
import argparse
import json
import os
import random
import sys
def set_seed(seed: int):
"""Set random seeds for reproducibility."""
random.seed(seed)
try:
import numpy as np
np.random.seed(seed)
except ImportError:
pass
try:
import torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
except ImportError:
pass
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path", "-m", default=".", help="Path to converted model")
parser.add_argument(
"--prompts", "-p", default="test_prompts.json",
help="Path to JSON file with list of prompt strings (default: hf_conversion/test_prompts.json)")
parser.add_argument(
"--seed", "-s", type=int, default=0,
help="Random seed for reproducible generation (default: None, non-deterministic)")
parser.add_argument(
"--max_new_tokens", type=int, default=None,
help="Max tokens to generate (default: 50)")
parser.add_argument(
"--max_new_sents", type=int, default=None,
help="Max sentences in decoded output (default: pipeline default)")
args = parser.parse_args()
if args.seed is not None:
set_seed(args.seed)
print(f"Random seed set to {args.seed} for reproducibility")
if not os.path.isdir(args.model_path):
print(f"Error: Model path {args.model_path} does not exist.")
sys.exit(1)
prompts_path = args.prompts
if prompts_path is None:
prompts_path = os.path.join(os.path.dirname(
os.path.abspath(__file__)), "test_prompts.json")
if not os.path.isfile(prompts_path):
print(f"Error: Prompts file {prompts_path} does not exist.")
sys.exit(1)
print("Loading model and tokenizer...")
from transformers import AutoModelForCausalLM
# Register custom model and load tokenizer directly (AutoTokenizer doesn't know RNNLMTokenizer)
model_path = os.path.abspath(args.model_path)
from rnnlm_model import (
RNNLMConfig,
RNNLMForCausalLM,
RNNLMTokenizer,
RNNLMTextGenerationPipeline,
)
from transformers import AutoConfig
AutoConfig.register("rnnlm", RNNLMConfig)
AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM)
model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True)
tokenizer = RNNLMTokenizer.from_pretrained(model_path)
print("Creating RNNLMTextGenerationPipeline (with entity adaptation)...")
pipe = RNNLMTextGenerationPipeline(
model=model,
tokenizer=tokenizer,
)
with open(prompts_path) as f:
test_prompts = json.load(f)
base_kwargs = dict(
max_new_tokens=args.max_new_tokens if args.max_new_tokens is not None else 50,
do_sample=True,
temperature=1.0,
pad_token_id=tokenizer.pad_token_id,
)
if args.max_new_sents is not None:
base_kwargs["max_new_sents"] = args.max_new_sents
def run_tests(kwargs):
for i, prompt in enumerate(test_prompts):
print(f"\n [{i + 1}/{len(test_prompts)}]")
print(f" PROMPT: ``{prompt}``")
output = pipe(prompt, **kwargs)
print(f" GENERATED: ``{output[0]['generated_text']}``")
# Test 1: Basic generation with default params
print("\n--- Test 1: Basic generation (default params) ---")
run_tests(base_kwargs)
# Test 2: max_new_tokens=20
print("\n--- Test 2: max_new_tokens=20 ---")
short_kwargs = {**base_kwargs, "max_new_tokens": 20}
run_tests(short_kwargs)
# Test 3: max_new_sents=2
print("\n--- Test 3: max_new_sents=2 ---")
sents_kwargs = {**base_kwargs, "max_new_sents": 2}
run_tests(sents_kwargs)
# Test 4: max_new_sents=1
print("\n--- Test 4: max_new_sents=1 ---")
sents1_kwargs = {**base_kwargs, "max_new_sents": 1}
run_tests(sents1_kwargs)
# Test 5: do_sample=False (greedy decoding)
print("\n--- Test 5: do_sample=False ---")
greedy_kwargs = {**base_kwargs, "do_sample": False}
run_tests(greedy_kwargs)
# Test 6: temperature=0.3
print("\n--- Test 6: temperature=0.3 ---")
low_temp_kwargs = {**base_kwargs, "temperature": 0.3}
run_tests(low_temp_kwargs)
if __name__ == "__main__":
main()