Zenith-7b-V1 / inference.py
Zandy-Wandy's picture
Upload Zenith-7B model
8d18b7c verified
#!/usr/bin/env python3
"""Zenith-7B Inference Script for Standard GPUs"""
import torch
import argparse
from pathlib import Path
from typing import Optional, Dict, Any
# Add current directory to path for imports
import sys
sys.path.append(str(Path(__file__).parent))
from configs.zenith_config import get_7b_config
from models.zenith_model import ZenithForCausalLM
from data.advanced_tokenizer import AdvancedTokenizer
def load_model(checkpoint_path: str, device: str = "cuda"):
"""Load trained model from checkpoint."""
config = get_7b_config()
# Initialize tokenizer
tokenizer = AdvancedTokenizer.from_pretrained(checkpoint_path)
config.vocab_size = tokenizer.get_vocab_size()
# Load model
model = ZenithForCausalLM.from_pretrained(
checkpoint_path,
config=config,
device_map="auto" if device == "cuda" else None
)
model.eval()
return model, tokenizer
def generate(
model: ZenithForCausalLM,
tokenizer: AdvancedTokenizer,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.1,
do_sample: bool = True,
stream: bool = False
):
"""Generate text from the model."""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
if stream:
# Streaming generation
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
streamer=streamer
)
from threading import Thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
else:
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def interactive_mode(model, tokenizer):
"""Run interactive chat session."""
print("=" * 60)
print("Zenith-7B Interactive Mode")
print("Type 'quit' to exit, 'clear' to clear history")
print("=" * 60)
history = []
while True:
try:
user_input = input("\nYou: ").strip()
if user_input.lower() == 'quit':
break
if user_input.lower() == 'clear':
history = []
print("History cleared.")
continue
# Build prompt with history
prompt = ""
for user_msg, assistant_msg in history[-4:]: # Keep last 4 exchanges
prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
prompt += f"User: {user_input}\nAssistant:"
print("\nZenith: ", end="", flush=True)
response = generate(model, tokenizer, prompt, stream=True)
full_response = ""
for token in response:
print(token, end="", flush=True)
full_response += token
print()
history.append((user_input, full_response))
except KeyboardInterrupt:
print("\n\nInterrupted. Type 'quit' to exit.")
except Exception as e:
print(f"\nError: {e}")
def main():
parser = argparse.ArgumentParser(description="Zenith-7B Inference")
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to model checkpoint directory"
)
parser.add_argument(
"--prompt",
type=str,
default=None,
help="Prompt for generation (if not provided, enters interactive mode)"
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=512,
help="Maximum new tokens to generate"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature"
)
parser.add_argument(
"--top_p",
type=float,
default=0.9,
help="Top-p (nucleus) sampling"
)
parser.add_argument(
"--top_k",
type=int,
default=50,
help="Top-k sampling"
)
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
help="Device to run inference on"
)
parser.add_argument(
"--stream",
action="store_true",
help="Stream output token by token"
)
args = parser.parse_args()
# Load model
print(f"Loading model from {args.checkpoint}...")
model, tokenizer = load_model(args.checkpoint, args.device)
print("Model loaded successfully!")
if args.prompt:
# Single generation
response = generate(
model, tokenizer, args.prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
stream=args.stream
)
if args.stream:
for token in response:
print(token, end="", flush=True)
print()
else:
print(f"\nResponse: {response}")
else:
# Interactive mode
interactive_mode(model, tokenizer)
if __name__ == "__main__":
main()