File size: 1,523 Bytes
f277b08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# File 2: models/clip_feedback.py
import torch
import clip

class CLIPFeedback:
    def __init__(self, device='cuda'):
        self.device = device
        self.model, self.preprocess = clip.load("ViT-B/32", device=device)
        self.model.eval()

    def get_text_embeddings(self, text):
        text_input = clip.tokenize([text]).to(self.device)
        with torch.no_grad():
            text_features = self.model.encode_text(text_input)
        return text_features

    def get_image_embeddings(self, images):
        with torch.no_grad():
            image_features = self.model.encode_image(images)
        return image_features

    def calculate_similarity(self, images, text):
        text_features = self.get_text_embeddings(text)
        image_features = self.get_image_embeddings(images)
        similarity = F.cosine_similarity(image_features, text_features)
        return similarity

    def refine_latent(self, latent, decoder, text, steps=5, lr=0.01):
        latent = latent.clone().detach().requires_grad_(True)
        optimizer = torch.optim.Adam([latent], lr=lr)
        
        text_features = self.get_text_embeddings(text)
        
        for _ in range(steps):
            generated = decoder(latent)
            image_features = self.get_image_embeddings(generated)
            loss = -F.cosine_similarity(image_features, text_features).mean()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        return latent.detach()