Transformers
PyTorch
perceiver
Inference Endpoints
multimodal-perceiver / handler.py
IAMRonHIT's picture
Create handler.py
6097b91 verified
raw
history blame
3.58 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
import torch
from PIL import Image
import numpy as np
import librosa
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler. This loads the tokenizer and model required for inference.
We will load the `ronai-multimodal-perceiver-tsx` model for multimodal input handling.
"""
# Load the tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = ORTModelForSequenceClassification.from_pretrained(path)
# Initialize a pipeline for text classification (adjust task type if needed)
self.pipeline = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess input data based on the modality.
This handler supports text, image, and audio data.
"""
inputs = data.get("inputs", None)
if isinstance(inputs, str):
# Preprocessing for text input
tokens = self.tokenizer(inputs, return_tensors="pt")
return tokens
elif isinstance(inputs, Image.Image):
# Preprocessing for image input (convert to tensor)
image = np.array(inputs)
image_tensor = torch.tensor(image).unsqueeze(0) # Add batch dimension
return image_tensor
elif isinstance(inputs, np.ndarray):
# Preprocessing for raw array input (e.g., audio, point clouds)
return torch.tensor(inputs).unsqueeze(0)
elif isinstance(inputs, bytes):
# Preprocessing for audio input (convert to mel spectrogram)
audio, sr = librosa.load(inputs, sr=None)
mel_spectrogram = librosa.feature.melspectrogram(audio, sr=sr)
mel_tensor = torch.tensor(mel_spectrogram).unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
return mel_tensor
else:
raise ValueError("Unsupported input type. Must be string (text), image (PIL), or array (audio, etc.).")
def postprocess(self, outputs: Any) -> List[Dict[str, Any]]:
"""
Post-process the model output to a human-readable format.
For text classification, this returns label and score.
"""
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_class_id = probabilities.argmax().item()
score = probabilities[0, predicted_class_id].item()
return [{"label": self.model.config.id2label[predicted_class_id], "score": score}]
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handles the incoming request, processes the input, runs inference, and returns results.
Args:
data (Dict[str, Any]): The input data for inference.
- data["inputs"] could be a string (text), PIL.Image (image), np.ndarray (audio or point clouds).
Returns:
A list of dictionaries containing the model's prediction.
"""
# Step 1: Preprocess input data
preprocessed_data = self.preprocess(data)
# Step 2: Perform model inference
outputs = self.pipeline(preprocessed_data)
# Step 3: Post-process and return the predictions
return self.postprocess(outputs)