piyushgrover commited on
Commit
8583dc9
Β·
1 Parent(s): 012daaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -99
app.py CHANGED
@@ -1,58 +1,69 @@
1
  import gradio as gr
2
- #from utils import *
3
  import random
4
 
5
-
6
  is_clicked = False
7
  out_img_list = ['', '', '', '', '']
8
- out_state_list = [False, False, False, False, False]
 
 
9
 
10
  def fn_query_on_load():
11
  return "Cats at sunset"
12
 
13
  def fn_refresh():
14
-
15
  return out_img_list
16
 
17
-
18
  with gr.Blocks() as app:
19
  with gr.Row():
20
  gr.Markdown(
21
  """
22
  # Stable Diffusion Image Generation
23
- ### Enter query to generate images in various styles
24
  """)
25
 
26
  with gr.Row(visible=True):
27
- with gr.Column():
28
- with gr.Row():
29
- search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None)
30
-
31
-
32
- with gr.Row(visible=True):
33
- #with gr.Column():
34
- out1 = gr.Image(value="out1.png", interactive=False, width=128, label='Oil Painting')
35
- #submit1 = gr.Button("Submit", variant='primary')
36
-
37
- #with gr.Column():
38
- out2 = gr.Image(value="out2.png", interactive=False, width=128, label='Low Poly HD Style')
39
- #submit2 = gr.Button("Submit", variant='primary')
40
-
41
- #with gr.Column():
42
- out3 = gr.Image(value="out3.png", interactive=False, width=128, label='Matrix style')
43
- #submit3 = gr.Button("Submit", variant='primary')
44
-
45
-
46
- #with gr.Column():
47
- out4 = gr.Image(value="out4.png", interactive=False, width=128, label='Dreamy Painting')
48
- #submit4 = gr.Button("Submit", variant='primary')
49
-
50
- #with gr.Column():
51
- out5 = gr.Image(value="out5.png", interactive=False, width=128, label='Depth Map Style')
52
- #submit5 = gr.Button("Submit", variant='primary')
53
-
54
- with gr.Row(visible=True):
55
- clear_btn = gr.ClearButton()
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def clear_data():
58
  return {
@@ -64,11 +75,21 @@ with gr.Blocks() as app:
64
  search_text: None
65
  }
66
 
 
 
 
 
 
 
 
 
 
67
 
68
  clear_btn.click(clear_data, None, [out1, out2, out3, out4, out5, search_text])
 
69
 
70
 
71
- '''def func_generate(query, concept_idx, seed):
72
  prompt = query + ' in the style of bulb'
73
  text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
74
  return_tensors="pt")
@@ -78,7 +99,7 @@ with gr.Blocks() as app:
78
  position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
79
  position_embeddings = pos_emb_layer(position_ids)
80
 
81
- s = seed
82
 
83
  token_embeddings = token_emb_layer(input_ids)
84
  # The new embedding - our special birb word
@@ -95,72 +116,141 @@ with gr.Blocks() as app:
95
 
96
  # And generate an image with this:
97
 
98
- s = random.randint(s + 1, s + 30)
 
 
 
 
 
99
  g = torch.manual_seed(s)
100
- return generate_with_embs(text_input, modified_output_embeddings, generator=g)
101
-
102
-
103
- def generate_oil_painting(query):
104
- return {
105
- out1: func_generate(query, 0, 0)
106
- }
107
-
108
- def generate_low_poly_hd(query):
109
- return {
110
- out2: func_generate(query, 1, 30)
111
- }
112
-
113
- def generate_matrix_style(query):
114
- return {
115
- out3: func_generate(query, 2, 60)
116
- }
117
-
118
- def generate_dreamy_painting(query):
119
- return {
120
- out4: func_generate(query, 3, 90)
121
- }
122
-
123
- def generate_depth_map_style(query):
124
- return {
125
- out5: func_generate(query, 4, 120)
126
- }
127
-
128
- submit1.click(
129
- generate_oil_painting,
130
- search_text,
131
- out1
132
- )
133
-
134
- submit2.click(
135
- generate_low_poly_hd,
136
- search_text,
137
- out2
138
- )
139
-
140
- submit3.click(
141
- generate_matrix_style,
142
- search_text,
143
- out3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
-
146
- submit4.click(
147
- generate_dreamy_painting,
148
- search_text,
149
- out4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  )
151
 
152
- submit5.click(
153
- generate_depth_map_style,
154
- search_text,
155
- out5
156
- )
157
- '''
158
-
159
  '''
160
  Launch the app
161
  '''
162
- app.launch()
163
-
164
-
165
-
166
-
 
1
  import gradio as gr
2
+ from utils import *
3
  import random
4
 
 
5
  is_clicked = False
6
  out_img_list = ['', '', '', '', '']
