|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
import json |
|
import os |
|
import numpy as np |
|
import requests |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
import time |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
test_image_url = "https://static.wixstatic.com/media/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg/v1/fill/w_454,h_333,fp_0.50_0.50,q_90/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg" |
|
english_text = ( |
|
"It was the best of times, it was the worst of times, it was the age " |
|
"of wisdom, it was the age of foolishness, it was the epoch of belief" |
|
) |
|
|
|
clip_model="ViT-L/14" |
|
clip_model_id ="laion5B-L-14" |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
print ("using device", device) |
|
from clip_retrieval.load_clip import load_clip, get_tokenizer |
|
|
|
model, preprocess = load_clip(clip_model, use_jit=True, device=device) |
|
tokenizer = get_tokenizer(clip_model) |
|
|
|
def preprocess_image(image_url): |
|
|
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
response = requests.get(test_image_url) |
|
input_image = Image.open(BytesIO(response.content)) |
|
input_image = input_image.convert('RGB') |
|
|
|
input_image = np.array(input_image) |
|
input_im = Image.fromarray(input_image) |
|
prepro = preprocess(input_im).unsqueeze(0).cpu() |
|
return prepro |
|
|
|
preprocessed_image = preprocess_image(test_image_url) |
|
|
|
|
|
def text_to_embedding(text): |
|
payload = { |
|
"text": ('str', text, 'application/octet-stream'), |
|
} |
|
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/") |
|
response = requests.post(url, files=payload) |
|
embeddings = response.text |
|
return embeddings |
|
|
|
def image_url_to_embedding(image_url): |
|
payload = { |
|
"image_url": ('str', test_image_url, 'application/octet-stream'), |
|
} |
|
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/") |
|
response = requests.post(url, files=payload) |
|
embeddings = response.text |
|
return embeddings |
|
|
|
def preprocessed_image_to_embedding(image): |
|
key = "preprocessed_image" |
|
data_bytes = image.numpy().tobytes() |
|
shape_bytes = np.array(image.shape).tobytes() |
|
dtype_bytes = str(image.dtype).encode() |
|
payload = { |
|
key: ('tensor', data_bytes, 'application/octet-stream'), |
|
'shape': ('shape', shape_bytes, 'application/octet-stream'), |
|
'dtype': ('dtype', dtype_bytes, 'application/octet-stream'), |
|
} |
|
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/") |
|
response = requests.post(url, files=payload) |
|
embeddings = response.text |
|
return embeddings |
|
|
|
def _send_text_request(number): |
|
embeddings = text_to_embedding(english_text) |
|
return number, embeddings |
|
|
|
def _send_image_url_request(number): |
|
embeddings = image_url_to_embedding(test_image_url) |
|
return number, embeddings |
|
|
|
def _send_preprocessed_image_request(number): |
|
embeddings = preprocessed_image_to_embedding(preprocessed_image) |
|
return number, embeddings |
|
|
|
def process(numbers, send_func, max_workers=10): |
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
futures = [executor.submit(send_func, number) for number in numbers] |
|
for future in as_completed(futures): |
|
n_result, result = future.result() |
|
result = json.loads(result) |
|
print (f"{n_result} : {len(result[0])}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
n_calls = 300 |
|
|
|
|
|
|
|
numbers = list(range(n_calls)) |
|
start_time = time.monotonic() |
|
process(numbers, _send_text_request) |
|
end_time = time.monotonic() |
|
total_time = end_time - start_time |
|
avg_time_ms = total_time / n_calls * 1000 |
|
calls_per_sec = n_calls / total_time |
|
print(f"Text...") |
|
print(f" Average time taken: {avg_time_ms:.2f} ms") |
|
print(f" Number of calls per second: {calls_per_sec:.2f}") |
|
|
|
|
|
|
|
numbers = list(range(n_calls)) |
|
start_time = time.monotonic() |
|
process(numbers, _send_image_url_request) |
|
end_time = time.monotonic() |
|
total_time = end_time - start_time |
|
avg_time_ms = total_time / n_calls * 1000 |
|
calls_per_sec = n_calls / total_time |
|
print(f"Image passing url...") |
|
print(f" Average time taken: {avg_time_ms:.2f} ms") |
|
print(f" Number of calls per second: {calls_per_sec:.2f}") |
|
|
|
|
|
|
|
numbers = list(range(n_calls)) |
|
start_time = time.monotonic() |
|
process(numbers, _send_preprocessed_image_request) |
|
end_time = time.monotonic() |
|
total_time = end_time - start_time |
|
avg_time_ms = total_time / n_calls * 1000 |
|
calls_per_sec = n_calls / total_time |
|
print(f"Preprocessed image...") |
|
print(f" Average time taken: {avg_time_ms:.2f} ms") |
|
print(f" Number of calls per second: {calls_per_sec:.2f}") |
|
|