tokeron commited on
Commit
d415ad5
·
verified ·
1 Parent(s): 2a99ad2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -8
app.py CHANGED
@@ -1,20 +1,115 @@
1
  import gradio as gr
2
  from diffusion_lens import get_images
3
 
 
4
 
5
- def generate_images(prompt):
6
- print('calling diffusion lens')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  gr.Info('Generating images from intermediate layers..')
8
  all_images = [] # Initialize a list to store all images
9
- for skip_layers in range(11, -1, -1):
10
- images = get_images(prompt, skip_layers=skip_layers)
11
- all_images.append((images[0], f'layer_{12 - skip_layers}'))
 
 
12
  yield all_images
13
 
14
  with gr.Blocks() as demo:
15
- text_input = gr.Textbox(label="Enter prompt")
 
 
 
 
 
 
16
  gallery = gr.Gallery(label="Generated Images", columns=6, rows=2, object_fit="contain", height="auto")
17
- text_input.submit(fn=generate_images, inputs=text_input, outputs=gallery)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- demo.launch()
20
 
 
 
 
1
  import gradio as gr
2
  from diffusion_lens import get_images
3
 
4
+ MAX_SEED = np.iinfo(np.int32).max
5
 
6
+ # Description
7
+ title = r"""
8
+ <h1 align="center">Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</h1>
9
+ """
10
+
11
+ description = r"""
12
+ <b>Based on the paper <a href='https://arxiv.org/abs/2403.05846' target='_blank'>InstantStyle: Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</a>.<br>
13
+ """
14
+
15
+ article = r"""
16
+ ---
17
+ 📝 **Citation**
18
+ <br>
19
+ If our work is helpful for your research or applications, please cite us via:
20
+ ```bibtex
21
+ @article{toker2024diffusion,
22
+ title={Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines},
23
+ author={Toker, Michael and Orgad, Hadas and Ventura, Mor and Arad, Dana and Belinkov, Yonatan},
24
+ journal={arXiv preprint arXiv:2403.05846},
25
+ year={2024}
26
+ }
27
+ }
28
+ ```
29
+ 📧 **Contact**
30
+ <br>
31
+ If you have any questions, please feel free to open an issue or directly reach us out at <b>tok@cs.technuin.ac.il</b>.
32
+ """
33
+
34
+
35
+ model_num_of_layers = {
36
+ 'Stable Diffusion 1.4': 12,
37
+ 'Stable Diffusion 2.1': 22,
38
+ }
39
+
40
+ def generate_images(prompt, model, seed):
41
+ print('calling diffusion lens with model:', model, 'and seed:', seed)
42
  gr.Info('Generating images from intermediate layers..')
43
  all_images = [] # Initialize a list to store all images
44
+ max_num_of_layers = model_num_of_layers[model]
45
+ for skip_layers in range(max_num_of_layers, -1, -1):
46
+ # Pass the model and seed to the get_images function
47
+ images = get_images(prompt, skip_layers=skip_layers, model=model, seed=seed)
48
+ all_images.append((images[0], f'layer_{12 - skip_layers}'))
49
  yield all_images
50
 
51
  with gr.Blocks() as demo:
52
+
53
+ gr.Markdown(title)
54
+ gr.Markdown(description)
55
+
56
+ # text_input = gr.Textbox(label="Enter prompt")
57
+ model_select = gr.Dropdown(label="Select Model", choices=['sd1', 'sd2'])
58
+ seed_input = gr.Number(label="Enter Seed", value=0) # Default seed set to 0
59
  gallery = gr.Gallery(label="Generated Images", columns=6, rows=2, object_fit="contain", height="auto")
60
+ # Update the submit function to include the new inputs
61
+
62
+
63
+ # text_input.submit(fn=generate_images, inputs=[text_input, model_select, seed_input], outputs=gallery)
64
+
65
+ with gr.Column():
66
+ prompt = gr.Textbox(
67
+ label="Prompt",
68
+ value="a cat, masterpiece, best quality, high quality",
69
+ )
70
+
71
+ model = gr.Radio(
72
+ [
73
+ "Stable Diffusion 1.4",
74
+ "Stable Diffusion 2.1",
75
+ ],
76
+ value="Stable Diffusion 1.4",
77
+ label="Model",
78
+ )
79
+
80
+ seed = gr.Slider(
81
+ minimum=-1,
82
+ maximum=MAX_SEED,
83
+ value=-1,
84
+ step=1,
85
+ label="Seed Value",
86
+ )
87
+
88
+ inputs = [
89
+ prompt,
90
+ model,
91
+ seed,
92
+ ]
93
+ outputs = [gallery]
94
+
95
+ gr.on(
96
+ triggers=[
97
+ prompt.input,
98
+ generate_button.click,
99
+ guidance_scale.input,
100
+ scale.input,
101
+ control_scale.input,
102
+ seed.input,
103
+ ],
104
+ fn=generate_images,
105
+ inputs=inputs,
106
+ outputs=outputs,
107
+ show_progress="full",
108
+ show_api=False,
109
+ trigger_mode="always_last",
110
+ )
111
 
112
+ gr.Markdown(article)
113
 
114
+ block.queue(api_open=False)
115
+ block.launch(show_api=False)