File size: 2,391 Bytes
c6ad8e3
 
 
 
 
c58cbbc
c6ad8e3
 
 
 
 
 
 
 
 
c58cbbc
 
 
 
 
 
 
 
 
 
 
 
 
35d97c8
c58cbbc
c6ad8e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e4f32c
 
 
 
 
 
 
c6ad8e3
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json
import os
import numpy as np
import torch
from PIL import Image
import open_clip

class CLIPTransform:
    def __init__(self):
        # os.environ["OMP_NUM_THREADS"] = "20"
        # torch.set_num_threads(20)
        # Load model
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        # if self.device == "cpu" and torch.backends.mps.is_available():
        #     self.device = torch.device("mps")

        # # ViT-H-14
        # self._clip_model="ViT-H-14"
        # self._pretrained='laion2B-s32B-b79K'

        # # ViT-B-32
        # self._clip_model="ViT-B-32"
        # self._pretrained='laion2b_s34b_b79k'

        # ViT-L/14 1.71gb
        self._clip_model="ViT-L-14"
        self._pretrained='datacomp_xl_s13b_b90k'

        self.model, _, self.preprocess = open_clip.create_model_and_transforms(self._clip_model, pretrained=self._pretrained,device=self.device)
        self.tokenizer = open_clip.get_tokenizer(self._clip_model)

        print ("using device", self.device)        

    def text_to_embeddings(self, prompts):
        # if prompt is a string, convert to list
        if type(prompts) is str:
            prompts = [prompts]
        text = self.tokenizer(prompts).to(self.device)
        with torch.no_grad():
            prompt_embededdings = self.model.encode_text(text)
        prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
        return(prompt_embededdings)
    
    def image_to_embeddings(self, input_im):
        input_im = Image.fromarray(input_im)
        prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
        with torch.no_grad():
            image_embeddings = self.model.encode_image(prepro)
        image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
        return(image_embeddings)
    
    def pil_image_to_embeddings(self, input_im):
        prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
        with torch.no_grad():
            image_embeddings = self.model.encode_image(prepro)
        image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
        return(image_embeddings)

    def preprocessed_image_to_emdeddings(self, prepro):
        with torch.no_grad():
            image_embeddings = self.model.encode_image(prepro)
        image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
        return(image_embeddings)