| """ |
| VicAI Text Generation |
| Interactive text generation and sampling utilities. |
| """ |
|
|
| import argparse |
| import sys |
|
|
| import torch |
|
|
| from model import VicAIModel, VicAIConfig, create_vicai_5b |
| from tokenizer import ByteLevelBPETokenizer, BPETokenizer |
| from utils import get_logger |
|
|
|
|
| def generate_interactive( |
| model, |
| tokenizer, |
| device, |
| max_new_tokens: int = 256, |
| temperature: float = 0.8, |
| top_k: int = 50, |
| top_p: float = 0.9, |
| repetition_penalty: float = 1.1, |
| ): |
| """Interactive text generation loop.""" |
| print("\n" + "=" * 60) |
| print("VicAI Interactive Generation") |
| print("=" * 60) |
| print("Commands:") |
| print(" /quit - Exit the program") |
| print(" /config - Show current generation settings") |
| print(" /temp X - Set temperature (0.1 - 2.0)") |
| print(" /topk X - Set top-k (1 - 100)") |
| print(" /topp X - Set top-p (0.0 - 1.0)") |
| print(" /reppen X - Set repetition penalty (1.0 - 2.0)") |
| print(" /maxlen X - Set max new tokens") |
| print("=" * 60 + "\n") |
| |
| |
| settings = { |
| 'temperature': temperature, |
| 'top_k': top_k, |
| 'top_p': top_p, |
| 'repetition_penalty': repetition_penalty, |
| 'max_new_tokens': max_new_tokens, |
| } |
| |
| while True: |
| try: |
| |
| prompt = input("\nPrompt: ").strip() |
| |
| |
| if prompt == '/quit': |
| print("Goodbye!") |
| break |
| |
| if prompt == '/config': |
| print("\nCurrent settings:") |
| for key, value in settings.items(): |
| print(f" {key}: {value}") |
| continue |
| |
| if prompt.startswith('/temp '): |
| try: |
| settings['temperature'] = float(prompt.split()[1]) |
| print(f"Temperature set to {settings['temperature']}") |
| except (ValueError, IndexError): |
| print("Invalid temperature value") |
| continue |
| |
| if prompt.startswith('/topk '): |
| try: |
| settings['top_k'] = int(prompt.split()[1]) |
| print(f"Top-k set to {settings['top_k']}") |
| except (ValueError, IndexError): |
| print("Invalid top-k value") |
| continue |
| |
| if prompt.startswith('/topp '): |
| try: |
| settings['top_p'] = float(prompt.split()[1]) |
| print(f"Top-p set to {settings['top_p']}") |
| except (ValueError, IndexError): |
| print("Invalid top-p value") |
| continue |
| |
| if prompt.startswith('/reppen '): |
| try: |
| settings['repetition_penalty'] = float(prompt.split()[1]) |
| print(f"Repetition penalty set to {settings['repetition_penalty']}") |
| except (ValueError, IndexError): |
| print("Invalid repetition penalty value") |
| continue |
| |
| if prompt.startswith('/maxlen '): |
| try: |
| settings['max_new_tokens'] = int(prompt.split()[1]) |
| print(f"Max new tokens set to {settings['max_new_tokens']}") |
| except (ValueError, IndexError): |
| print("Invalid max new tokens value") |
| continue |
| |
| if not prompt: |
| continue |
| |
| |
| input_ids = torch.tensor([tokenizer.encode(prompt)], device=device) |
| |
| |
| print("\nGenerating...") |
| with torch.no_grad(): |
| output_ids = model.generate( |
| input_ids, |
| max_new_tokens=settings['max_new_tokens'], |
| temperature=settings['temperature'], |
| top_k=settings['top_k'], |
| top_p=settings['top_p'], |
| repetition_penalty=settings['repetition_penalty'], |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| |
| |
| generated_text = tokenizer.decode(output_ids[0].tolist()) |
| |
| prompt_text = tokenizer.decode(input_ids[0].tolist()) |
| if generated_text.startswith(prompt_text): |
| generated_text = generated_text[len(prompt_text):].strip() |
| |
| print("\n" + "-" * 60) |
| print("Generated:") |
| print("-" * 60) |
| print(generated_text) |
| print("-" * 60) |
| |
| |
| num_tokens = output_ids.shape[1] - input_ids.shape[1] |
| print(f"\nTokens generated: {num_tokens}") |
| |
| except KeyboardInterrupt: |
| print("\n\nInterrupted by user. Type /quit to exit.") |
| except Exception as e: |
| print(f"\nError: {e}") |
|
|
|
|
| def generate_batch( |
| model, |
| tokenizer, |
| prompts: list, |
| device, |
| max_new_tokens: int = 256, |
| temperature: float = 0.8, |
| top_k: int = 50, |
| top_p: float = 0.9, |
| ): |
| """Generate completions for multiple prompts.""" |
| results = [] |
| |
| for prompt in prompts: |
| input_ids = torch.tensor([tokenizer.encode(prompt)], device=device) |
| |
| with torch.no_grad(): |
| output_ids = model.generate( |
| input_ids, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| |
| generated_text = tokenizer.decode(output_ids[0].tolist()) |
| prompt_text = tokenizer.decode(input_ids[0].tolist()) |
| |
| if generated_text.startswith(prompt_text): |
| generated_text = generated_text[len(prompt_text):].strip() |
| |
| results.append({ |
| 'prompt': prompt, |
| 'completion': generated_text, |
| }) |
| |
| return results |
|
|
|
|
| def benchmark_generation( |
| model, |
| tokenizer, |
| device, |
| num_runs: int = 10, |
| max_new_tokens: int = 128, |
| prompt: str = "The future of artificial intelligence is", |
| ): |
| """Benchmark generation speed.""" |
| import time |
| |
| print(f"\nBenchmarking generation ({num_runs} runs)...") |
| |
| input_ids = torch.tensor([tokenizer.encode(prompt)], device=device) |
| |
| |
| with torch.no_grad(): |
| _ = model.generate(input_ids, max_new_tokens=10) |
| |
| torch.cuda.synchronize() |
| |
| |
| times = [] |
| tokens_generated = [] |
| |
| for i in range(num_runs): |
| start = time.time() |
| |
| with torch.no_grad(): |
| output = model.generate( |
| input_ids, |
| max_new_tokens=max_new_tokens, |
| temperature=1.0, |
| ) |
| |
| torch.cuda.synchronize() |
| elapsed = time.time() - start |
| |
| num_tokens = output.shape[1] - input_ids.shape[1] |
| times.append(elapsed) |
| tokens_generated.append(num_tokens) |
| |
| print(f" Run {i+1}: {num_tokens} tokens in {elapsed:.2f}s ({num_tokens/elapsed:.1f} tok/s)") |
| |
| avg_time = sum(times) / len(times) |
| avg_tokens = sum(tokens_generated) / len(tokens_generated) |
| avg_speed = avg_tokens / avg_time |
| |
| print(f"\nAverage: {avg_tokens:.1f} tokens in {avg_time:.2f}s ({avg_speed:.1f} tok/s)") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Generate text with VicAI') |
| |
| parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint') |
| parser.add_argument('--tokenizer', type=str, default='tokenizer.pkl', help='Path to tokenizer') |
| parser.add_argument('--prompt', type=str, default=None, help='Single prompt to generate from') |
| parser.add_argument('--interactive', action='store_true', help='Interactive mode') |
| parser.add_argument('--max-new-tokens', type=int, default=256, help='Maximum tokens to generate') |
| parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature') |
| parser.add_argument('--top-k', type=int, default=50, help='Top-k sampling') |
| parser.add_argument('--top-p', type=float, default=0.9, help='Top-p (nucleus) sampling') |
| parser.add_argument('--repetition-penalty', type=float, default=1.1, help='Repetition penalty') |
| parser.add_argument('--benchmark', action='store_true', help='Run generation benchmark') |
| parser.add_argument('--device', type=str, default='cuda', help='Device to use') |
| |
| args = parser.parse_args() |
| |
| |
| device = torch.device(args.device if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
| |
| |
| print(f"Loading tokenizer from {args.tokenizer}...") |
| |
| tokenizer = ByteLevelBPETokenizer() |
| tokenizer.load(args.tokenizer) |
| print(f"Tokenizer loaded: {len(tokenizer)} tokens") |
| |
| |
| print(f"Loading model from {args.checkpoint}...") |
| checkpoint = torch.load(args.checkpoint, map_location=device) |
| |
| |
| model = create_vicai_5b(vocab_size=len(tokenizer)) |
| |
| |
| state_dict = checkpoint.get('model', checkpoint) |
| model.load_state_dict(state_dict) |
| model = model.to(device) |
| model.eval() |
| |
| print(f"Model loaded: ~{model.get_num_params() / 1e9:.2f}B parameters") |
| |
| |
| if args.benchmark: |
| benchmark_generation(model, tokenizer, device) |
| return |
| |
| |
| if args.interactive or args.prompt is None: |
| generate_interactive( |
| model, |
| tokenizer, |
| device, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| ) |
| else: |
| |
| print(f"\nPrompt: {args.prompt}") |
| print("-" * 60) |
| |
| input_ids = torch.tensor([tokenizer.encode(args.prompt)], device=device) |
| |
| with torch.no_grad(): |
| output_ids = model.generate( |
| input_ids, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| |
| generated_text = tokenizer.decode(output_ids[0].tolist()) |
| prompt_text = tokenizer.decode(input_ids[0].tolist()) |
| |
| if generated_text.startswith(prompt_text): |
| generated_text = generated_text[len(prompt_text):].strip() |
| |
| print(generated_text) |
| print("-" * 60) |
| |
| num_tokens = output_ids.shape[1] - input_ids.shape[1] |
| print(f"\nGenerated {num_tokens} tokens") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|