Maitreyapatel commited on
Commit
de248ee
1 Parent(s): 084a6a2

updated demo anouncement

Browse files
Files changed (2) hide show
  1. app.py +13 -189
  2. app_demo.py +214 -0
app.py CHANGED
@@ -17,198 +17,22 @@ from PIL import Image, ImageChops
17
  import numpy as np
18
  import torch
19
 
20
-
21
- PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/"
22
-
23
-
24
- def get_image_grid(images: List[Image.Image]) -> Image:
25
- num_images = len(images)
26
- cols = 3#int(math.ceil(math.sqrt(num_images)))
27
- rows = 1#int(math.ceil(num_images / cols))
28
- width, height = images[0].size
29
- grid_image = Image.new('RGB', (cols * width, rows * height))
30
- for i, img in enumerate(images):
31
- x = i % cols
32
- y = i // cols
33
- grid_image.paste(img, (x * width, y * height))
34
- return grid_image
35
-
36
-
37
- class AttributionModel:
38
- def __init__(self):
39
- is_cuda = False
40
- if torch.cuda.is_available():
41
- is_cuda = True
42
-
43
- scheduler = EulerDiscreteScheduler.from_pretrained('stabilityai/stable-diffusion-2', subfolder="scheduler")
44
- self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2', scheduler=scheduler)#, safety_checker=None, torch_dtype=torch.float16)
45
- if is_cuda:
46
- self.pipe = self.pipe.to("cuda")
47
- self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR)
48
- self.vae = AutoencoderKL.from_pretrained(
49
- 'stabilityai/stable-diffusion-2', subfolder="vae"
50
- )
51
- self.vae = customize_vae_decoder(self.vae, 128, "deqkv", "all", False, 1.0)
52
-
53
- self.mapping_network = MappingNetwork(32, 0, 128, None, num_layers=2, w_avg_beta=None, normalization = False)
54
-
55
- from torchvision.models import resnet50, ResNet50_Weights
56
- self.decoding_network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
57
- self.decoding_network.fc = torch.nn.Linear(2048,32)
58
-
59
- self.vae.decoder.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'vae_decoder.pth')))
60
- self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth')))
61
- self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth')))
62
-
63
- if is_cuda:
64
- self.vae = self.vae.to("cuda")
65
- self.mapping_network = self.mapping_network.to("cuda")
66
- self.decoding_network = self.decoding_network.to("cuda")
67
-
68
- self.test_norm = transforms.Compose(
69
- [
70
- transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
71
- ]
72
- )
73
-
74
- def infer(self, prompt, negative, steps, guidance_scale):
75
- with torch.no_grad():
76
- out_latents = self.pipe([prompt], negative_prompt=[negative], output_type="latent", num_inference_steps=steps, guidance_scale=guidance_scale).images
77
- image_attr = self.inference_with_attribution(out_latents)
78
- image_attr_pil = self.pipe.numpy_to_pil(image_attr[0])
79
-
80
- image_org = self.inference_without_attribution(out_latents)
81
- image_org_pil = self.pipe.numpy_to_pil(image_org[0])
82
-
83
- # image_diff_pil = self.pipe.numpy_to_pil(image_attr[0] - image_org[0])
84
- diff_factor = 5
85
- image_diff_pil = ImageChops.difference(image_org_pil[0], image_attr_pil[0]).convert("RGB", (diff_factor,0,0,0,0,diff_factor,0,0,0,0,diff_factor,0))
86
-
87
- return image_org_pil[0], image_attr_pil[0], image_diff_pil
88
-
89
- def inference_without_attribution(self, latents):
90
- latents = 1 / 0.18215 * latents
91
- with torch.no_grad():
92
- image = self.pipe.vae.decode(latents).sample
93
- image = image.clamp(-1,1)
94
- image = (image / 2 + 0.5).clamp(0, 1)
95
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
96
- return image
97
-
98
- def get_phis(self, phi_dimension, batch_size ,eps = 1e-8):
99
- phi_length = phi_dimension
100
- b = batch_size
101
- phi = torch.empty(b,phi_length).uniform_(0,1)
102
- return torch.bernoulli(phi) + eps
103
-
104
-
105
- def inference_with_attribution(self, latents, key=None):
106
- if key==None:
107
- key = self.get_phis(32, 1)
108
-
109
- latents = 1 / 0.18215 * latents
110
- with torch.no_grad():
111
- image = self.vae.decode(latents, self.mapping_network(key.cuda())).sample
112
- image = image.clamp(-1,1)
113
- image = (image / 2 + 0.5).clamp(0, 1)
114
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
115
- return image
116
-
117
- def postprocess(self, image):
118
- image = self.resize_transform(image)
119
- return image
120
-
121
- def detect_key(self, image):
122
- reconstructed_keys = self.decoding_network(self.test_norm((image / 2 + 0.5).clamp(0, 1)))
123
- return reconstructed_keys
124
-
125
-
126
- attribution_model = AttributionModel()
127
- def get_images(prompt, negative, steps, guidence_scale):
128
- x1, x2, x3 = attribution_model.infer(prompt, negative, steps, guidence_scale)
129
- return [x1, x2, x3]
130
-
131
-
132
- image_examples = [
133
- ["A pikachu fine dining with a view to the Eiffel Tower", "low quality", 50, 10],
134
- ["A mecha robot in a favela in expressionist style", "low quality, 3d, photorealistic", 50, 10]
135
- ]
136
-
137
  with gr.Blocks() as demo:
