tow252 commited on
Commit
a8f966d
1 Parent(s): 003ec88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -96
app.py CHANGED
@@ -14,6 +14,9 @@ import os
14
  import streamlit as st
15
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
16
 
 
 
 
17
 
18
 
19
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
@@ -83,74 +86,17 @@ def generate_image(image, product_name, target_name):
83
  generator=generator,
84
  ).images
85
 
 
86
  return im
87
-
88
-
89
-
90
- def translate_sentence(article, source, target):
91
- if target == 'eng_Latn':
92
- return article
93
- translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source, tgt_lang=target)
94
- output = translator(article, max_length=400)
95
- output = output[0]['translation_text']
96
- return output
97
-
98
-
99
- codes_as_string = '''Modern Standard Arabic arb_Arab
100
- Danish dan_Latn
101
- German deu_Latn
102
- Greek ell_Grek
103
- English eng_Latn
104
- Estonian est_Latn
105
- Finnish fin_Latn
106
- French fra_Latn
107
- Hebrew heb_Hebr
108
- Hindi hin_Deva
109
- Croatian hrv_Latn
110
- Hungarian hun_Latn
111
- Indonesian ind_Latn
112
- Icelandic isl_Latn
113
- Italian ita_Latn
114
- Japanese jpn_Jpan
115
- Korean kor_Hang
116
- Luxembourgish ltz_Latn
117
- Macedonian mkd_Cyrl
118
- Maltese mlt_Latn
119
- Dutch nld_Latn
120
- Norwegian Bokmål nob_Latn
121
- Polish pol_Latn
122
- Portuguese por_Latn
123
- Russian rus_Cyrl
124
- Slovak slk_Latn
125
- Slovenian slv_Latn
126
- Spanish spa_Latn
127
- Serbian srp_Cyrl
128
- Swedish swe_Latn
129
- Thai tha_Thai
130
- Turkish tur_Latn
131
- Ukrainian ukr_Cyrl
132
- Vietnamese vie_Latn
133
- Chinese (Simplified) zho_Hans'''
134
-
135
- codes_as_string = codes_as_string.split('\n')
136
-
137
- flores_codes = {}
138
- for code in codes_as_string:
139
- lang, lang_code = code.split('\t')
140
- flores_codes[lang] = lang_code
141
-
142
-
143
 
144
- import gradio as gr
145
- import gc
146
  gc.collect()
 
147
 
148
  image_label = 'Please upload the image (optional)'
149
- extract_label = 'Specify what need to be extracted from the above image'
150
  prompt_label = 'Specify the description of image to be generated'
151
  button_label = "Proceed"
152
- output_label = "Generations"
153
-
154
 
155
  shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long']
156
  shot_label = 'Choose the shot type'
@@ -171,56 +117,40 @@ device_services = ['iphone', 'CCTV', 'Nikon ZFX','Canon', 'Gopro']
171
  device_label = 'Choose the device type'
172
 
173
 
174
- def change_lang(choice):
175
- global lang_choice
176
- lang_choice = choice
177
- new_image_label = translate_sentence(image_label, "english", choice)
178
- return [gr.update(visible=True, label=translate_sentence(image_label, flores_codes["English"],flores_codes[choice])),
179
- gr.update(visible=True, label=translate_sentence(extract_label, flores_codes["English"],flores_codes[choice])),
180
- gr.update(visible=True, label=translate_sentence(prompt_label, flores_codes["English"],flores_codes[choice])),
181
- gr.update(visible=True, value=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])),
182
- gr.update(visible=True, label=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])),
183
- ]
184
-
185
  def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ):
186
  if shot_radio != '':
187
- prompt_text += ","+shot_radio
188
  if style_radio != '':
189
- prompt_text += ","+style_radio
190
  if lighting_radio != '':
191
- prompt_text += ","+lighting_radio
192
  if context_radio != '':
193
- prompt_text += ","+ context_radio
194
  if lens_radio != '':
195
- prompt_text += ","+ lens_radio
196
  if device_radio != '':
197
- prompt_text += ","+ device_radio
198
  return prompt_text
199
 
200
  def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio):
201
  if extract_text == "" or input_file == "":
