| | from typing import Dict, List, Any |
| | import os |
| | import requests |
| | from flask import Flask, Response, request, jsonify |
| | from segment_anything import SamPredictor, sam_model_registry |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | model_type = "vit_b" |
| | |
| | print('current working directory', os.getcwd()) |
| | model_path = "models/tf_model.h5" |
| | |
| | sam = sam_model_registry[model_type](checkpoint=model_path) |
| | self.predictor = SamPredictor(sam) |
| | |
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | data args: |
| | inputs (:obj: `str` | `PIL.Image` | `np.array`) |
| | kwargs |
| | Return: |
| | A :obj:`list` | `dict`: will be serialized and returned |
| | """ |
| | |
| | inputs = data.pop("inputs", data) |
| | image_url = inputs.pop("imageUrl", None) |
| | |
| | if not image_url: |
| | return jsonify({"error": "image_url not provided"}), 400 |
| | |
| | try: |
| | response = requests.get(image_url) |
| | response.raise_for_status() |
| | image = response.content |
| | except requests.RequestException as e: |
| | return jsonify({"error": f"Error downloading image: {str(e)}"}), 500 |
| | |
| | self.predictor.set_image(image) |
| | |
| | image_embedding = self.predictor.get_image_embedding().cpu().numpy().tolist() |
| | |
| | return jsonify(image_embedding) |
| |
|