138
  gr.Markdown(
139
- """<h1 style="text-align: center;"><b>WOUAF:
140
- Weight Modulation for User Attribution and Fingerprinting in Text-to-Image Diffusion Models</b> <br> <a href="https://wouaf.vercel.app">Project Page</a></h1>""")
141
-
142
- with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
143
- with gr.Column():
144
- text = gr.Textbox(
145
- label="Enter your prompt",
146
- show_label=False,
147
- max_lines=1,
148
- placeholder="Enter your prompt",
149
- elem_id="prompt-text-input",
150
- ).style(
151
- border=(True, False, True, True),
152
- rounded=(True, False, False, True),
153
- container=False,
154
- )
155
- negative = gr.Textbox(
156
- label="Enter your negative prompt",
157
- show_label=False,
158
- max_lines=1,
159
- placeholder="Enter a negative prompt",
160
- elem_id="negative-prompt-text-input",
161
- ).style(
162
- border=(True, False, True, True),
163
- rounded=(True, False, False, True),
164
- container=False,
165
- )
166
-
167
- with gr.Row():
168
- steps = gr.Slider(label="Steps", minimum=45, maximum=55, value=50, step=1)
169
- guidance_scale = gr.Slider(
170
- label="Guidance Scale", minimum=0, maximum=10, value=7.5, step=0.1
171
- )
172
-
173
- with gr.Row():
174
- btn = gr.Button(value="Generate Image", full_width=False)
175
-
176
- with gr.Row():
177
- im_2 = gr.Image(type="pil", label="without attribution")
178
- im_3 = gr.Image(type="pil", label="**with** attribution")
179
- im_4 = gr.Image(type="pil", label="pixel-wise difference multiplied by 5")
180
-
181
-
182
- btn.click(get_images, inputs=[text, negative, steps, guidance_scale], outputs=[im_2, im_3, im_4])
183
-
184
- gr.Examples(
185
- examples=image_examples,
186
- inputs=[text, negative, steps, guidance_scale],
187
- outputs=[im_2, im_3, im_4],
188
- fn=get_images,
189
- cache_examples=True,
190
- )
191
 
192
- gr.HTML(
193
- """
194
- <div class="footer">
195
- <p>Pre-trained model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a>
196
- </p>
197
- <p>
198
- Fine-tuned by authors for research purpose.
199
- </p>
200
- </div>
201
- """
202
- )
203
- with gr.Accordion(label="Ethics & Privacy", open=False):
204
- gr.HTML(
205
- """<div class="acknowledgments">
206
- <p><h4>Privacy</h4>
207
- We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI.
208
- <p><h4>Biases and content acknowledgment</h4>
209
- This model will have the same biases as Stable Diffusion V2.1 </div>
210
- """
211
- )
212
 
213
  if __name__ == "__main__":
214
  demo.launch()
 
17
  import numpy as np
