DeIT-Dreamer / app.py
SoggyKiwi's picture
bro really had them the wrong way around
b222813
raw
history blame
No virus
2.5 kB
import gradio as gr
import torch
from torch.nn import BCEWithLogitsLoss
import numpy as np
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
# Load model and feature extractor outside the function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
model.to(device)
model.eval()
def get_encoder_activations(x):
encoder_output = model.vit(x)
final_activations = encoder_output.last_hidden_state[:,0,:]
return final_activations
def process_image(input_image, learning_rate, iterations, n_targets, seed):
if input_image is None:
return None
image = input_image.convert('RGB')
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
pixel_values.requires_grad_(True)
torch.manual_seed(int(seed))
random_one_logits = torch.zeros(1000)
random_one_logits[torch.randperm(1000)[:int(n_targets)]] = 1
random_one_logits = random_one_logits.to(pixel_values.device)
for iteration in range(int(iterations)):
model.zero_grad()
if pixel_values.grad is not None:
pixel_values.grad.data.zero_()
final_activations = get_encoder_activations(pixel_values.to(device))
logits = model.classifier(final_activations[0]).to(device)
original_loss = BCEWithLogitsLoss(reduction='sum')(logits,random_one_logits)
original_loss.backward()
with torch.no_grad():
pixel_values.data += learning_rate * pixel_values.grad.data
pixel_values.data = torch.clamp(pixel_values.data, -1, 1)
updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8)
return updated_pixel_values_np
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil"),
gr.Number(value=1.0, minimum=0, label="Learning Rate"),
gr.Number(value=2, minimum=1, label="Iterations"),
gr.Number(value=250, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"),
gr.Number(value=420, minimum=0, label="Seed"),
],
outputs=[gr.Image(type="numpy", label="Dreamed Image")]
)
iface.launch()