File size: 4,174 Bytes
7bef5db e4bcc80 55f430c 334dcac cdfc363 55f430c cdfc363 1731153 cdfc363 1731153 cdfc363 1731153 cdfc363 1731153 cdfc363 1731153 cdfc363 b2b5d5f cdfc363 b2b5d5f cdfc363 7bef5db cdfc363 7bef5db cdfc363 7bef5db cdfc363 7bef5db cdfc363 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import json
import os
import numpy as np
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
from io import BytesIO
import torch
from clip_retrieval.load_clip import load_clip, get_tokenizer
class ClipAppClient:
"""
A class to handle generating embeddings using the OpenAI CLIP model.
app_client = ClipAppClient()
test_image_url = "https://example.com/image.jpg"
preprocessed_image = app_client.preprocess_image(test_image_url)
text = "A beautiful landscape"
text_embeddings = app_client.text_to_embedding(text)
image_embeddings = app_client.image_url_to_embedding(test_image_url)
preprocessed_image_embeddings = app_client.preprocessed_image_to_embedding(preprocessed_image)
"""
def __init__(self, clip_model="ViT-L/14", device=None):
# def __init__(self, clip_model="open_clip:ViT-H-14", device=None):
self.clip_model = clip_model
self.device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
print("using device", self.device)
_, self.preprocess = load_clip(clip_model, use_jit=True, device=self.device)
self.tokenizer = get_tokenizer(clip_model)
def preprocess_image(self, image_url):
"""
Preprocess an image from a given URL.
:param image_url: str, URL of the image to preprocess
:return: torch.Tensor, preprocessed image
"""
if os.path.isfile(image_url):
input_image = Image.open(image_url).convert('RGB')
input_image = np.array(input_image)
input_image = Image.fromarray(input_image)
else:
response = requests.get(image_url)
input_image = Image.open(BytesIO(response.content)).convert('RGB')
input_image = np.array(input_image)
input_image = Image.fromarray(input_image)
prepro = self.preprocess(input_image).unsqueeze(0).cpu()
return prepro
def text_to_embedding(self, text):
"""
Convert a given text to an embedding using the OpenAI CLIP model.
:param text: str, text to convert to an embedding
:return: str, text embeddings
"""
payload = {
"text": ('str', text, 'application/octet-stream'),
}
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
response = requests.post(url, files=payload)
embeddings = response.text
embeddings = json.loads(embeddings)
embeddings = torch.tensor(embeddings)
return embeddings
def image_url_to_embedding(self, image_url):
"""
Convert an image URL to an embedding using the OpenAI CLIP model.
:param image_url: str, URL of the image to convert to an embedding
:return: str, image embeddings
"""
payload = {
"image_url": ('str', image_url, 'application/octet-stream'),
}
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
response = requests.post(url, files=payload)
embeddings = response.text
embeddings = json.loads(embeddings)
embeddings = torch.tensor(embeddings)
return embeddings
def preprocessed_image_to_embedding(self, image):
"""
Convert a preprocessed image to an embedding using the OpenAI CLIP model.
:param image: torch.Tensor, preprocessed image
:return: str, image embeddings
"""
key = "preprocessed_image"
data_bytes = image.numpy().tobytes()
shape_bytes = np.array(image.shape).tobytes()
dtype_bytes = str(image.dtype).encode()
payload = {
key: ('tensor', data_bytes, 'application/octet-stream'),
'shape': ('shape', shape_bytes, 'application/octet-stream'),
'dtype': ('dtype', dtype_bytes, 'application/octet-stream'),
}
url = os.environ.get("HTTP_ADDRESS", "http://127.0.0.1:8000/")
response = requests.post(url, files=payload)
embeddings = response.text
embeddings = json.loads(embeddings)
embeddings = torch.tensor(embeddings)
return embeddings
|