18
  import torch
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  with gr.Blocks() as demo:
21
  gr.Markdown(
22
+ """<div style="transform: translate(0, 50%);">
23
+ <h1 style="text-align: center;"><b>WOUAF:
24
+ Weight Modulation for User Attribution and Fingerprinting in Text-to-Image Diffusion Models</b> <br> <a href="https://wouaf.vercel.app">Project Page</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; <a href="https://huggingface.co/spaces/wouaf/WOUAF-Text-to-Image">New Demo</a></h1>
25
+ <br>
26
+ <br>
27
+ <br>
28
+ <br>
29
+ <br>
30
+ <br>
31
+ <h1 style="text-align: center;"> With generous support from Intel, we have <a href="https://huggingface.co/spaces/wouaf/WOUAF-Text-to-Image">transferred the demo</a> to a better and faster GPU. </h1>
32
+ </div>
33
+ """
34
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  if __name__ == "__main__":
38
  demo.launch()
app_demo.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ import torch
5
+ import re
6
+ import os
7
+ import requests
8
+
9
+ from customization import customize_vae_decoder
10
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler, EulerDiscreteScheduler
11
+ from torchvision import transforms
12
+ from attribution import MappingNetwork
13
+
14
+ import math
15
+ from typing import List
16
+ from PIL import Image, ImageChops
17
+ import numpy as np
18
+ import torch
19
+
20
+
21
+ PRETRAINED_MODEL_NAME_OR_PATH = "./checkpoints/"
22
+
23
+
24
+ def get_image_grid(images: List[Image.Image]) -> Image:
25
+ num_images = len(images)
26
+ cols = 3#int(math.ceil(math.sqrt(num_images)))
27
+ rows = 1#int(math.ceil(num_images / cols))
28
+ width, height = images[0].size
29
+ grid_image = Image.new('RGB', (cols * width, rows * height))
30
+ for i, img in enumerate(images):
31
+ x = i % cols
32
+ y = i // cols
33
+ grid_image.paste(img, (x * width, y * height))
34
+ return grid_image
35
+
36
+
37
+ class AttributionModel:
38
+ def __init__(self):
39
+ is_cuda = False
40
+ if torch.cuda.is_available():
41
+ is_cuda = True
42
+
43
+ scheduler = EulerDiscreteScheduler.from_pretrained('stabilityai/stable-diffusion-2', subfolder="scheduler")
44
+ self.pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2', scheduler=scheduler)#, safety_checker=None, torch_dtype=torch.float16)
45
+ if is_cuda:
46
+ self.pipe = self.pipe.to("cuda")
47
+ self.resize_transform = transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR)
48
+ self.vae = AutoencoderKL.from_pretrained(
49
+ 'stabilityai/stable-diffusion-2', subfolder="vae"
50
+ )
51
+ self.vae = customize_vae_decoder(self.vae, 128, "deqkv", "all", False, 1.0)
52
+
53
+ self.mapping_network = MappingNetwork(32, 0, 128, None, num_layers=2, w_avg_beta=None, normalization = False)
54
+
55
+ from torchvision.models import resnet50, ResNet50_Weights
56
+ self.decoding_network = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
57
+ self.decoding_network.fc = torch.nn.Linear(2048,32)
58
+
59
+ self.vae.decoder.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'vae_decoder.pth')))
60
+ self.mapping_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'mapping_network.pth')))
61
+ self.decoding_network.load_state_dict(torch.load(os.path.join(PRETRAINED_MODEL_NAME_OR_PATH, 'decoding_network.pth')))
62
+
63
+ if is_cuda:
64
+ self.vae = self.vae.to("cuda")
65
+ self.mapping_network = self.mapping_network.to("cuda")
66
+ self.decoding_network = self.decoding_network.to("cuda")
67
+
68
+ self.test_norm = transforms.Compose(
69
+ [
70
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
71
+ ]
72
+ )
73
+
74
+ def infer(self, prompt, negative, steps, guidance_scale):
75
+ with torch.no_grad():
76
+ out_latents = self.pipe([prompt], negative_prompt=[negative], output_type="latent", num_inference_steps=steps, guidance_scale=guidance_scale).images
77
+ image_attr = self.inference_with_attribution(out_latents)
78
+ image_attr_pil = self.pipe.numpy_to_pil(image_attr[0])
79
+
80
+ image_org = self.inference_without_attribution(out_latents)
81
+ image_org_pil = self.pipe.numpy_to_pil(image_org[0])
82
+
83
+ # image_diff_pil = self.pipe.numpy_to_pil(image_attr[0] - image_org[0])
84
+ diff_factor = 5
85
+ image_diff_pil = ImageChops.difference(image_org_pil[0], image_attr_pil[0]).convert("RGB", (diff_factor,0,0,0,0,diff_factor,0,0,0,0,diff_factor,0))
86
+
87
+ return image_org_pil[0], image_attr_pil[0], image_diff_pil
88
+
89
+ def inference_without_attribution(self, latents):
90
+ latents = 1 / 0.18215 * latents
91
+ with torch.no_grad():
92
+ image = self.pipe.vae.decode(latents).sample
93
+ image = image.clamp(-1,1)
94
+ image = (image / 2 + 0.5).clamp(0, 1)
95
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
96
+ return image
97
+
98
+ def get_phis(self, phi_dimension, batch_size ,eps = 1e-8):
99
+ phi_length = phi_dimension
100
+ b = batch_size
101
+ phi = torch.empty(b,phi_length).uniform_(0,1)
102
+ return torch.bernoulli(phi) + eps
103
+
104
+
105
+ def inference_with_attribution(self, latents, key=None):
106
+ if key==None:
107
+ key = self.get_phis(32, 1)
108
+
109
+ latents = 1 / 0.18215 * latents
110
+ with torch.no_grad():
111
+ image = self.vae.decode(latents, self.mapping_network(key.cuda())).sample
112
+ image = image.clamp(-1,1)
113
+ image = (image / 2 + 0.5).clamp(0, 1)
114
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
115
+ return image
116
+
117
+ def postprocess(self, image):
118
+ image = self.resize_transform(image)
119
+ return image
120
+
121
+ def detect_key(self, image):
122
+ reconstructed_keys = self.decoding_network(self.test_norm((image / 2 + 0.5).clamp(0, 1)))
123
+ return reconstructed_keys
124
+
125
+
126
+ attribution_model = AttributionModel()
127
+ def get_images(prompt, negative, steps, guidence_scale):
128
+ x1, x2, x3 = attribution_model.infer(prompt, negative, steps, guidence_scale)
129
+ return [x1, x2, x3]
130
+
131
+
132
+ image_examples = [
133
+ ["A pikachu fine dining with a view to the Eiffel Tower", "low quality", 50, 10],
134
+ ["A mecha robot in a favela in expressionist style", "low quality, 3d, photorealistic", 50, 10]
135
+ ]
136
+
137
+ with gr.Blocks() as demo:
138
+ gr.Markdown(
139
+ """<h1 style="text-align: center;"><b>WOUAF:
140
+ Weight Modulation for User Attribution and Fingerprinting in Text-to-Image Diffusion Models</b> <br> <a href="https://wouaf.vercel.app">Project Page</a></h1>""")
141
+
142
+ with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
143
+ with gr.Column():
144
+ text = gr.Textbox(
145
+ label="Enter your prompt",
146
+ show_label=False,
147
+ max_lines=1,
148
+ placeholder="Enter your prompt",
149
+ elem_id="prompt-text-input",
150
+ ).style(
151
+ border=(True, False, True, True),
152
+ rounded=(True, False, False, True),
153
+ container=False,
154
+ )
155
+ negative = gr.Textbox(
156
+ label="Enter your negative prompt",
157
+ show_label=False,
158
+ max_lines=1,
159
+ placeholder="Enter a negative prompt",
160
+ elem_id="negative-prompt-text-input",
161
+ ).style(
162
+ border=(True, False, True, True),
163
+ rounded=(True, False, False, True),
164
+ container=False,
165
+ )
166
+
167
+ with gr.Row():
168
+ steps = gr.Slider(label="Steps", minimum=45, maximum=55, value=50, step=1)
169
+ guidance_scale = gr.Slider(
170
+ label="Guidance Scale", minimum=0, maximum=10, value=7.5, step=0.1
171
+ )
172
+
173
+ with gr.Row():
174
+ btn = gr.Button(value="Generate Image", full_width=False)
175
+
176
+ with gr.Row():
177
+ im_2 = gr.Image(type="pil", label="without attribution")
178
+ im_3 = gr.Image(type="pil", label="**with** attribution")
179
+ im_4 = gr.Image(type="pil", label="pixel-wise difference multiplied by 5")
180
+
181
+
182
+ btn.click(get_images, inputs=[text, negative, steps, guidance_scale], outputs=[im_2, im_3, im_4])
183
+
184
+ gr.Examples(
185
+ examples=image_examples,
186
+ inputs=[text, negative, steps, guidance_scale],
187
+ outputs=[im_2, im_3, im_4],
188
+ fn=get_images,
189
+ cache_examples=True,
190
+ )
191
+
192
+ gr.HTML(
193
+ """
194
+ <div class="footer">
195
+ <p>Pre-trained model by <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">StabilityAI</a>
196
+ </p>
197
+ <p>
198
+ Fine-tuned by authors for research purpose.
199
+ </p>
200
+ </div>
201
+ """
202
+ )
203
+ with gr.Accordion(label="Ethics & Privacy", open=False):
204
+ gr.HTML(
205
+ """<div class="acknowledgments">
206
+ <p><h4>Privacy</h4>
207
+ We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI.
208
+ <p><h4>Biases and content acknowledgment</h4>
209
+ This model will have the same biases as Stable Diffusion V2.1 </div>
210
+ """
211
+ )
212
+
213
+ if __name__ == "__main__":
214
+ demo.launch()