Ubuntu commited on
Commit
c2fbc3b
1 Parent(s): 5d0eb5a

add custom handler

Browse files
Files changed (3) hide show
  1. handler.py +48 -0
  2. requirements.txt +6 -0
  3. test_handler.py +13 -0
handler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import os
3
+ import requests
4
+ from flask import Flask, Response, request, jsonify
5
+ from segment_anything import SamPredictor, sam_model_registry
6
+
7
+ class EndpointHandler():
8
+ def __init__(self, path=""):
9
+ # Preload all the elements you are going to need at inference.
10
+ model_type = "vit_b"
11
+ # prefix = "/opt/ml/model"
12
+ model_path = "tf_model.h5"
13
+ # model_checkpoint_path = os.path.join(prefix, "sam_vit_h_4b8939.pth")
14
+
15
+ sam = sam_model_registry[model_type](checkpoint=model_path)
16
+ predictor = SamPredictor(sam)
17
+
18
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
+ """
20
+ data args:
21
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
22
+ kwargs
23
+ Return:
24
+ A :obj:`list` | `dict`: will be serialized and returned
25
+ """
26
+
27
+ inputs = data.pop("inputs", data)
28
+ image_url = inputs.pop("imageUrl", none)
29
+
30
+ if not image_url:
31
+ return jsonify({"error": "image_url not provided"}), 400
32
+
33
+ try:
34
+ response = requests.get(image_url)
35
+ response.raise_for_status()
36
+ image = response.content
37
+ except requests.RequestException as e:
38
+ return jsonify({"error": f"Error downloading image: {str(e)}"}), 500
39
+
40
+
41
+ predictor.set_image(image)
42
+
43
+ image_embedding = predictor.get_image_embedding().cpu().numpy().toList()
44
+
45
+ return jsonify(image_embedding)
46
+
47
+ # pseudo
48
+ # self.model(input)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -f https://download.pytorch.org/whl/cu117/torch_stable.html
2
+ # torch
3
+ # torchvision
4
+ git+https://github.com/ara-vardanyan/segment-anything.git
5
+ flask
6
+ requests
test_handler.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler(path=".")
5
+
6
+ # prepare sample payload
7
+ payload = {"inputs": "I am quite excited how this will turn out", "imageUrl": "https://res.cloudinary.com/dvfgdnfzd/image/upload/v1693510414/nvae1t0lvgzavfkgb45j.png"}
8
+
9
+ # test the handler
10
+ payload=my_handler(payload)
11
+
12
+ # show results
13
+ print("payload", payload)