tejoess's picture
Add custom handler for Inference Endpoint deployment
3a507ec
# handler.py
from typing import Dict, Any, List
import torch
import PIL.Image
from io import BytesIO
import base64
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
# Configure logging for debugging purposes
logging.basicConfig(level=logging.INFO)
class EndpointHandler:
def __init__(self, path=""):
logging.info("Initializing EndpointHandler for Moondream2")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {self.device}")
# Load the model with trust_remote_code enabled.
# 'path' points to the location of the model files inside the container.
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map=self.device
)
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
# Ensure the model is moved to the device
self.model.to(self.device)
self.model.eval()
logging.info("Moondream2 model loaded successfully.")
def preprocess_image(self, encoded_image: str) -> PIL.Image.Image:
"""Decode and preprocess the base64 encoded image."""
try:
image_data = base64.b64decode(encoded_image)
return PIL.Image.open(BytesIO(image_data)).convert("RGB")
except Exception as e:
logging.error(f"Error decoding image: {e}")
raise ValueError(f"Failed to decode image data: {e}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handles the API call. The `data` argument is a dictionary containing the payload.
Expects a JSON payload like:
{
"inputs": {
"prompt": "What's in this picture?",
"image": "base64_encoded_image_string"
}
}
"""
logging.info("Received request payload")
inputs = data.get("inputs", {})
prompt = inputs.get("prompt", "")
encoded_image = inputs.get("image", "")
if not prompt or not encoded_image:
raise ValueError("Prompt and base64 encoded image must be provided in the 'inputs' field.")
image = self.preprocess_image(encoded_image)
# Process the image and prompt
enc_image = self.model.encode_image(image)
# Create the conversation history for inference
chat_history = f"Question: {prompt}\n\nAnswer:"
logging.info(f"Running inference with prompt: {prompt}")
with torch.no_grad():
output_tokens = self.model.generate(
enc_image,
self.tokenizer,
chat_history,
pad_token_id=self.tokenizer.eos_token_id,
# Add other generation parameters here if needed
)
# Decode the generated tokens
generated_text = self.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
logging.info(f"Inference complete. Generated text: {generated_text}")
# Post-process the output to isolate the answer
try:
# The model output includes the prompt, so we need to extract only the answer part.
answer_start_tag = "\n\nAnswer:"
generated_answer = generated_text.split(answer_start_tag)[-1].strip()
except IndexError:
generated_answer = generated_text # Fallback if splitting fails
return [{"generated_text": generated_answer}]