photoguard / app.py
RamAnanth1's picture
Update app.py
5d424f4
raw history blame
No virus
3.76 kB
import gradio as gr
import os
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import numpy as np
import torch
import requests
from tqdm import tqdm
from diffusers import StableDiffusionImg2ImgPipeline
import torchvision.transforms as T
from utils import preprocess, recover_image
to_pil = T.ToPILImage()
title = "Interactive demo: Raising the Cost of Malicious AI-Powered Image Editing"
model_id_or_path = "runwayml/stable-diffusion-v1-5"
# model_id_or_path = "CompVis/stable-diffusion-v1-4"
# model_id_or_path = "CompVis/stable-diffusion-v1-3"
# model_id_or_path = "CompVis/stable-diffusion-v1-2"
# model_id_or_path = "CompVis/stable-diffusion-v1-1"
pipe_img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id_or_path,
revision="fp16",
torch_dtype=torch.float16,
)
pipe_img2img = pipe_img2img.to("cuda")
def pgd(X, model, eps=0.1, step_size=0.015, iters=40, clamp_min=0, clamp_max=1, mask=None):
X_adv = X.clone().detach() + (torch.rand(*X.shape)*2*eps-eps).cuda()
pbar = tqdm(range(iters))
for i in pbar:
actual_step_size = step_size - (step_size - step_size / 100) / iters * i
X_adv.requires_grad_(True)
loss = (model(X_adv).latent_dist.mean).norm()
pbar.set_description(f"[Running attack]: Loss {loss.item():.5f} | step size: {actual_step_size:.4}")
grad, = torch.autograd.grad(loss, [X_adv])
X_adv = X_adv - grad.detach().sign() * actual_step_size
X_adv = torch.minimum(torch.maximum(X_adv, X - eps), X + eps)
X_adv.data = torch.clamp(X_adv, min=clamp_min, max=clamp_max)
X_adv.grad = None
if mask is not None:
X_adv.data *= mask
return X_adv
def process_image(raw_image,prompt):
resize = T.transforms.Resize(512)
center_crop = T.transforms.CenterCrop(512)
init_image = center_crop(resize(raw_image))
with torch.autocast('cuda'):
X = preprocess(init_image).half().cuda()
adv_X = pgd(X,
model=pipe_img2img.vae.encode,
clamp_min=-1,
clamp_max=1,
eps=0.06, # The higher, the less imperceptible the attack is
step_size=0.02, # Set smaller than eps
iters=100, # The higher, the stronger your attack will be
)
# convert pixels back to [0,1] range
adv_X = (adv_X / 2 + 0.5).clamp(0, 1)
adv_image = to_pil(adv_X[0]).convert("RGB")
# a good seed (uncomment the line below to generate new images)
SEED = 9222
# SEED = np.random.randint(low=0, high=10000)
# Play with these for improving generated image quality
STRENGTH = 0.5
GUIDANCE = 7.5
NUM_STEPS = 50
with torch.autocast('cuda'):
torch.manual_seed(SEED)
image_nat = pipe_img2img(prompt=prompt, image=init_image, strength=STRENGTH, guidance_scale=GUIDANCE, num_inference_steps=NUM_STEPS).images[0]
torch.manual_seed(SEED)
image_adv = pipe_img2img(prompt=prompt, image=adv_image, strength=STRENGTH, guidance_scale=GUIDANCE, num_inference_steps=NUM_STEPS).images[0]
return [(init_image,"Source Image"), (adv_image, "Adv Image"), (image_nat,"Gen. Image Nat"), (image_adv, "Gen. Image Adv")]
interface = gr.Interface(fn=process_image,
inputs=[gr.Image(type="pil"), gr.Textbox(label="Prompt")],
outputs=[gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
],
title=title
)
interface.launch(debug=True)