|
|
| """
|
| Inference script for Vortex models.
|
| Supports both CUDA and MPS backends.
|
| """
|
|
|
| import argparse
|
| import sys
|
| from pathlib import Path
|
|
|
| import torch
|
|
|
| from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| from configs.vortex_13b_config import VORTEX_13B_CONFIG
|
|
|
| from models.vortex_model import VortexModel
|
| from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| from inference.cuda_optimize import optimize_for_cuda, profile_model
|
| from inference.mps_optimize import optimize_for_mps, profile_model_mps
|
|
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser(description="Run inference with Vortex model")
|
| parser.add_argument("--model_path", type=str, required=True,
|
| help="Path to trained model checkpoint")
|
| parser.add_argument("--config", type=str, default=None,
|
| help="Path to model config (if not in checkpoint)")
|
| parser.add_argument("--tokenizer_path", type=str, default=None,
|
| help="Path to tokenizer")
|
| parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
|
| help="Model size for config")
|
| parser.add_argument("--device", type=str, default="cuda",
|
| choices=["cuda", "mps", "cpu"],
|
| help="Device to run on")
|
| parser.add_argument("--use_mps", action="store_true",
|
| help="Use MPS backend (Apple Silicon)")
|
| parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
|
| help="Apply quantization (CUDA only)")
|
| parser.add_argument("--flash_attention", action="store_true",
|
| help="Use Flash Attention 2 (CUDA only)")
|
| parser.add_argument("--torch_compile", action="store_true",
|
| help="Use torch.compile")
|
| parser.add_argument("--prompt", type=str, default=None,
|
| help="Input prompt for generation")
|
| parser.add_argument("--interactive", action="store_true",
|
| help="Run in interactive mode")
|
| parser.add_argument("--max_new_tokens", type=int, default=100,
|
| help="Maximum new tokens to generate")
|
| parser.add_argument("--temperature", type=float, default=0.8,
|
| help="Sampling temperature")
|
| parser.add_argument("--top_p", type=float, default=0.9,
|
| help="Top-p sampling")
|
| parser.add_argument("--profile", action="store_true",
|
| help="Profile performance")
|
| return parser.parse_args()
|
|
|
|
|
| def load_model(args):
|
| """Load model with appropriate optimizations."""
|
|
|
| if args.config:
|
| from configuration_vortex import VortexConfig
|
| config = VortexConfig.from_pretrained(args.config)
|
| else:
|
|
|
| if args.model_size == "7b":
|
| config_dict = VORTEX_7B_CONFIG
|
| else:
|
| config_dict = VORTEX_13B_CONFIG
|
| from configuration_vortex import VortexConfig
|
| config = VortexConfig(**config_dict)
|
|
|
|
|
| print("Creating model...")
|
| model = VortexModel(config.to_dict())
|
|
|
|
|
| print(f"Loading checkpoint from {args.model_path}")
|
| checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=False)
|
| if "model_state_dict" in checkpoint:
|
| model.load_state_dict(checkpoint["model_state_dict"])
|
| else:
|
| model.load_state_dict(checkpoint)
|
| print("Model loaded")
|
|
|
|
|
| device = torch.device(args.device)
|
| if args.use_mps or args.device == "mps":
|
| print("Optimizing for MPS...")
|
| model = optimize_for_mps(model, config.to_dict(), use_sdpa=True)
|
| else:
|
| print("Optimizing for CUDA...")
|
| model = optimize_for_cuda(
|
| model,
|
| config.to_dict(),
|
| use_flash_attention=args.flash_attention,
|
| use_torch_compile=args.torch_compile,
|
| quantization=args.quantization,
|
| )
|
|
|
| model = model.to(device)
|
| model.eval()
|
|
|
| return model, config
|
|
|
|
|
| def load_tokenizer(args):
|
| """Load tokenizer."""
|
| tokenizer_path = args.tokenizer_path
|
| if not tokenizer_path:
|
|
|
| model_dir = Path(args.model_path).parent
|
| tokenizer_path = model_dir / "vortex_tokenizer.json"
|
|
|
| if tokenizer_path and Path(tokenizer_path).exists():
|
| from tokenization_vortex import VortexTokenizer
|
| tokenizer = VortexTokenizer.from_pretrained(str(model_dir))
|
| else:
|
| print("Warning: No tokenizer found, using dummy tokenizer")
|
| class DummyTokenizer:
|
| def __call__(self, text, **kwargs):
|
| return {"input_ids": torch.tensor([[1, 2, 3]])}
|
| def decode(self, ids, **kwargs):
|
| return "dummy"
|
| tokenizer = DummyTokenizer()
|
|
|
| return tokenizer
|
|
|
|
|
| def generate_text(model, tokenizer, prompt, args):
|
| """Generate text from prompt."""
|
|
|
| inputs = tokenizer(
|
| prompt,
|
| return_tensors="pt",
|
| padding=False,
|
| truncation=True,
|
| max_length=model.config.max_seq_len - args.max_new_tokens,
|
| )
|
| input_ids = inputs["input_ids"].to(next(model.parameters()).device)
|
|
|
|
|
| with torch.no_grad():
|
| if hasattr(model, 'generate'):
|
| output_ids = model.generate(
|
| input_ids,
|
| max_new_tokens=args.max_new_tokens,
|
| temperature=args.temperature,
|
| top_p=args.top_p,
|
| do_sample=True,
|
| pad_token_id=tokenizer.pad_token_id,
|
| )
|
| else:
|
|
|
| for _ in range(args.max_new_tokens):
|
| outputs = model(input_ids)
|
| next_token_logits = outputs["logits"][:, -1, :]
|
| next_token = torch.multinomial(
|
| torch.softmax(next_token_logits / args.temperature, dim=-1),
|
| num_samples=1,
|
| )
|
| input_ids = torch.cat([input_ids, next_token], dim=-1)
|
|
|
|
|
| if next_token.item() == tokenizer.eos_token_id:
|
| break
|
|
|
|
|
| generated = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)
|
| return generated
|
|
|
|
|
| def main():
|
| args = parse_args()
|
|
|
|
|
| model, config = load_model(args)
|
| tokenizer = load_tokenizer(args)
|
|
|
| print(f"Model loaded on {next(model.parameters()).device}")
|
| print(f"Model parameters: {model.get_num_params():,}")
|
|
|
|
|
| if args.profile:
|
| print("Profiling...")
|
| dummy_input = torch.randint(0, config.vocab_size, (1, 128)).to(next(model.parameters()).device)
|
| if args.use_mps or args.device == "mps":
|
| stats = profile_model_mps(model, dummy_input)
|
| else:
|
| stats = profile_model(model, dummy_input)
|
| print("Profile results:")
|
| for k, v in stats.items():
|
| print(f" {k}: {v:.4f}")
|
| return
|
|
|
|
|
| if args.interactive:
|
| print("Interactive mode. Type 'quit' to exit.")
|
| while True:
|
| prompt = input("\nPrompt: ")
|
| if prompt.lower() == "quit":
|
| break
|
| response = generate_text(model, tokenizer, prompt, args)
|
| print(f"\nResponse: {response}")
|
| elif args.prompt:
|
| response = generate_text(model, tokenizer, args.prompt, args)
|
| print(f"Response: {response}")
|
| else:
|
| print("No prompt provided. Use --prompt or --interactive.")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|