Spaces:
Runtime error
Runtime error
File size: 2,391 Bytes
c6ad8e3 c58cbbc c6ad8e3 c58cbbc 35d97c8 c58cbbc c6ad8e3 3e4f32c c6ad8e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import json
import os
import numpy as np
import torch
from PIL import Image
import open_clip
class CLIPTransform:
def __init__(self):
# os.environ["OMP_NUM_THREADS"] = "20"
# torch.set_num_threads(20)
# Load model
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
# if self.device == "cpu" and torch.backends.mps.is_available():
# self.device = torch.device("mps")
# # ViT-H-14
# self._clip_model="ViT-H-14"
# self._pretrained='laion2B-s32B-b79K'
# # ViT-B-32
# self._clip_model="ViT-B-32"
# self._pretrained='laion2b_s34b_b79k'
# ViT-L/14 1.71gb
self._clip_model="ViT-L-14"
self._pretrained='datacomp_xl_s13b_b90k'
self.model, _, self.preprocess = open_clip.create_model_and_transforms(self._clip_model, pretrained=self._pretrained,device=self.device)
self.tokenizer = open_clip.get_tokenizer(self._clip_model)
print ("using device", self.device)
def text_to_embeddings(self, prompts):
# if prompt is a string, convert to list
if type(prompts) is str:
prompts = [prompts]
text = self.tokenizer(prompts).to(self.device)
with torch.no_grad():
prompt_embededdings = self.model.encode_text(text)
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
return(prompt_embededdings)
def image_to_embeddings(self, input_im):
input_im = Image.fromarray(input_im)
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def pil_image_to_embeddings(self, input_im):
prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
def preprocessed_image_to_emdeddings(self, prepro):
with torch.no_grad():
image_embeddings = self.model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
return(image_embeddings)
|