project_charles / clip_transform.py
sohojoe's picture
create prototype for is someone is there
3e4f32c
raw
history blame
No virus
2.39 kB
import json
import os
import numpy as np
import torch
from PIL import Image
import open_clip
class CLIPTransform:
def __init__(self):
# os.environ["OMP_NUM_THREADS"] = "20"
# torch.set_num_threads(20)
# Load model
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
# if self.device == "cpu" and torch.backends.mps.is_available():
# self.device = torch.device("mps")
# # ViT-H-14
# self._clip_model="ViT-H-14"
# self._pretrained='laion2B-s32B-b79K'
# # ViT-B-32
# self._clip_model="ViT-B-32"
# self._pretrained='laion2b_s34b_b79k'
# ViT-L/14 1.71gb
self._clip_model="ViT-L-14"
self._pretrained='datacomp_xl_s13b_b90k'
self.model, _, self.preprocess = open_clip.create_model_and_transforms(self._clip_model, pretrained=self._pretrained,device=self.device)
self.tokenizer = open_clip.get_tokenizer(self._clip_model)
print ("using device", self.device)
def text_to_embeddings(self, prompts):
# if prompt is a string, convert to list
if type(prompts) is str:
prompts = [prompts]
text = self.tokenizer(prompts).to(self.device)
with torch.no_grad():
prompt_embededdings = self.model.encode_text(text)
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
return(prompt_embededdings)
def image_to_embeddings(self, input_im):
input_im = Image.fromarray(input_im)
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def pil_image_to_embeddings(self, input_im):
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def preprocessed_image_to_emdeddings(self, prepro):
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)