sohojoe's picture
experimental ray
334dcac
raw
history blame
No virus
3.09 kB
import time
import numpy as np
import torch
from PIL import Image
import ray
from ray import serve
from clip_retrieval.load_clip import load_clip, get_tokenizer
# from clip_retrieval.clip_client import ClipClient, Modality
class CLIPModel:
def __init__(self):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self._test_image_url = "https://static.wixstatic.com/media/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg/v1/fill/w_454,h_333,fp_0.50_0.50,q_90/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg"
self._clip_model="ViT-L/14"
self._clip_model_id ="laion5B-L-14"
self.model, self.preprocess = load_clip(self._clip_model, use_jit=True, device=self.device)
self.tokenizer = get_tokenizer(self._clip_model)
print ("using device", self.device)
def test_to_embeddings(self, prompt):
text = self.tokenizer([prompt]).to(self.device)
with torch.no_grad():
prompt_embededdings = self.model.encode_text(text)
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
return(prompt_embededdings)
def image_to_embeddings(self, input_im):
input_im = Image.fromarray(input_im)
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def preprocessed_image_to_emdeddings(self, prepro):
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
# simple regression test
def regression_test(self):
text_embeddings = self.test_to_embeddings("Howdy!")
print("text embeddings", text_embeddings)
# download image from url
import requests
from io import BytesIO
response = requests.get(self._test_image_url)
input_image = Image.open(BytesIO(response.content))
input_image = input_image.convert('RGB')
# convert image to numpy array
input_image = np.array(input_image)
image_embeddings = self.image_to_embeddings(input_image)
print("image embeddings", image_embeddings)
input_im = Image.fromarray(input_image)
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
image_embeddings = self.preprocessed_image_to_emdeddings(prepro)
print("image embeddings", image_embeddings)
# regression test
test_instance = CLIPModel()
test_instance.regression_test()
ray.init()
serve.start()
# Register the model with Ray Serve
serve.create_backend("clip_model", CLIPModel)
serve.create_endpoint("clip_model", backend="clip_model", route="/clip_model")
# You can now call the endpoint with your input
import requests
input_prompt = "Howdy!"
response = requests.get("http://localhost:8000/clip_model", json={"prompt": input_prompt})
print(response.json())