from typing import Dict, List, Any import os import requests from flask import Flask, Response, request, jsonify from segment_anything import SamPredictor, sam_model_registry class EndpointHandler(): def __init__(self, path=""): # Preload all the elements you are going to need at inference. model_type = "vit_b" # prefix = "/opt/ml/model" print('current working directory', os.getcwd()) model_path = "models/tf_model.h5" # model_checkpoint_path = os.path.join(prefix, "sam_vit_h_4b8939.pth") sam = sam_model_registry[model_type](checkpoint=model_path) self.predictor = SamPredictor(sam) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.pop("inputs", data) image_url = inputs.pop("imageUrl", None) if not image_url: return jsonify({"error": "image_url not provided"}), 400 try: response = requests.get(image_url) response.raise_for_status() image = response.content except requests.RequestException as e: return jsonify({"error": f"Error downloading image: {str(e)}"}), 500 self.predictor.set_image(image) image_embedding = self.predictor.get_image_embedding().cpu().numpy().tolist() return jsonify(image_embedding)