Text-to-3D
Transformers
Safetensors
English
BrickGPTFork / handler.py
jjohnson5253's picture
fix me up dady
73a8f69
from typing import Dict, List, Any
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
# Get HuggingFace token for gated model access
hf_token = os.getenv("HF_TOKEN")
# Load model and tokenizer with authentication
self.tokenizer = AutoTokenizer.from_pretrained(
path,
token=hf_token
)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
token=hf_token
)
# Set pad token if not exists
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Simple handler that mimics local LLM behavior for RemoteLLM
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Handle different input formats that RemoteLLM sends
if isinstance(inputs, dict) and "messages" in inputs:
messages = inputs["messages"]
elif isinstance(inputs, list):
messages = inputs
else:
# Fallback - treat as direct text
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": str(inputs)}
]
# Check if this is a continuation (has assistant message)
has_assistant = any(msg.get("role") == "assistant" for msg in messages)
# Apply chat template exactly like BrickGPT does locally
if has_assistant:
prompt = self.tokenizer.apply_chat_template(
messages,
continue_final_message=True,
return_tensors='pt'
)
else:
prompt = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors='pt'
)
# Move to device
input_ids = prompt.to(self.model.device)
attention_mask = torch.ones_like(input_ids)
# Generation parameters - use BrickGPT defaults
generation_params = {
"max_new_tokens": parameters.get("max_new_tokens", 10),
"temperature": parameters.get("temperature", 0.6),
"top_k": parameters.get("top_k", 20),
"top_p": parameters.get("top_p", 1.0),
"pad_token_id": self.tokenizer.pad_token_id,
"do_sample": True,
"num_return_sequences": 1,
"return_dict_in_generate": True,
}
# Generate
with torch.no_grad():
output_dict = self.model.generate(
input_ids,
attention_mask=attention_mask,
**generation_params
)
# Extract new tokens and decode EXACTLY like local LLM
input_length = input_ids.shape[1]
result_ids = output_dict['sequences'][0][input_length:]
# CRITICAL: Decode exactly like local LLM (no skip_special_tokens parameter)
generated_text = self.tokenizer.decode(result_ids)
# Return in format RemoteLLM expects
return [{"generated_text": generated_text}]