7
+ out_state_list = ['', '', '', '', '']
8
+ out_state_list2 = ['', '', '', '', '']
9
+ seed_values = [0, 0, 0, 0, 0]
10
 
11
  def fn_query_on_load():
12
  return "Cats at sunset"
13
 
14
  def fn_refresh():
 
15
  return out_img_list
16
 
 
17
  with gr.Blocks() as app:
18
  with gr.Row():
19
  gr.Markdown(
20
  """
21
  # Stable Diffusion Image Generation
22
+ ### Enter prompt to generate images in various styles
23
  """)
24
 
25
  with gr.Row(visible=True):
26
+ search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter image prompt..', label='Enter Image Prompt')
27
+
28
+ with gr.Tab('Generate Image in various styles'):
29
+
30
+ with gr.Row():
31
+ concept_index = gr.Dropdown(label='Select image style', value='Oil Painting', type="index", choices=['Oil Painting', 'Low Poly HD Style', 'Matrix Style', 'Dreamy Painting', 'Depth Map Style'] )
32
+
33
+ with gr.Row(visible=True):
34
+ out1 = gr.Image(value="out1.png", interactive=False, width=128, height=128, label='Oil Painting')
35
+ out2 = gr.Image(value="out2.png", interactive=False, width=128, height=128, label='Low Poly HD Style')
36
+ out3 = gr.Image(value="out3.png", interactive=False, width=128, height=128, label='Matrix Style')
37
+ out4 = gr.Image(value="out4.png", interactive=False, width=128, height=128, label='Dreamy Painting')
38
+ out5 = gr.Image(value="out5.png", interactive=False, width=128, height=128, label='Depth Map Style')
39
+
40
+ with gr.Row(visible=True):
41
+ submit_btn = gr.Button("Submit", variant='primary')
42
+ clear_btn = gr.ClearButton()
43
+
44
+ with gr.Tab("Additional Guidance with Contrast Adjustment"):
45
+
46
+ with gr.Row():
47
+ gr.Markdown(
48
+ """
49
+ ### Experiment with contrast based additional guidance to view how it affects the output
50
+ """)
51
+
52
+ with gr.Row():
53
+ concept_index2 = gr.Dropdown(label='Select image style', value='Oil Painting', type="index", choices=['Oil Painting', 'Low Poly HD Style', 'Matrix Style', 'Dreamy Painting', 'Depth Map Style'] )
54
+ contrast_perc = gr.Slider(value=90, minimum=-100, maximum=100, label='Contrast Adjustment')
55
+
56
+
57
+ with gr.Row(visible=True):
58
+ out11 = gr.Image(value="out11.png", interactive=False, width=128, height=128, label='Oil Painting')
59
+ out12 = gr.Image(value="out12.png", interactive=False, width=128, height=128, label='Low Poly HD Style')
60
+ out13 = gr.Image(value="out13.png", interactive=False, width=128, height=128, label='Matrix Style')
61
+ out14 = gr.Image(value="out14.png", interactive=False, width=128, height=128, label='Dreamy Painting')
62
+ out15 = gr.Image(value="out15.png", interactive=False, width=128, height=128, label='Depth Map Style')
63
+
64
+ with gr.Row(visible=True):
65
+ submit_btn2 = gr.Button("Submit", variant='primary')
66
+ clear_btn2 = gr.ClearButton()
67
 
68
  def clear_data():
69
  return {
 
75
  search_text: None
76
  }
77
 
78
+ def clear_data2():
79
+ return {
80
+ out11: None,
81
+ out12: None,
82
+ out13: None,
83
+ out14: None,
84
+ out15: None
85
+ }
86
+
87
 
88
  clear_btn.click(clear_data, None, [out1, out2, out3, out4, out5, search_text])
89
+ clear_btn2.click(clear_data2, None, [out11, out12, out13, out14, out15])
90
 
91
 
92
+ def func_generate(query, concept_idx, seed_start, contrast_loss=False, contrast_perc=None):
93
  prompt = query + ' in the style of bulb'
94
  text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
95
  return_tensors="pt")
 
99
  position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
100
  position_embeddings = pos_emb_layer(position_ids)
101
 
102
+ s = seed_start
103
 
104
  token_embeddings = token_emb_layer(input_ids)
105
  # The new embedding - our special birb word
 
116
 
117
  # And generate an image with this:
118
 
119
+ if contrast_loss and seed_values[concept_idx] > 0:
120
+ s = seed_values[concept_idx]
121
+ else:
122
+ s = random.randint(s + 1, s + 30)
123
+ seed_values[concept_idx] = s
124
+
125
  g = torch.manual_seed(s)