202
- translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"])
203
- translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio)
204
- print(translated_prompt)
205
- output = SDpipe(translated_prompt, height=512, width=512, num_images_per_prompt=4)
206
  return output.images
207
  elif extract_text != "" and input_file != "" and prompt_text !='':
208
- translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"])
209
- translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio)
210
- print(translated_prompt)
211
- translated_extract = translate_sentence(extract_text, flores_codes[lang_choice], flores_codes["English"])
212
- print(translated_extract)
213
- output = generate_image(Image.fromarray(input_file), translated_extract, translated_prompt)
214
  return output
215
  else:
216
- raise gr.Error("Please fill all details for guided image or atleast promt for free image rendition !")
217
 
218
 
219
 
220
  with gr.Blocks() as demo:
221
 
222
- lang_option = gr.Dropdown(list(flores_codes.keys()), default='English', label='Please Select your Language')
223
-
224
  with gr.Row():
225
  input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512))
226
  extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True)
@@ -234,15 +164,12 @@ with gr.Blocks() as demo:
234
  lens_radio = gr.Radio(lens_services , label=lens_label)
235
  device_radio = gr.Radio(device_services , label=device_label)
236
 
237
- button = gr.Button(value = button_label , visible = False)
238
 
239
  with gr.Row():
240
- output_gallery = gr.Gallery(label = output_label, visible= False)
241
 
242
-
243
-
244
 
245
- lang_option.change(fn=change_lang, inputs=lang_option, outputs=[input_file, extract_text, prompt_text, button, output_gallery])
246
  button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery])
247
 
248
 
 
14
  import streamlit as st
15
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
16
 
17
+ import gradio as gr
18
+ import gc
19
+
20
 
21
 
22
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
 
86
  generator=generator,
87
  ).images
88
 
89
+ im.enable_attention_slicing()
90
  return im
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
 
92
  gc.collect()
93
+ torch.cuda.empty_cache()
94
 
95
  image_label = 'Please upload the image (optional)'
96
+ extract_label = 'Specify what needs to be extracted from the above image'
97
  prompt_label = 'Specify the description of image to be generated'
98
  button_label = "Proceed"
99
+ output_label = "Results"
 
100
 
101
  shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long']
102
  shot_label = 'Choose the shot type'
 
117
  device_label = 'Choose the device type'
118
 
119
 
 
 
 
 
 
 
 
 
 
 
 
120
  def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ):
121
  if shot_radio != '':
122
+ prompt_text += ", "+shot_radio
123
  if style_radio != '':
124
+ prompt_text += ", "+style_radio
125
  if lighting_radio != '':
126
+ prompt_text += ", "+lighting_radio
127
  if context_radio != '':
128
+ prompt_text += ", "+ context_radio
129
  if lens_radio != '':
130
+ prompt_text += ", "+ lens_radio
131
  if device_radio != '':
132
+ prompt_text += ", "+ device_radio
133
  return prompt_text
134
 
135
  def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio):
136
  if extract_text == "" or input_file == "":
137
+ prompt = add_to_prompt(prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio)
138
+ print(prompt)
139
+ output = SDpipe(prompt, height=512, width=512, num_images_per_prompt=4)
 
140
  return output.images
141
  elif extract_text != "" and input_file != "" and prompt_text !='':
142
+ prompt = add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio)
143
+ print(prompt)
144
+ print(extract_text)
145
+ output = generate_image(Image.fromarray(input_file), extract_text, prompt)
 
 
146
  return output
147
  else:
148
+ raise gr.Error("Please fill all details for guided image or atleast prompt for free image rendition !")
149
 
150
 
151
 
152
  with gr.Blocks() as demo:
153
 
 
 
154
  with gr.Row():
155
  input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512))
156
  extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True)
 
164
  lens_radio = gr.Radio(lens_services , label=lens_label)
165
  device_radio = gr.Radio(device_services , label=device_label)
166
 
167
+ button = gr.Button(value = button_label , visible = True)
168
 
169
  with gr.Row():
170
+ output_gallery = gr.Gallery(label = output_label, visible= True)
171
 
 
 
172
 
 
173
  button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery])
174
 
175