Vortex-7b-V1 / inference.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
#!/usr/bin/env python3
"""
Inference script for Vortex models.
Supports both CUDA and MPS backends.
"""
import argparse
import sys
from pathlib import Path
import torch
from configs.vortex_7b_config import VORTEX_7B_CONFIG
from configs.vortex_13b_config import VORTEX_13B_CONFIG
from models.vortex_model import VortexModel
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
from inference.cuda_optimize import optimize_for_cuda, profile_model
from inference.mps_optimize import optimize_for_mps, profile_model_mps
def parse_args():
parser = argparse.ArgumentParser(description="Run inference with Vortex model")
parser.add_argument("--model_path", type=str, required=True,
help="Path to trained model checkpoint")
parser.add_argument("--config", type=str, default=None,
help="Path to model config (if not in checkpoint)")
parser.add_argument("--tokenizer_path", type=str, default=None,
help="Path to tokenizer")
parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
help="Model size for config")
parser.add_argument("--device", type=str, default="cuda",
choices=["cuda", "mps", "cpu"],
help="Device to run on")
parser.add_argument("--use_mps", action="store_true",
help="Use MPS backend (Apple Silicon)")
parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
help="Apply quantization (CUDA only)")
parser.add_argument("--flash_attention", action="store_true",
help="Use Flash Attention 2 (CUDA only)")
parser.add_argument("--torch_compile", action="store_true",
help="Use torch.compile")
parser.add_argument("--prompt", type=str, default=None,
help="Input prompt for generation")
parser.add_argument("--interactive", action="store_true",
help="Run in interactive mode")
parser.add_argument("--max_new_tokens", type=int, default=100,
help="Maximum new tokens to generate")
parser.add_argument("--temperature", type=float, default=0.8,
help="Sampling temperature")
parser.add_argument("--top_p", type=float, default=0.9,
help="Top-p sampling")
parser.add_argument("--profile", action="store_true",
help="Profile performance")
return parser.parse_args()
def load_model(args):
"""Load model with appropriate optimizations."""
# Load config
if args.config:
from configuration_vortex import VortexConfig
config = VortexConfig.from_pretrained(args.config)
else:
# Use default config for size
if args.model_size == "7b":
config_dict = VORTEX_7B_CONFIG
else:
config_dict = VORTEX_13B_CONFIG
from configuration_vortex import VortexConfig
config = VortexConfig(**config_dict)
# Create model
print("Creating model...")
model = VortexModel(config.to_dict())
# Load checkpoint
print(f"Loading checkpoint from {args.model_path}")
checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=False)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
print("Model loaded")
# Apply optimizations
device = torch.device(args.device)
if args.use_mps or args.device == "mps":
print("Optimizing for MPS...")
model = optimize_for_mps(model, config.to_dict(), use_sdpa=True)
else:
print("Optimizing for CUDA...")
model = optimize_for_cuda(
model,
config.to_dict(),
use_flash_attention=args.flash_attention,
use_torch_compile=args.torch_compile,
quantization=args.quantization,
)
model = model.to(device)
model.eval()
return model, config
def load_tokenizer(args):
"""Load tokenizer."""
tokenizer_path = args.tokenizer_path
if not tokenizer_path:
# Try to find in model directory
model_dir = Path(args.model_path).parent
tokenizer_path = model_dir / "vortex_tokenizer.json"
if tokenizer_path and Path(tokenizer_path).exists():
from tokenization_vortex import VortexTokenizer
tokenizer = VortexTokenizer.from_pretrained(str(model_dir))
else:
print("Warning: No tokenizer found, using dummy tokenizer")
class DummyTokenizer:
def __call__(self, text, **kwargs):
return {"input_ids": torch.tensor([[1, 2, 3]])}
def decode(self, ids, **kwargs):
return "dummy"
tokenizer = DummyTokenizer()
return tokenizer
def generate_text(model, tokenizer, prompt, args):
"""Generate text from prompt."""
# Tokenize
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=False,
truncation=True,
max_length=model.config.max_seq_len - args.max_new_tokens,
)
input_ids = inputs["input_ids"].to(next(model.parameters()).device)
# Generate
with torch.no_grad():
if hasattr(model, 'generate'):
output_ids = model.generate(
input_ids,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
else:
# Manual generation
for _ in range(args.max_new_tokens):
outputs = model(input_ids)
next_token_logits = outputs["logits"][:, -1, :]
next_token = torch.multinomial(
torch.softmax(next_token_logits / args.temperature, dim=-1),
num_samples=1,
)
input_ids = torch.cat([input_ids, next_token], dim=-1)
# Check for EOS
if next_token.item() == tokenizer.eos_token_id:
break
# Decode
generated = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)
return generated
def main():
args = parse_args()
# Load model and tokenizer
model, config = load_model(args)
tokenizer = load_tokenizer(args)
print(f"Model loaded on {next(model.parameters()).device}")
print(f"Model parameters: {model.get_num_params():,}")
# Profile if requested
if args.profile:
print("Profiling...")
dummy_input = torch.randint(0, config.vocab_size, (1, 128)).to(next(model.parameters()).device)
if args.use_mps or args.device == "mps":
stats = profile_model_mps(model, dummy_input)
else:
stats = profile_model(model, dummy_input)
print("Profile results:")
for k, v in stats.items():
print(f" {k}: {v:.4f}")
return
# Interactive mode
if args.interactive:
print("Interactive mode. Type 'quit' to exit.")
while True:
prompt = input("\nPrompt: ")
if prompt.lower() == "quit":
break
response = generate_text(model, tokenizer, prompt, args)
print(f"\nResponse: {response}")
elif args.prompt:
response = generate_text(model, tokenizer, args.prompt, args)
print(f"Response: {response}")
else:
print("No prompt provided. Use --prompt or --interactive.")
if __name__ == "__main__":
main()