Update clip_app.py
Browse files- experimental/clip_app.py +59 -0
experimental/clip_app.py
CHANGED
@@ -46,6 +46,65 @@ class CLIPTransform:
|
|
46 |
return(image_embeddings)
|
47 |
|
48 |
async def __call__(self, http_request: Request) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
request = await http_request.json()
|
50 |
# print(type(request))
|
51 |
# print(str(request))
|
|
|
46 |
return(image_embeddings)
|
47 |
|
48 |
async def __call__(self, http_request: Request) -> str:
|
49 |
+
form_data = await http_request.form()
|
50 |
+
|
51 |
+
embeddings = None
|
52 |
+
if "text" in form_data:
|
53 |
+
prompt = (await form_data["text"].read()).decode()
|
54 |
+
print (type(prompt))
|
55 |
+
print (str(prompt))
|
56 |
+
embeddings = self.text_to_embeddings(prompt)
|
57 |
+
elif "image_url" in form_data:
|
58 |
+
image_url = (await form_data["image_url"].read()).decode()
|
59 |
+
# download image from url
|
60 |
+
import requests
|
61 |
+
from io import BytesIO
|
62 |
+
image_bytes = requests.get(image_url).content
|
63 |
+
input_image = Image.open(BytesIO(image_bytes))
|
64 |
+
input_image = input_image.convert('RGB')
|
65 |
+
input_image = np.array(input_image)
|
66 |
+
embeddings = self.image_to_embeddings(input_image)
|
67 |
+
elif "preprocessed_image" in form_data:
|
68 |
+
tensor_bytes = await form_data["preprocessed_image"].read()
|
69 |
+
shape_bytes = await form_data["shape"].read()
|
70 |
+
dtype_bytes = await form_data["dtype"].read()
|
71 |
+
|
72 |
+
# Convert bytes back to original form
|
73 |
+
dtype_mapping = {
|
74 |
+
"torch.float32": torch.float32,
|
75 |
+
"torch.float64": torch.float64,
|
76 |
+
"torch.float16": torch.float16,
|
77 |
+
"torch.uint8": torch.uint8,
|
78 |
+
"torch.int8": torch.int8,
|
79 |
+
"torch.int16": torch.int16,
|
80 |
+
"torch.int32": torch.int32,
|
81 |
+
"torch.int64": torch.int64,
|
82 |
+
torch.float32: np.float32,
|
83 |
+
torch.float64: np.float64,
|
84 |
+
torch.float16: np.float16,
|
85 |
+
torch.uint8: np.uint8,
|
86 |
+
torch.int8: np.int8,
|
87 |
+
torch.int16: np.int16,
|
88 |
+
torch.int32: np.int32,
|
89 |
+
torch.int64: np.int64,
|
90 |
+
# add more if needed
|
91 |
+
}
|
92 |
+
dtype_str = dtype_bytes.decode()
|
93 |
+
dtype_torch = dtype_mapping[dtype_str]
|
94 |
+
dtype_numpy = dtype_mapping[dtype_torch]
|
95 |
+
# shape = np.frombuffer(shape_bytes, dtype=np.int64)
|
96 |
+
# TODO: fix shape so it is passed nicely
|
97 |
+
shape = tuple([1, 3, 224, 224])
|
98 |
+
|
99 |
+
tensor_numpy = np.frombuffer(tensor_bytes, dtype=dtype_numpy).reshape(shape)
|
100 |
+
tensor = torch.from_numpy(tensor_numpy)
|
101 |
+
prepro = tensor.to(self.device)
|
102 |
+
embeddings = self.preprocessed_image_to_emdeddings(prepro)
|
103 |
+
else:
|
104 |
+
print ("Invalid request")
|
105 |
+
raise Exception("Invalid request")
|
106 |
+
return embeddings.cpu().numpy().tolist()
|
107 |
+
|
108 |
request = await http_request.json()
|
109 |
# print(type(request))
|
110 |
# print(str(request))
|