project_charles / clip_transform.py
sohojoe's picture
video not really work that well
c6ad8e3
raw
history blame
2.12 kB
import json
import os
import numpy as np
import torch
from PIL import Image
from clip_retrieval.load_clip import load_clip, get_tokenizer
# from clip_retrieval.clip_client import ClipClient, Modality
class CLIPTransform:
def __init__(self):
# os.environ["OMP_NUM_THREADS"] = "20"
# torch.set_num_threads(20)
# Load model
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
# if self.device == "cpu" and torch.backends.mps.is_available():
# self.device = torch.device("mps")
# self._clip_model="ViT-L/14"
self._clip_model="open_clip:ViT-H-14"
# self._clip_model="open_clip:ViT-L-14"
# self._clip_model="open_clip:datacomp_xl_s13b_b90k"
# import open_clip
# pretrained = dict(open_clip.list_pretrained())
# checkpoint = pretrained[self._clip_model]
self.model, self.preprocess = load_clip(self._clip_model, use_jit=True, device=self.device)
self.tokenizer = get_tokenizer(self._clip_model)
print ("using device", self.device)
def text_to_embeddings(self, prompts):
# if prompt is a string, convert to list
if type(prompts) is str:
prompts = [prompts]
text = self.tokenizer(prompts).to(self.device)
with torch.no_grad():
prompt_embededdings = self.model.encode_text(text)
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
return(prompt_embededdings)
def image_to_embeddings(self, input_im):
input_im = Image.fromarray(input_im)
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def preprocessed_image_to_emdeddings(self, prepro):
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)