SkipMoE / generate.py
chengyanwu
stuff
ccda2ec
#!/usr/bin/env python
"""
Example usage script to evaluate a fine-tuned OlmoE adapter model
and demonstrate generation with adapters.
"""
import argparse
import torch
from transformers import AutoTokenizer
from modeling_olmoe import OlmoEWithAdaptersForCausalLM, OlmoConfig
def generate_text(
model_path: str,
prompt: str,
max_new_tokens: int = 128,
temperature: float = 0.7,
top_p: float = 0.9,
device: str = "auto",
):
"""Generate text using a fine-tuned OlmoE adapter model."""
# Determine device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load tokenizer and model
print(f"Loading model from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load config and update with adapter settings if needed
config = OlmoConfig.from_pretrained(model_path)
# Load adapter model
model = OlmoEWithAdaptersForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
model = model.to(device)
model.eval()
# Tokenize input
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Generate text
print("\nGenerating text...\n")
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
# Decode the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Prompt: {prompt}")
print("\nGenerated text:")
print("=" * 40)
print(generated_text)
print("=" * 40)
return generated_text
def main():
parser = argparse.ArgumentParser(description="Generate text with OlmoE adapter model")
parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model")
parser.add_argument("--prompt", type=str, default="This is an example of", help="Prompt for text generation")
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of new tokens to generate")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling parameter")
parser.add_argument("--device", type=str, default="auto", help="Device to use (cuda, cpu, or auto)")
args = parser.parse_args()
generate_text(
model_path=args.model_path,
prompt=args.prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
device=args.device,
)
if __name__ == "__main__":
main()