project_charles / clip_transform.py
sohojoe's picture
create prototype for is someone is there
3e4f32c
raw
history blame
2.39 kB
import json
import os
import numpy as np
import torch
from PIL import Image
import open_clip
class CLIPTransform:
def __init__(self):
# os.environ["OMP_NUM_THREADS"] = "20"
# torch.set_num_threads(20)
# Load model
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
# if self.device == "cpu" and torch.backends.mps.is_available():
# self.device = torch.device("mps")
# # ViT-H-14
# self._clip_model="ViT-H-14"
# self._pretrained='laion2B-s32B-b79K'
# # ViT-B-32
# self._clip_model="ViT-B-32"
# self._pretrained='laion2b_s34b_b79k'
# ViT-L/14 1.71gb
self._clip_model="ViT-L-14"
self._pretrained='datacomp_xl_s13b_b90k'
self.model, _, self.preprocess = open_clip.create_model_and_transforms(self._clip_model, pretrained=self._pretrained,device=self.device)
self.tokenizer = open_clip.get_tokenizer(self._clip_model)
print ("using device", self.device)
def text_to_embeddings(self, prompts):
# if prompt is a string, convert to list
if type(prompts) is str:
prompts = [prompts]
text = self.tokenizer(prompts).to(self.device)
with torch.no_grad():
prompt_embededdings = self.model.encode_text(text)
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
return(prompt_embededdings)
def image_to_embeddings(self, input_im):
input_im = Image.fromarray(input_im)
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def pil_image_to_embeddings(self, input_im):
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def preprocessed_image_to_emdeddings(self, prepro):
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)