Helion-V2 / inference.py
Trouter-Library's picture
Create inference.py
e2035b7 verified
"""
Helion-V2 Inference Script
Provides optimized inference with various sampling strategies.
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import argparse
from typing import Optional, List, Dict
import time
class HelionInference:
"""Inference wrapper for Helion-V2 model."""
def __init__(
self,
model_name: str = "DeepXR/Helion-V2",
device: str = "auto",
load_in_4bit: bool = False,
load_in_8bit: bool = False,
use_flash_attention: bool = True,
):
"""
Initialize the Helion-V2 model for inference.
Args:
model_name: HuggingFace model identifier
device: Device placement ('auto', 'cuda', 'cpu')
load_in_4bit: Use 4-bit quantization
load_in_8bit: Use 8-bit quantization
use_flash_attention: Enable Flash Attention 2
"""
self.model_name = model_name
self.device = device
print(f"Loading tokenizer from {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Configure quantization
quantization_config = None
if load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif load_in_8bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
print(f"Loading model from {model_name}...")
model_kwargs = {
"device_map": device,
"torch_dtype": torch.float16,
"quantization_config": quantization_config,
}
if use_flash_attention and not (load_in_4bit or load_in_8bit):
model_kwargs["attn_implementation"] = "flash_attention_2"
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
**model_kwargs
)
self.model.eval()
print("Model loaded successfully!")
def generate(
self,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.1,
do_sample: bool = True,
num_return_sequences: int = 1,
) -> List[str]:
"""
Generate text from a prompt.
Args:
prompt: Input text prompt
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature (higher = more random)
top_p: Nucleus sampling threshold
top_k: Top-k sampling parameter
repetition_penalty: Penalty for repeating tokens
do_sample: Use sampling vs greedy decoding
num_return_sequences: Number of sequences to generate
Returns:
List of generated text strings
"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
start_time = time.time()
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
num_return_sequences=num_return_sequences,
pad_token_id=self.tokenizer.eos_token_id,
)
generation_time = time.time() - start_time
tokens_generated = outputs.shape[1] - inputs["input_ids"].shape[1]
tokens_per_second = tokens_generated / generation_time
results = []
for output in outputs:
text = self.tokenizer.decode(output, skip_special_tokens=True)
results.append(text)
print(f"\nGeneration stats:")
print(f" Tokens generated: {tokens_generated}")
print(f" Time: {generation_time:.2f}s")
print(f" Speed: {tokens_per_second:.2f} tokens/s")
return results
def chat(
self,
messages: List[Dict[str, str]],
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
**kwargs
) -> str:
"""
Generate response in chat format.
Args:
messages: List of message dicts with 'role' and 'content'
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling threshold
**kwargs: Additional generation parameters
Returns:
Generated response text
"""
input_text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
results = self.generate(
input_text,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
**kwargs
)
# Extract only the assistant's response
full_text = results[0]
if "<|assistant|>" in full_text:
response = full_text.split("<|assistant|>")[-1].split("<|end|>")[0].strip()
else:
response = full_text[len(input_text):].strip()
return response
def main():
parser = argparse.ArgumentParser(description="Helion-V2 Inference")
parser.add_argument(
"--model",
type=str,
default="DeepXR/Helion-V2",
help="Model name or path"
)
parser.add_argument(
"--prompt",
type=str,
required=True,
help="Input prompt"
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="Maximum tokens to generate"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature"
)
parser.add_argument(
"--top-p",
type=float,
default=0.9,
help="Nucleus sampling threshold"
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="Top-k sampling"
)
parser.add_argument(
"--repetition-penalty",
type=float,
default=1.1,
help="Repetition penalty"
)
parser.add_argument(
"--load-in-4bit",
action="store_true",
help="Load model in 4-bit precision"
)
parser.add_argument(
"--load-in-8bit",
action="store_true",
help="Load model in 8-bit precision"
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="Device placement"
)
parser.add_argument(
"--chat-mode",
action="store_true",
help="Use chat format"
)
args = parser.parse_args()
# Initialize model
inference = HelionInference(
model_name=args.model,
device=args.device,
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
)
# Generate response
if args.chat_mode:
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": args.prompt}
]
response = inference.chat(
messages,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
)
print(f"\nAssistant: {response}")
else:
results = inference.generate(
args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
)
print(f"\nGenerated text:\n{results[0]}")
if __name__ == "__main__":
main()