sohojoe commited on
Commit
1aa1aec
1 Parent(s): e4bcc80

base API on http, checking in so we can track the difference

Browse files
experimental/clip_api_app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File name: model.py
2
+ import json
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ from starlette.requests import Request
7
+ from PIL import Image
8
+ import ray
9
+ from ray import serve
10
+ from clip_retrieval.load_clip import load_clip, get_tokenizer
11
+ # from clip_retrieval.clip_client import ClipClient, Modality
12
+
13
+ @serve.deployment(num_replicas=6, ray_actor_options={"num_cpus": .2, "num_gpus": 0.1})
14
+ class CLIPTransform:
15
+ def __init__(self):
16
+ # os.environ["OMP_NUM_THREADS"] = "20"
17
+ # torch.set_num_threads(20)
18
+ # Load model
19
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
20
+ self._clip_model="ViT-L/14"
21
+ self._clip_model_id ="laion5B-L-14"
22
+ self.model, self.preprocess = load_clip(self._clip_model, use_jit=True, device=self.device)
23
+ self.tokenizer = get_tokenizer(self._clip_model)
24
+
25
+ print ("using device", self.device)
26
+
27
+ def text_to_embeddings(self, prompt):
28
+ text = self.tokenizer([prompt]).to(self.device)
29
+ with torch.no_grad():
30
+ prompt_embededdings = self.model.encode_text(text)
31
+ prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
32
+ return(prompt_embededdings)
33
+
34
+ def image_to_embeddings(self, input_im):
35
+ input_im = Image.fromarray(input_im)
36
+ prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
37
+ with torch.no_grad():
38
+ image_embeddings = self.model.encode_image(prepro)
39
+ image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
40
+ return(image_embeddings)
41
+
42
+ def preprocessed_image_to_emdeddings(self, prepro):
43
+ with torch.no_grad():
44
+ image_embeddings = self.model.encode_image(prepro)
45
+ image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
46
+ return(image_embeddings)
47
+
48
+ async def __call__(self, http_request: Request) -> str:
49
+ request = await http_request.json()
50
+ # print(type(request))
51
+ # print(str(request))
52
+ # switch based if we are using text or image
53
+ embeddings = None
54
+ if "text" in request:
55
+ prompt = request["text"]
56
+ embeddings = self.text_to_embeddings(prompt)
57
+ elif "image" in request:
58
+ image_url = request["image_url"]
59
+ # download image from url
60
+ import requests
61
+ from io import BytesIO
62
+ input_image = Image.open(BytesIO(image_url))
63
+ input_image = input_image.convert('RGB')
64
+ input_image = np.array(input_image)
65
+ embeddings = self.image_to_embeddings(input_image)
66
+ elif "preprocessed_image" in request:
67
+ prepro = request["preprocessed_image"]
68
+ # create torch tensor on the device
69
+ prepro = torch.tensor(prepro).to(self.device)
70
+ embeddings = self.preprocessed_image_to_emdeddings(prepro)
71
+ else:
72
+ raise Exception("Invalid request")
73
+ return embeddings.cpu().numpy().tolist()
74
+
75
+ deployment_graph = CLIPTransform.bind()
experimental/clip_api_app_client.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File name: graph_client.py
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import json
4
+ import os
5
+ import requests
6
+ from concurrent.futures import ThreadPoolExecutor, as_completed
7
+ import time
8
+
9
+ test_image_url = "https://static.wixstatic.com/media/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg/v1/fill/w_454,h_333,fp_0.50_0.50,q_90/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg"
10
+ english_text = (
11
+ "It was the best of times, it was the worst of times, it was the age "
12
+ "of wisdom, it was the age of foolishness, it was the epoch of belief"
13
+ )
14
+
15
+
16
+ def send_text_request(number):
17
+ json = {"text": english_text}
18
+ url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
19
+ response = requests.post(url, json=json)
20
+ embeddings = response.text
21
+ return number, embeddings
22
+
23
+ def process_text(numbers, max_workers=10):
24
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
25
+ futures = [executor.submit(send_text_request, number) for number in numbers]
26
+ for future in as_completed(futures):
27
+ n_result, result = future.result()
28
+ result = json.loads(result)
29
+ print (f"{n_result} : {len(result[0])}")
30
+
31
+ # def process_text(numbers, max_workers=10):
32
+ # for n in numbers:
33
+ # n_result, result = send_text_request(n)
34
+ # result = json.loads(result)
35
+ # print (f"{n_result} : {len(result[0])}")
36
+
37
+ if __name__ == "__main__":
38
+ # n_calls = 100000
39
+ n_calls = 10000
40
+ numbers = list(range(n_calls))
41
+ start_time = time.monotonic()
42
+ process_text(numbers)
43
+ end_time = time.monotonic()
44
+ total_time = end_time - start_time
45
+ avg_time_ms = total_time / n_calls * 1000
46
+ calls_per_sec = n_calls / total_time
47
+ print(f"Average time taken: {avg_time_ms:.2f} ms")
48
+ print(f"Number of calls per second: {calls_per_sec:.2f}")
49
+