sohojoe's picture
Update clip_app.py
781e740
raw
history blame
5.63 kB
# File name: model.py
import json
import os
import numpy as np
import torch
from starlette.requests import Request
from PIL import Image
import ray
from ray import serve
from clip_retrieval.load_clip import load_clip, get_tokenizer
# from clip_retrieval.clip_client import ClipClient, Modality
@serve.deployment(num_replicas=6, ray_actor_options={"num_cpus": .2, "num_gpus": 0.1})
class CLIPTransform:
def __init__(self):
# os.environ["OMP_NUM_THREADS"] = "20"
# torch.set_num_threads(20)
# Load model
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self._clip_model="ViT-L/14"
self._clip_model_id ="laion5B-L-14"
self.model, self.preprocess = load_clip(self._clip_model, use_jit=True, device=self.device)
self.tokenizer = get_tokenizer(self._clip_model)
print ("using device", self.device)
def text_to_embeddings(self, prompt):
text = self.tokenizer([prompt]).to(self.device)
with torch.no_grad():
prompt_embededdings = self.model.encode_text(text)
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
return(prompt_embededdings)
def image_to_embeddings(self, input_im):
input_im = Image.fromarray(input_im)
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def preprocessed_image_to_emdeddings(self, prepro):
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
async def __call__(self, http_request: Request) -> str:
form_data = await http_request.form()
embeddings = None
if "text" in form_data:
prompt = (await form_data["text"].read()).decode()
print (type(prompt))
print (str(prompt))
embeddings = self.text_to_embeddings(prompt)
elif "image_url" in form_data:
image_url = (await form_data["image_url"].read()).decode()
# download image from url
import requests
from io import BytesIO
image_bytes = requests.get(image_url).content
input_image = Image.open(BytesIO(image_bytes))
input_image = input_image.convert('RGB')
input_image = np.array(input_image)
embeddings = self.image_to_embeddings(input_image)
elif "preprocessed_image" in form_data:
tensor_bytes = await form_data["preprocessed_image"].read()
shape_bytes = await form_data["shape"].read()
dtype_bytes = await form_data["dtype"].read()
# Convert bytes back to original form
dtype_mapping = {
"torch.float32": torch.float32,
"torch.float64": torch.float64,
"torch.float16": torch.float16,
"torch.uint8": torch.uint8,
"torch.int8": torch.int8,
"torch.int16": torch.int16,
"torch.int32": torch.int32,
"torch.int64": torch.int64,
torch.float32: np.float32,
torch.float64: np.float64,
torch.float16: np.float16,
torch.uint8: np.uint8,
torch.int8: np.int8,
torch.int16: np.int16,
torch.int32: np.int32,
torch.int64: np.int64,
# add more if needed
}
dtype_str = dtype_bytes.decode()
dtype_torch = dtype_mapping[dtype_str]
dtype_numpy = dtype_mapping[dtype_torch]
# shape = np.frombuffer(shape_bytes, dtype=np.int64)
# TODO: fix shape so it is passed nicely
shape = tuple([1, 3, 224, 224])
tensor_numpy = np.frombuffer(tensor_bytes, dtype=dtype_numpy).reshape(shape)
tensor = torch.from_numpy(tensor_numpy)
prepro = tensor.to(self.device)
embeddings = self.preprocessed_image_to_emdeddings(prepro)
else:
print ("Invalid request")
raise Exception("Invalid request")
return embeddings.cpu().numpy().tolist()
request = await http_request.json()
# print(type(request))
# print(str(request))
# switch based if we are using text or image
embeddings = None
if "text" in request:
prompt = request["text"]
embeddings = self.text_to_embeddings(prompt)
elif "image_url" in request:
image_url = request["image_url"]
# download image from url
import requests
from io import BytesIO
image_bytes = requests.get(image_url).content
input_image = Image.open(BytesIO(image_bytes))
input_image = input_image.convert('RGB')
input_image = np.array(input_image)
embeddings = self.image_to_embeddings(input_image)
elif "preprocessed_image" in request:
prepro = request["preprocessed_image"]
# create torch tensor on the device
prepro = torch.tensor(prepro).to(self.device)
embeddings = self.preprocessed_image_to_emdeddings(prepro)
else:
raise Exception("Invalid request")
return embeddings.cpu().numpy().tolist()
deployment_graph = CLIPTransform.bind()