sam-vit-base / handler.py
aradootle's picture
moved model to see if itll fix the bug
1810c2f
raw history blame
No virus
1.32 kB
from typing import Dict, List, Any
import os
import requests
from flask import Flask, Response, request, jsonify
from segment_anything import SamPredictor, sam_model_registry
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
model_type = "vit_b"
# prefix = "/opt/ml/model"
model_path = "models/tf_model.h5"
# model_checkpoint_path = os.path.join(prefix, "sam_vit_h_4b8939.pth")
sam = sam_model_registry[model_type](checkpoint=model_path)
self.predictor = SamPredictor(sam)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
image_url = inputs.pop("imageUrl", None)
if not image_url:
return jsonify({"error": "image_url not provided"}), 400
try:
response = requests.get(image_url)
response.raise_for_status()
image = response.content
except requests.RequestException as e:
return jsonify({"error": f"Error downloading image: {str(e)}"}), 500
self.predictor.set_image(image)
image_embedding = self.predictor.get_image_embedding().cpu().numpy().tolist()
return jsonify(image_embedding)