sohojoe commited on
Commit
781e740
1 Parent(s): 2afa949

Update clip_app.py

Browse files
Files changed (1) hide show
  1. experimental/clip_app.py +59 -0
experimental/clip_app.py CHANGED
@@ -46,6 +46,65 @@ class CLIPTransform:
46
  return(image_embeddings)
47
 
48
  async def __call__(self, http_request: Request) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  request = await http_request.json()
50
  # print(type(request))
51
  # print(str(request))
 
46
  return(image_embeddings)
47
 
48
  async def __call__(self, http_request: Request) -> str:
49
+ form_data = await http_request.form()
50
+
51
+ embeddings = None
52
+ if "text" in form_data:
53
+ prompt = (await form_data["text"].read()).decode()
54
+ print (type(prompt))
55
+ print (str(prompt))
56
+ embeddings = self.text_to_embeddings(prompt)
57
+ elif "image_url" in form_data:
58
+ image_url = (await form_data["image_url"].read()).decode()
59
+ # download image from url
60
+ import requests
61
+ from io import BytesIO
62
+ image_bytes = requests.get(image_url).content
63
+ input_image = Image.open(BytesIO(image_bytes))
64
+ input_image = input_image.convert('RGB')
65
+ input_image = np.array(input_image)
66
+ embeddings = self.image_to_embeddings(input_image)
67
+ elif "preprocessed_image" in form_data:
68
+ tensor_bytes = await form_data["preprocessed_image"].read()
69
+ shape_bytes = await form_data["shape"].read()
70
+ dtype_bytes = await form_data["dtype"].read()
71
+
72
+ # Convert bytes back to original form
73
+ dtype_mapping = {
74
+ "torch.float32": torch.float32,
75
+ "torch.float64": torch.float64,
76
+ "torch.float16": torch.float16,
77
+ "torch.uint8": torch.uint8,
78
+ "torch.int8": torch.int8,
79
+ "torch.int16": torch.int16,
80
+ "torch.int32": torch.int32,
81
+ "torch.int64": torch.int64,
82
+ torch.float32: np.float32,
83
+ torch.float64: np.float64,
84
+ torch.float16: np.float16,
85
+ torch.uint8: np.uint8,
86
+ torch.int8: np.int8,
87
+ torch.int16: np.int16,
88
+ torch.int32: np.int32,
89
+ torch.int64: np.int64,
90
+ # add more if needed
91
+ }
92
+ dtype_str = dtype_bytes.decode()
93
+ dtype_torch = dtype_mapping[dtype_str]
94
+ dtype_numpy = dtype_mapping[dtype_torch]
95
+ # shape = np.frombuffer(shape_bytes, dtype=np.int64)
96
+ # TODO: fix shape so it is passed nicely
97
+ shape = tuple([1, 3, 224, 224])
98
+
99
+ tensor_numpy = np.frombuffer(tensor_bytes, dtype=dtype_numpy).reshape(shape)
100
+ tensor = torch.from_numpy(tensor_numpy)
101
+ prepro = tensor.to(self.device)
102
+ embeddings = self.preprocessed_image_to_emdeddings(prepro)
103
+ else:
104
+ print ("Invalid request")
105
+ raise Exception("Invalid request")
106
+ return embeddings.cpu().numpy().tolist()
107
+
108
  request = await http_request.json()
109
  # print(type(request))
110
  # print(str(request))