Linoy Tsaban commited on
Commit
c633c03
1 Parent(s): 9cd1904

Update app.py

Browse files

caption image with BLIP

Files changed (1) hide show
  1. app.py +91 -78
app.py CHANGED
@@ -4,16 +4,37 @@ import numpy as np
4
  import requests
5
  import random
6
  from io import BytesIO
7
- from diffusers import StableDiffusionPipeline
8
- from diffusers import DDIMScheduler
9
  from utils import *
10
  from inversion_utils import *
11
  from modified_pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
12
  from torch import autocast, inference_mode
13
- import re
 
 
14
 
 
 
 
 
 
 
 
 
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
18
 
19
  # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
@@ -35,7 +56,6 @@ def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta
35
  return zs, wts
36
 
37
 
38
-
39
  def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
40
 
41
  # reverse process (via Zs and wT)
@@ -49,85 +69,13 @@ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
49
  img = image_grid(x0_dec)
50
  return img
51
 
52
- # load pipelines
53
- sd_model_id = "stabilityai/stable-diffusion-2-base"
54
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
- sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
56
- sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
57
- sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
58
-
59
-
60
- def get_example():
61
- case = [
62
- [
63
- 'examples/source_a_cat_sitting_next_to_a_mirror.jpeg',
64
- 'a cat sitting next to a mirror',
65
- 'watercolor painting of a cat sitting next to a mirror',
66
- 100,
67
- 36,
68
- 15,
69
- 'Schnauzer dog', 'cat',
70
- 5.5,
71
- 1,
72
- 'examples/ddpm_sega_watercolor_painting_a_cat_sitting_next_to_a_mirror_plus_dog_minus_cat.png'
73
- ],
74
- [
75
- 'examples/source_a_man_wearing_a_brown_hoodie_in_a_crowded_street.jpeg',
76
- 'a man wearing a brown hoodie in a crowded street',
77
- 'a robot wearing a brown hoodie in a crowded street',
78
- 100,
79
- 36,
80
- 15,
81
- 'painting','',
82
- 10,
83
- 1,
84
- 'examples/ddpm_sega_painting_of_a_robot_wearing_a_brown_hoodie_in_a_crowded_street.png'
85
- ],
86
- [
87
- 'examples/source_wall_with_framed_photos.jpeg',
88
- '',
89
- '',
90
- 100,
91
- 36,
92
- 15,
93
- 'pink drawings of muffins','',
94
- 10,
95
- 1,
96
- 'examples/ddpm_sega_plus_pink_drawings_of_muffins.png'
97
- ],
98
- [
99
- 'examples/source_an_empty_room_with_concrete_walls.jpg',
100
- 'an empty room with concrete walls',
101
- 'glass walls',
102
- 100,
103
- 36,
104
- 17,
105
- 'giant elephant','',
106
- 10,
107
- 1,
108
- 'examples/ddpm_sega_glass_walls_gian_elephant.png'
109
- ]]
110
- return case
111
-
112
- def randomize_seed_fn(seed, randomize_seed):
113
- if randomize_seed:
114
- seed = random.randint(0, np.iinfo(np.int32).max)
115
- torch.manual_seed(seed)
116
- return seed
117
-
118
-
119
-
120
 
121
  def reconstruct(tar_prompt,
122
  tar_cfg_scale,
123
  skip,
124
  wts, zs,
125
- # do_reconstruction,
126
- # reconstruction
127
  ):
128
 
129
-
130
- # if do_reconstruction:
131
  reconstruction = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
132
  return reconstruction
133
 
