File size: 1,167 Bytes
904ef7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn

import torchvision.transforms as T
import torchvision.transforms.functional as TF

import clip

class CLIP(nn.Module):
    def __init__(self, device):
        super().__init__()

        self.device = device

        self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
        
         # image augmentation
        self.aug = T.Compose([
            T.Resize((224, 224)),
            T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

        # self.gaussian_blur = T.GaussianBlur(15, sigma=(0.1, 10))

    
    def get_text_embeds(self, prompt):

        text = clip.tokenize(prompt).to(self.device)
        text_z = self.clip_model.encode_text(text)
        text_z = text_z / text_z.norm(dim=-1, keepdim=True)

        return text_z

    
    def train_step(self, text_z, pred_rgb):

        pred_rgb = self.aug(pred_rgb)

        image_z = self.clip_model.encode_image(pred_rgb)
        image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features

        loss = - (image_z * text_z).sum(-1).mean()

        return loss