126
+ return generate_with_embs(text_input, modified_output_embeddings, generator=g, contrast_loss=contrast_loss, contrast_perc=contrast_perc)
127
+
128
+
129
+ def generate_image(query, con_idx, o1, o2, o3, o4, o5, contrast):
130
+ if not query:
131
+ raise gr.Error("No prompt provided")
132
+ return {
133
+ out1: o1,
134
+ out2: o2,
135
+ out3: o3,
136
+ out4: o4,
137
+ out5: o5
138
+ }
139
+ else:
140
+ out = func_generate(query, con_idx, con_idx*30)
141
+ out_state_list[con_idx] = query
142
+
143
+ if con_idx == 0:
144
+ return {
145
+ out1: out,
146
+ out2: None if out_state_list[1] != query else o2,
147
+ out3: None if out_state_list[2] != query else o3,
148
+ out4: None if out_state_list[3] != query else o4,
149
+ out5: None if out_state_list[4] != query else o5
150
+ }
151
+ elif con_idx == 1:
152
+ return {
153
+ out1: None if out_state_list[0] != query else o1,
154
+ out2: out,
155
+ out3: None if out_state_list[2] != query else o2,
156
+ out4: None if out_state_list[3] != query else o3,
157
+ out5: None if out_state_list[4] != query else o4
158
+ }
159
+ elif con_idx == 2:
160
+ return {
161
+ out1: None if out_state_list[0] != query else o1,
162
+ out2: None if out_state_list[1] != query else o2,
163
+ out3: out,
164
+ out4: None if out_state_list[3] != query else o3,
165
+ out5: None if out_state_list[4] != query else o4
166
+ }
167
+ elif con_idx == 3:
168
+ return {
169
+ out1: None if out_state_list[0] != query else o1,
170
+ out2: None if out_state_list[1] != query else o2,
171
+ out3: None if out_state_list[2] != query else o3,
172
+ out4: out,
173
+ out5: None if out_state_list[4] != query else o4
174
+ }
175
+ elif con_idx == 4:
176
+ return {
177
+ out1: None if out_state_list[0] != query else o1,
178
+ out2: None if out_state_list[1] != query else o2,
179
+ out3: None if out_state_list[2] != query else o3,
180
+ out4: None if out_state_list[3] != query else o4,
181
+ out5: out
182
+ }
183
+
184
+
185
+
186
+ submit_btn.click(
187
+ generate_image,
188
+ [search_text, concept_index, out1, out2, out3, out4, out5],
189
+ [out1, out2, out3, out4, out5]
190
  )
191
+
192
+ def generate_image_with_contrast_loss_guidance(query, con_idx, o1, o2, o3, o4, o5, contrast):
193
+ if not query:
194
+ raise gr.Error("No prompt provided")
195
+ return {
196
+ out11: o1,
197
+ out12: o2,
198
+ out13: o3,
199
+ out14: o4,
200
+ out15: o5
201
+ }
202
+ else:
203
+ out = func_generate(query, con_idx, con_idx*30, contrast_loss=True, contrast_perc=contrast)
204
+ out_state_list[con_idx] = query
205
+
206
+ if con_idx == 0:
207
+ return {
208
+ out11: out,
209
+ out12: None if out_state_list[1] != query else o2,
210
+ out13: None if out_state_list[2] != query else o3,
211
+ out14: None if out_state_list[3] != query else o4,
212
+ out15: None if out_state_list[4] != query else o5
213
+ }
214
+ elif con_idx == 1:
215
+ return {
216
+ out11: None if out_state_list[0] != query else o1,
217
+ out12: out,
218
+ out13: None if out_state_list[2] != query else o2,
219
+ out14: None if out_state_list[3] != query else o3,
220
+ out15: None if out_state_list[4] != query else o4
221
+ }
222
+ elif con_idx == 2:
223
+ return {
224
+ out11: None if out_state_list[0] != query else o1,
225
+ out12: None if out_state_list[1] != query else o2,
226
+ out13: out,
227
+ out14: None if out_state_list[3] != query else o3,
228
+ out15: None if out_state_list[4] != query else o4
229
+ }
230
+ elif con_idx == 3:
231
+ return {
232
+ out11: None if out_state_list[0] != query else o1,
233
+ out12: None if out_state_list[1] != query else o2,
234
+ out13: None if out_state_list[2] != query else o3,
235
+ out14: out,
236
+ out15: None if out_state_list[4] != query else o4
237
+ }
238
+ elif con_idx == 4:
239
+ return {
240
+ out11: None if out_state_list[0] != query else o1,
241
+ out12: None if out_state_list[1] != query else o2,
242
+ out13: None if out_state_list[2] != query else o3,
243
+ out14: None if out_state_list[3] != query else o4,
244
+ out15: out
245
+ }
246
+
247
+ submit_btn2.click(
248
+ generate_image_with_contrast_loss_guidance,
249
+ [search_text, concept_index2, out11, out12, out13, out14, out15, contrast_perc],
250
+ [out11, out12, out13, out14, out15]
251
  )
252
 
 
 
 
 
 
 
 
253
  '''
254
  Launch the app
255
  '''
256
+ app.launch()