|
|
| """Zenith-7B Inference Script for Standard GPUs"""
|
|
|
| import torch
|
| import argparse
|
| from pathlib import Path
|
| from typing import Optional, Dict, Any
|
|
|
|
|
| import sys
|
| sys.path.append(str(Path(__file__).parent))
|
|
|
| from configs.zenith_config import get_7b_config
|
| from models.zenith_model import ZenithForCausalLM
|
| from data.advanced_tokenizer import AdvancedTokenizer
|
|
|
|
|
| def load_model(checkpoint_path: str, device: str = "cuda"):
|
| """Load trained model from checkpoint."""
|
| config = get_7b_config()
|
|
|
|
|
| tokenizer = AdvancedTokenizer.from_pretrained(checkpoint_path)
|
| config.vocab_size = tokenizer.get_vocab_size()
|
|
|
|
|
| model = ZenithForCausalLM.from_pretrained(
|
| checkpoint_path,
|
| config=config,
|
| device_map="auto" if device == "cuda" else None
|
| )
|
| model.eval()
|
|
|
| return model, tokenizer
|
|
|
|
|
| def generate(
|
| model: ZenithForCausalLM,
|
| tokenizer: AdvancedTokenizer,
|
| prompt: str,
|
| max_new_tokens: int = 512,
|
| temperature: float = 0.7,
|
| top_p: float = 0.9,
|
| top_k: int = 50,
|
| repetition_penalty: float = 1.1,
|
| do_sample: bool = True,
|
| stream: bool = False
|
| ):
|
| """Generate text from the model."""
|
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
|
|
|
| with torch.no_grad():
|
| if stream:
|
|
|
| from transformers import TextIteratorStreamer
|
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| generation_kwargs = dict(
|
| input_ids=input_ids,
|
| max_new_tokens=max_new_tokens,
|
| temperature=temperature,
|
| top_p=top_p,
|
| top_k=top_k,
|
| repetition_penalty=repetition_penalty,
|
| do_sample=do_sample,
|
| streamer=streamer
|
| )
|
| from threading import Thread
|
| thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| thread.start()
|
| return streamer
|
| else:
|
| outputs = model.generate(
|
| input_ids=input_ids,
|
| max_new_tokens=max_new_tokens,
|
| temperature=temperature,
|
| top_p=top_p,
|
| top_k=top_k,
|
| repetition_penalty=repetition_penalty,
|
| do_sample=do_sample,
|
| pad_token_id=tokenizer.pad_token_id,
|
| eos_token_id=tokenizer.eos_token_id
|
| )
|
| return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
| def interactive_mode(model, tokenizer):
|
| """Run interactive chat session."""
|
| print("=" * 60)
|
| print("Zenith-7B Interactive Mode")
|
| print("Type 'quit' to exit, 'clear' to clear history")
|
| print("=" * 60)
|
|
|
| history = []
|
| while True:
|
| try:
|
| user_input = input("\nYou: ").strip()
|
| if user_input.lower() == 'quit':
|
| break
|
| if user_input.lower() == 'clear':
|
| history = []
|
| print("History cleared.")
|
| continue
|
|
|
|
|
| prompt = ""
|
| for user_msg, assistant_msg in history[-4:]:
|
| prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
|
| prompt += f"User: {user_input}\nAssistant:"
|
|
|
| print("\nZenith: ", end="", flush=True)
|
| response = generate(model, tokenizer, prompt, stream=True)
|
| full_response = ""
|
| for token in response:
|
| print(token, end="", flush=True)
|
| full_response += token
|
| print()
|
|
|
| history.append((user_input, full_response))
|
|
|
| except KeyboardInterrupt:
|
| print("\n\nInterrupted. Type 'quit' to exit.")
|
| except Exception as e:
|
| print(f"\nError: {e}")
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Zenith-7B Inference")
|
| parser.add_argument(
|
| "--checkpoint",
|
| type=str,
|
| required=True,
|
| help="Path to model checkpoint directory"
|
| )
|
| parser.add_argument(
|
| "--prompt",
|
| type=str,
|
| default=None,
|
| help="Prompt for generation (if not provided, enters interactive mode)"
|
| )
|
| parser.add_argument(
|
| "--max_new_tokens",
|
| type=int,
|
| default=512,
|
| help="Maximum new tokens to generate"
|
| )
|
| parser.add_argument(
|
| "--temperature",
|
| type=float,
|
| default=0.7,
|
| help="Sampling temperature"
|
| )
|
| parser.add_argument(
|
| "--top_p",
|
| type=float,
|
| default=0.9,
|
| help="Top-p (nucleus) sampling"
|
| )
|
| parser.add_argument(
|
| "--top_k",
|
| type=int,
|
| default=50,
|
| help="Top-k sampling"
|
| )
|
| parser.add_argument(
|
| "--device",
|
| type=str,
|
| default="cuda",
|
| choices=["cuda", "cpu"],
|
| help="Device to run inference on"
|
| )
|
| parser.add_argument(
|
| "--stream",
|
| action="store_true",
|
| help="Stream output token by token"
|
| )
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| print(f"Loading model from {args.checkpoint}...")
|
| model, tokenizer = load_model(args.checkpoint, args.device)
|
| print("Model loaded successfully!")
|
|
|
| if args.prompt:
|
|
|
| response = generate(
|
| model, tokenizer, args.prompt,
|
| max_new_tokens=args.max_new_tokens,
|
| temperature=args.temperature,
|
| top_p=args.top_p,
|
| top_k=args.top_k,
|
| stream=args.stream
|
| )
|
| if args.stream:
|
| for token in response:
|
| print(token, end="", flush=True)
|
| print()
|
| else:
|
| print(f"\nResponse: {response}")
|
| else:
|
|
|
| interactive_mode(model, tokenizer)
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |