|
|
|
import torch |
|
import clip |
|
|
|
class CLIPFeedback: |
|
def __init__(self, device='cuda'): |
|
self.device = device |
|
self.model, self.preprocess = clip.load("ViT-B/32", device=device) |
|
self.model.eval() |
|
|
|
def get_text_embeddings(self, text): |
|
text_input = clip.tokenize([text]).to(self.device) |
|
with torch.no_grad(): |
|
text_features = self.model.encode_text(text_input) |
|
return text_features |
|
|
|
def get_image_embeddings(self, images): |
|
with torch.no_grad(): |
|
image_features = self.model.encode_image(images) |
|
return image_features |
|
|
|
def calculate_similarity(self, images, text): |
|
text_features = self.get_text_embeddings(text) |
|
image_features = self.get_image_embeddings(images) |
|
similarity = F.cosine_similarity(image_features, text_features) |
|
return similarity |
|
|
|
def refine_latent(self, latent, decoder, text, steps=5, lr=0.01): |
|
latent = latent.clone().detach().requires_grad_(True) |
|
optimizer = torch.optim.Adam([latent], lr=lr) |
|
|
|
text_features = self.get_text_embeddings(text) |
|
|
|
for _ in range(steps): |
|
generated = decoder(latent) |
|
image_features = self.get_image_embeddings(generated) |
|
loss = -F.cosine_similarity(image_features, text_features).mean() |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
return latent.detach() |