#!/usr/bin/env python """ Example usage script to evaluate a fine-tuned OlmoE adapter model and demonstrate generation with adapters. """ import argparse import torch from transformers import AutoTokenizer from modeling_olmoe import OlmoEWithAdaptersForCausalLM, OlmoConfig def generate_text( model_path: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.7, top_p: float = 0.9, device: str = "auto", ): """Generate text using a fine-tuned OlmoE adapter model.""" # Determine device if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load tokenizer and model print(f"Loading model from {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path) # Load config and update with adapter settings if needed config = OlmoConfig.from_pretrained(model_path) # Load adapter model model = OlmoEWithAdaptersForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16 if device == "cuda" else torch.float32, ) model = model.to(device) model.eval() # Tokenize input input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Generate text print("\nGenerating text...\n") with torch.no_grad(): outputs = model.generate( input_ids, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, ) # Decode the generated text generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"Prompt: {prompt}") print("\nGenerated text:") print("=" * 40) print(generated_text) print("=" * 40) return generated_text def main(): parser = argparse.ArgumentParser(description="Generate text with OlmoE adapter model") parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model") parser.add_argument("--prompt", type=str, default="This is an example of", help="Prompt for text generation") parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of new tokens to generate") parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling parameter") parser.add_argument("--device", type=str, default="auto", help="Device to use (cuda, cpu, or auto)") args = parser.parse_args() generate_text( model_path=args.model_path, prompt=args.prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, device=args.device, ) if __name__ == "__main__": main()