Spaces:
Running
on
T4
Running
on
T4
import gradio as gr | |
from PIL import Image | |
import torch | |
import re | |
import os | |
import requests | |
from customization import customize_vae_decoder | |
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler, EulerDiscreteScheduler | |
from torchvision import transforms | |
from attribution import MappingNetwork | |
import math | |
from typing import List | |
from PIL import Image | |
import numpy as np | |
import torch | |
PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/" | |
def get_image_grid(images: List[Image.Image]) -> Image: | |
num_images = len(images) | |
cols = 3#int(math.ceil(math.sqrt(num_images))) | |
rows = 1#int(math.ceil(num_images / cols)) | |
width, height = images[0].size | |
grid_image = Image.new('RGB', (cols * width, rows * height)) | |
for i, img in enumerate(images): | |
x = i % cols | |
y = i // cols | |
grid_image.paste(img, (x * width, y * height)) | |
return grid_image | |
class AttributionModel: | |
def __init__(self): | |
is_cuda = False | |
if torch.cuda.is_available(): | |
is_cuda = True | |
self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2')#, safety_checker=None, torch_dtype=torch.float16) | |
if is_cuda: | |
self.pipe = self.pipe.to("cuda") | |
self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR) | |
self.vae = AutoencoderKL.from_pretrained( | |
'stabilityai/stable-diffusion-2', subfolder="vae" | |
) | |
self.vae = customize_vae_decoder(self.vae, 128, "qkv", "all", False, 1.0) | |
self.mapping_network = MappingNetwork(32, 0, 128, None, num_layers=2, w_avg_beta=None, normalization = False) | |
from torchvision.models import resnet50, ResNet50_Weights | |
self.decoding_network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) | |
self.decoding_network.fc = torch.nn.Linear(2048,32) | |
self.vae.decoder.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'vae_decoder.pth'))) | |
self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth'))) | |
self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth'))) | |
if is_cuda: | |
self.vae = self.vae.to("cuda") | |
self.mapping_network = self.mapping_network.to("cuda") | |
self.decoding_network = self.decoding_network.to("cuda") | |
self.test_norm = transforms.Compose( | |
[ | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
] | |
) | |
def infer(self, prompt, negative, steps, guidance_scale): | |
with torch.no_grad(): | |
out_latents = self.pipe([prompt], negative_prompt=[negative], output_type="latent", num_inference_steps=steps, guidance_scale=guidance_scale).images | |
image_attr = self.inference_with_attribution(out_latents) | |
image_attr_pil = self.pipe.numpy_to_pil(image_attr[0]) | |
image_org = self.inference_without_attribution(out_latents) | |
image_org_pil = self.pipe.numpy_to_pil(image_org[0]) | |
image_diff_pil = self.pipe.numpy_to_pil(image_attr[0] - image_org[0]) | |
return image_org_pil[0], image_attr_pil[0], image_diff_pil[0] | |
def inference_without_attribution(self, latents): | |
latents = 1 / 0.18215 * latents | |
with torch.no_grad(): | |
image = self.pipe.vae.decode(latents).sample | |
image = image.clamp(-1,1) | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
return image | |
def get_phis(self, phi_dimension, batch_size ,eps = 1e-8): | |
phi_length = phi_dimension | |
b = batch_size | |
phi = torch.empty(b,phi_length).uniform_(0,1) | |
return torch.bernoulli(phi) + eps | |
def inference_with_attribution(self, latents, key=None): | |
if key==None: | |
key = self.get_phis(32, 1) | |
latents = 1 / 0.18215 * latents | |
with torch.no_grad(): | |
image = self.vae.decode(latents, self.mapping_network(key.cuda())).sample | |
image = image.clamp(-1,1) | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
return image | |
def postprocess(self, image): | |
image = self.resize_transform(image) | |
return image | |
def detect_key(self, image): | |
reconstructed_keys = self.decoding_network(self.test_norm((image / 2 + 0.5).clamp(0, 1))) | |
return reconstructed_keys | |
attribution_model = AttributionModel() | |
def get_images(prompt, negative, steps, guidence_scale): | |
x1, x2, x3 = attribution_model.infer(prompt, negative, steps, guidence_scale) | |
return [x1, x2, x3] | |
image_examples = [ | |
["A pikachu fine dining with a view to the Eiffel Tower", "low quality", 50, 10], | |
["A mecha robot in a favela in expressionist style", "low quality, 3d, photorealistic", 50, 10] | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"""<h1 style="text-align: center;"><b>WOUAF: | |
Weight Modulation for User Attribution and Fingerprinting in Text-to-Image Diffusion Models</b> <br> <a href="https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion | |
-Models/">Project Page</a></h1>""") | |
gr.Markdown( | |
"""<h3>Demo: Text-to-Image (Stable diffusion 2) with random user attribution</h3> | |
WOUAF can be applied to other applications such as In-painting, Image-editing, Image Super-Resolution etc. | |
<br>More details at: <a href="https://arxiv.org/abs/2306.04744">Paper</a> | |
""" | |
) | |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): | |
with gr.Column(): | |
text = gr.Textbox( | |
label="Enter your prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
elem_id="prompt-text-input", | |
).style( | |
border=(True, False, True, True), | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
negative = gr.Textbox( | |
label="Enter your negative prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
elem_id="negative-prompt-text-input", | |
).style( | |
border=(True, False, True, True), | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
with gr.Row(): | |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1) | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", minimum=0, maximum=10, value=9, step=0.1 | |
) | |
with gr.Row(): | |
btn = gr.Button(value="Generate Image", full_width=False) | |
with gr.Row(): | |
im_2 = gr.Image(type="pil", label="without attribution") | |
im_3 = gr.Image(type="pil", label="**with** attribution") | |
im_4 = gr.Image(type="pil", label="pixel-wise difference") | |
btn.click(get_images, inputs=[text, negative, steps, guidance_scale], outputs=[im_2, im_3, im_4]) | |
gr.Examples( | |
examples=image_examples, | |
inputs=[text, negative, steps, guidance_scale], | |
outputs=[im_2, im_3, im_4], | |
fn=get_images, | |
cache_examples=True, | |
) | |
gr.HTML( | |
""" | |
<div class="footer"> | |
<p>Pre-trained model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a> | |
</p> | |
<p> | |
Fine-tuned by authors for research purpose. | |
</p> | |
</div> | |
""" | |
) | |
with gr.Accordion(label="Ethics & Privacy", open=False): | |
gr.HTML( | |
"""<div class="acknowledgments"> | |
<p><h4>Privacy</h4> | |
We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI. | |
<p><h4>Biases and content acknowledgment</h4> | |
This model will have the same biases as Stable Diffusion V2.1 </div> | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |