import gradio as gr import torch from torch.nn import BCEWithLogitsLoss import numpy as np from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image # Load model and feature extractor outside the function device = torch.device("cuda" if torch.cuda.is_available() else "cpu") feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384') model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384') model.to(device) model.eval() def get_encoder_activations(x): encoder_output = model.vit(x) final_activations = encoder_output.last_hidden_state[:,0,:] return final_activations def process_image(input_image, learning_rate, iterations, n_targets, seed): if input_image is None: return None image = input_image.convert('RGB') pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) pixel_values.requires_grad_(True) torch.manual_seed(int(seed)) random_one_logits = torch.zeros(1000) random_one_logits[torch.randperm(1000)[:n_targets]] = 1 random_one_logits = random_one_logits.to(pixel_values.device) for iteration in range(int(iterations)): model.zero_grad() if pixel_values.grad is not None: pixel_values.grad.data.zero_() final_activations = get_encoder_activations(pixel_values.to('cuda')) logits = model.classifier(final_activations[0]).to(pixel_values.device) original_loss = BCEWithLogitsLoss(reduction='sum')(logits,random_one_logits) original_loss.backward() with torch.no_grad(): pixel_values.data += learning_rate * pixel_values.grad.data pixel_values.data = torch.clamp(pixel_values.data, -1, 1) updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5 updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8) return updated_pixel_values_np iface = gr.Interface( fn=process_image, inputs=[ gr.Image(type="pil"), gr.Number(value=1.0, minimum=0, label="Learning Rate"), gr.Number(value=2, minimum=1, label="Iterations"), gr.Number(value=420, minimum=0, label="Seed"), gr.Number(value=250, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"), ], outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")] ) iface.launch()