projectai / clip_feedback.py
Matthew Frazer
Create clip_feedback.py
f277b08 verified
# File 2: models/clip_feedback.py
import torch
import clip
class CLIPFeedback:
def __init__(self, device='cuda'):
self.device = device
self.model, self.preprocess = clip.load("ViT-B/32", device=device)
self.model.eval()
def get_text_embeddings(self, text):
text_input = clip.tokenize([text]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_input)
return text_features
def get_image_embeddings(self, images):
with torch.no_grad():
image_features = self.model.encode_image(images)
return image_features
def calculate_similarity(self, images, text):
text_features = self.get_text_embeddings(text)
image_features = self.get_image_embeddings(images)
similarity = F.cosine_similarity(image_features, text_features)
return similarity
def refine_latent(self, latent, decoder, text, steps=5, lr=0.01):
latent = latent.clone().detach().requires_grad_(True)
optimizer = torch.optim.Adam([latent], lr=lr)
text_features = self.get_text_embeddings(text)
for _ in range(steps):
generated = decoder(latent)
image_features = self.get_image_embeddings(generated)
loss = -F.cosine_similarity(image_features, text_features).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
return latent.detach()