@@ -158,6 +106,7 @@ def load_and_invert(
158
 
159
  return wts, zs, do_inversion
160
 
 
161
 
162
  def edit(input_image,
163
  wts, zs,
@@ -197,6 +146,66 @@ def edit(input_image,
197
 
198
 
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  ########
202
  # demo #
@@ -346,6 +355,7 @@ with gr.Blocks(css='style.css') as demo:
346
 
347
 
348
  with gr.Row():
 
349
  run_button = gr.Button("Run")
350
  reconstruct_button = gr.Button("Show Reconstruction", visible=False)
351
 
@@ -366,11 +376,14 @@ with gr.Blocks(css='style.css') as demo:
366
  with gr.Accordion("Help", open=False):
367
  gr.Markdown(help_text)
368
 
369
-
 
 
 
 
370
 
371
  add_concept_button.click(fn = add_concept, inputs=sega_concepts_counter,
372
  outputs= [row2, row3, add_concept_button, sega_concepts_counter], queue = False)
373
-
374
 
375
  run_button.click(
376
  fn = randomize_seed_fn,
 
4
  import requests
5
  import random
6
  from io import BytesIO
 
 
7
  from utils import *
8
  from inversion_utils import *
9
  from modified_pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
10
  from torch import autocast, inference_mode
11
+ from diffusers import StableDiffusionPipeline
12
+ from diffusers import DDIMScheduler
13
+ from transformers import AutoProcessor, BlipForConditionalGeneration
14
 
15
+ # load pipelines
16
+ sd_model_id = "stabilityai/stable-diffusion-2-base"
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
19
+ sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
20
+ sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
21
+ blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
22
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
23
 
24
 
25
+
26
+ ## IMAGE CPATIONING ##
27
+ def caption_image(input_image):
28
+
29
+ inputs = blip_processor(images=image, return_tensors="pt")
30
+ pixel_values = inputs.pixel_values
31
+
32
+ generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
33
+ generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
+ return generated_caption
35
+
36
+
37
+ ## DDPM INVERSION AND SAMPLING ##
38
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
39
 
40
  # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
 
56
  return zs, wts
57
 
58
 
 
59
  def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
60
 
61
  # reverse process (via Zs and wT)
 
69
  img = image_grid(x0_dec)
70
  return img
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def reconstruct(tar_prompt,
74
  tar_cfg_scale,
75
  skip,
76
  wts, zs,
 
 
77
  ):
78
 
 
 
79
  reconstruction = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
80
  return reconstruction
81
 
 
106
 
107
  return wts, zs, do_inversion
108
 
109
+ ## SEGA ##
110
 
111
  def edit(input_image,
112
  wts, zs,
 
146
 
147
 
148
 
149
+ def randomize_seed_fn(seed, randomize_seed):
150
+ if randomize_seed:
151
+ seed = random.randint(0, np.iinfo(np.int32).max)
152
+ torch.manual_seed(seed)
153
+ return seed
154
+
155
+
156
+ def get_example():
157
+ case = [
158
+ [
159
+ 'examples/source_a_cat_sitting_next_to_a_mirror.jpeg',
160
+ 'a cat sitting next to a mirror',
161
+ 'watercolor painting of a cat sitting next to a mirror',
162
+ 100,
163
+ 36,
164
+ 15,
165
+ 'Schnauzer dog', 'cat',
166
+ 5.5,
167
+ 1,
168
+ 'examples/ddpm_sega_watercolor_painting_a_cat_sitting_next_to_a_mirror_plus_dog_minus_cat.png'
169
+ ],
170
+ [
171
+ 'examples/source_a_man_wearing_a_brown_hoodie_in_a_crowded_street.jpeg',
172
+ 'a man wearing a brown hoodie in a crowded street',
173
+ 'a robot wearing a brown hoodie in a crowded street',
174
+ 100,
175
+ 36,
176
+ 15,
177
+ 'painting','',
178
+ 10,
179
+ 1,
180
+ 'examples/ddpm_sega_painting_of_a_robot_wearing_a_brown_hoodie_in_a_crowded_street.png'
181
+ ],
182
+ [
183
+ 'examples/source_wall_with_framed_photos.jpeg',
184
+ '',
185
+ '',
186
+ 100,
187
+ 36,
188
+ 15,
189
+ 'pink drawings of muffins','',
190
+ 10,
191
+ 1,
192
+ 'examples/ddpm_sega_plus_pink_drawings_of_muffins.png'
193
+ ],
194
+ [
195
+ 'examples/source_an_empty_room_with_concrete_walls.jpg',
196
+ 'an empty room with concrete walls',
197
+ 'glass walls',
198
+ 100,
199
+ 36,
200
+ 17,
201
+ 'giant elephant','',
202
+ 10,
203
+ 1,
204
+ 'examples/ddpm_sega_glass_walls_gian_elephant.png'
205
+ ]]
206
+ return case
207
+
208
+
209
 
210
  ########
211
  # demo #
 
355
 
356
 
357
  with gr.Row():
358
+ caption_button = gr.Button("Caption Image")
359
  run_button = gr.Button("Run")
360
  reconstruct_button = gr.Button("Show Reconstruction", visible=False)
361
 
 
376
  with gr.Accordion("Help", open=False):
377
  gr.Markdown(help_text)
378
 
379
+ caption_button.click(
380
+ fn = caption_image,
381
+ inputs = [input_image],
382
+ outputs = [tar_prompt]
383
+ )
384
 
385
  add_concept_button.click(fn = add_concept, inputs=sega_concepts_counter,
386
  outputs= [row2, row3, add_concept_button, sega_concepts_counter], queue = False)
 
387
 
388
  run_button.click(
389
  fn = randomize_seed_fn,