import numpy as np | |
import onnxruntime as ort | |
import torch | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
class YOLOSegmentationModel: | |
def __init__(self): | |
# Download and load the ONNX model from Hugging Face Hub | |
model_path = hf_hub_download(repo_id="rayh/astro-seg", filename="astro-yolo11m-seg.onnx") | |
self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) | |
def preprocess(self, image: Image.Image): | |
# Convert image to RGB and preprocess for ONNX model | |
input_array = np.array(image.convert("RGB")).astype(np.float32) | |
input_array = np.expand_dims(input_array, axis=0) # Add batch dimension | |
return input_array | |
def predict(self, image: Image.Image): | |
input_tensor = self.preprocess(image) | |
outputs = self.session.run(None, {"images": input_tensor}) | |
return outputs # Modify if needed to return bounding boxes/masks | |
model = YOLOSegmentationModel() | |
# HF Inference API expects a `predict` function | |
def predict(image: Image.Image): | |
return model.predict(image) |