# File 2: models/clip_feedback.py import torch import clip class CLIPFeedback: def __init__(self, device='cuda'): self.device = device self.model, self.preprocess = clip.load("ViT-B/32", device=device) self.model.eval() def get_text_embeddings(self, text): text_input = clip.tokenize([text]).to(self.device) with torch.no_grad(): text_features = self.model.encode_text(text_input) return text_features def get_image_embeddings(self, images): with torch.no_grad(): image_features = self.model.encode_image(images) return image_features def calculate_similarity(self, images, text): text_features = self.get_text_embeddings(text) image_features = self.get_image_embeddings(images) similarity = F.cosine_similarity(image_features, text_features) return similarity def refine_latent(self, latent, decoder, text, steps=5, lr=0.01): latent = latent.clone().detach().requires_grad_(True) optimizer = torch.optim.Adam([latent], lr=lr) text_features = self.get_text_embeddings(text) for _ in range(steps): generated = decoder(latent) image_features = self.get_image_embeddings(generated) loss = -F.cosine_similarity(image_features, text_features).mean() optimizer.zero_grad() loss.backward() optimizer.step() return latent.detach()