sohojoe's picture
trying to get api working but it is not working yet
ed1e314
raw
history blame
3.22 kB
from typing import List
import numpy as np
import torch
import ray
from ray import serve
from PIL import Image
from clip_retrieval.load_clip import load_clip, get_tokenizer
# from clip_retrieval.clip_client import ClipClient, Modality
@serve.deployment(num_replicas=6, ray_actor_options={"num_cpus": .2, "num_gpus": 0.1})
class CLIPTransform:
def __init__(self):
# os.environ["OMP_NUM_THREADS"] = "20"
# torch.set_num_threads(20)
# Load model
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self._clip_model="ViT-L/14"
self._clip_model_id ="laion5B-L-14"
self.model, self.preprocess = load_clip(self._clip_model, use_jit=True, device=self.device)
self.tokenizer = get_tokenizer(self._clip_model)
print ("using device", self.device)
@serve.batch(max_batch_size=32)
# def text_to_embeddings(self, prompts: List[str]) -> torch.Tensor:
def text_to_embeddings(self, prompts: List[str]) -> List[np.ndarray]:
text = self.tokenizer(prompts).to(self.device)
with torch.no_grad():
prompt_embededdings = self.model.encode_text(text)
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
prompt_embededdings = prompt_embededdings.cpu().numpy().tolist()
return(prompt_embededdings)
def image_to_embeddings(self, input_im):
input_im = Image.fromarray(input_im)
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def preprocessed_image_to_emdeddings(self, prepro):
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
# async def __call__(self, http_request: Request) -> str:
# request = await http_request.json()
# # print(type(request))
# # print(str(request))
# # switch based if we are using text or image
# embeddings = None
# if "text" in request:
# prompt = request["text"]
# embeddings = self.text_to_embeddings(prompt)
# elif "image" in request:
# image_url = request["image_url"]
# # download image from url
# import requests
# from io import BytesIO
# input_image = Image.open(BytesIO(image_url))
# input_image = input_image.convert('RGB')
# input_image = np.array(input_image)
# embeddings = self.image_to_embeddings(input_image)
# elif "preprocessed_image" in request:
# prepro = request["preprocessed_image"]
# # create torch tensor on the device
# prepro = torch.tensor(prepro).to(self.device)
# embeddings = self.preprocessed_image_to_emdeddings(prepro)
# else:
# raise Exception("Invalid request")
# return embeddings.cpu().numpy().tolist()
deployment_graph = CLIPTransform.bind()