Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,9 @@ import os
|
|
14 |
import streamlit as st
|
15 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
16 |
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
@@ -83,74 +86,17 @@ def generate_image(image, product_name, target_name):
|
|
83 |
generator=generator,
|
84 |
).images
|
85 |
|
|
|
86 |
return im
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
def translate_sentence(article, source, target):
|
91 |
-
if target == 'eng_Latn':
|
92 |
-
return article
|
93 |
-
translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source, tgt_lang=target)
|
94 |
-
output = translator(article, max_length=400)
|
95 |
-
output = output[0]['translation_text']
|
96 |
-
return output
|
97 |
-
|
98 |
-
|
99 |
-
codes_as_string = '''Modern Standard Arabic arb_Arab
|
100 |
-
Danish dan_Latn
|
101 |
-
German deu_Latn
|
102 |
-
Greek ell_Grek
|
103 |
-
English eng_Latn
|
104 |
-
Estonian est_Latn
|
105 |
-
Finnish fin_Latn
|
106 |
-
French fra_Latn
|
107 |
-
Hebrew heb_Hebr
|
108 |
-
Hindi hin_Deva
|
109 |
-
Croatian hrv_Latn
|
110 |
-
Hungarian hun_Latn
|
111 |
-
Indonesian ind_Latn
|
112 |
-
Icelandic isl_Latn
|
113 |
-
Italian ita_Latn
|
114 |
-
Japanese jpn_Jpan
|
115 |
-
Korean kor_Hang
|
116 |
-
Luxembourgish ltz_Latn
|
117 |
-
Macedonian mkd_Cyrl
|
118 |
-
Maltese mlt_Latn
|
119 |
-
Dutch nld_Latn
|
120 |
-
Norwegian Bokmål nob_Latn
|
121 |
-
Polish pol_Latn
|
122 |
-
Portuguese por_Latn
|
123 |
-
Russian rus_Cyrl
|
124 |
-
Slovak slk_Latn
|
125 |
-
Slovenian slv_Latn
|
126 |
-
Spanish spa_Latn
|
127 |
-
Serbian srp_Cyrl
|
128 |
-
Swedish swe_Latn
|
129 |
-
Thai tha_Thai
|
130 |
-
Turkish tur_Latn
|
131 |
-
Ukrainian ukr_Cyrl
|
132 |
-
Vietnamese vie_Latn
|
133 |
-
Chinese (Simplified) zho_Hans'''
|
134 |
-
|
135 |
-
codes_as_string = codes_as_string.split('\n')
|
136 |
-
|
137 |
-
flores_codes = {}
|
138 |
-
for code in codes_as_string:
|
139 |
-
lang, lang_code = code.split('\t')
|
140 |
-
flores_codes[lang] = lang_code
|
141 |
-
|
142 |
-
|
143 |
|
144 |
-
import gradio as gr
|
145 |
-
import gc
|
146 |
gc.collect()
|
|
|
147 |
|
148 |
image_label = 'Please upload the image (optional)'
|
149 |
-
extract_label = 'Specify what
|
150 |
prompt_label = 'Specify the description of image to be generated'
|
151 |
button_label = "Proceed"
|
152 |
-
output_label = "
|
153 |
-
|
154 |
|
155 |
shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long']
|
156 |
shot_label = 'Choose the shot type'
|
@@ -171,56 +117,40 @@ device_services = ['iphone', 'CCTV', 'Nikon ZFX','Canon', 'Gopro']
|
|
171 |
device_label = 'Choose the device type'
|
172 |
|
173 |
|
174 |
-
def change_lang(choice):
|
175 |
-
global lang_choice
|
176 |
-
lang_choice = choice
|
177 |
-
new_image_label = translate_sentence(image_label, "english", choice)
|
178 |
-
return [gr.update(visible=True, label=translate_sentence(image_label, flores_codes["English"],flores_codes[choice])),
|
179 |
-
gr.update(visible=True, label=translate_sentence(extract_label, flores_codes["English"],flores_codes[choice])),
|
180 |
-
gr.update(visible=True, label=translate_sentence(prompt_label, flores_codes["English"],flores_codes[choice])),
|
181 |
-
gr.update(visible=True, value=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])),
|
182 |
-
gr.update(visible=True, label=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])),
|
183 |
-
]
|
184 |
-
|
185 |
def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ):
|
186 |
if shot_radio != '':
|
187 |
-
prompt_text += ","+shot_radio
|
188 |
if style_radio != '':
|
189 |
-
prompt_text += ","+style_radio
|
190 |
if lighting_radio != '':
|
191 |
-
prompt_text += ","+lighting_radio
|
192 |
if context_radio != '':
|
193 |
-
prompt_text += ","+ context_radio
|
194 |
if lens_radio != '':
|
195 |
-
prompt_text += ","+ lens_radio
|
196 |
if device_radio != '':
|
197 |
-
prompt_text += ","+ device_radio
|
198 |
return prompt_text
|
199 |
|
200 |
def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio):
|
201 |
if extract_text == "" or input_file == "":
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
output = SDpipe(translated_prompt, height=512, width=512, num_images_per_prompt=4)
|
206 |
return output.images
|
207 |
elif extract_text != "" and input_file != "" and prompt_text !='':
|
208 |
-
|
209 |
-
|
210 |
-
print(
|
211 |
-
|
212 |
-
print(translated_extract)
|
213 |
-
output = generate_image(Image.fromarray(input_file), translated_extract, translated_prompt)
|
214 |
return output
|
215 |
else:
|
216 |
-
raise gr.Error("Please fill all details for guided image or atleast
|
217 |
|
218 |
|
219 |
|
220 |
with gr.Blocks() as demo:
|
221 |
|
222 |
-
lang_option = gr.Dropdown(list(flores_codes.keys()), default='English', label='Please Select your Language')
|
223 |
-
|
224 |
with gr.Row():
|
225 |
input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512))
|
226 |
extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True)
|
@@ -234,15 +164,12 @@ with gr.Blocks() as demo:
|
|
234 |
lens_radio = gr.Radio(lens_services , label=lens_label)
|
235 |
device_radio = gr.Radio(device_services , label=device_label)
|
236 |
|
237 |
-
button = gr.Button(value = button_label , visible =
|
238 |
|
239 |
with gr.Row():
|
240 |
-
output_gallery = gr.Gallery(label = output_label, visible=
|
241 |
|
242 |
-
|
243 |
-
|
244 |
|
245 |
-
lang_option.change(fn=change_lang, inputs=lang_option, outputs=[input_file, extract_text, prompt_text, button, output_gallery])
|
246 |
button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery])
|
247 |
|
248 |
|
|
|
14 |
import streamlit as st
|
15 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
16 |
|
17 |
+
import gradio as gr
|
18 |
+
import gc
|
19 |
+
|
20 |
|
21 |
|
22 |
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
|
|
86 |
generator=generator,
|
87 |
).images
|
88 |
|
89 |
+
im.enable_attention_slicing()
|
90 |
return im
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
|
|
|
|
92 |
gc.collect()
|
93 |
+
torch.cuda.empty_cache()
|
94 |
|
95 |
image_label = 'Please upload the image (optional)'
|
96 |
+
extract_label = 'Specify what needs to be extracted from the above image'
|
97 |
prompt_label = 'Specify the description of image to be generated'
|
98 |
button_label = "Proceed"
|
99 |
+
output_label = "Results"
|
|
|
100 |
|
101 |
shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long']
|
102 |
shot_label = 'Choose the shot type'
|
|
|
117 |
device_label = 'Choose the device type'
|
118 |
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ):
|
121 |
if shot_radio != '':
|
122 |
+
prompt_text += ", "+shot_radio
|
123 |
if style_radio != '':
|
124 |
+
prompt_text += ", "+style_radio
|
125 |
if lighting_radio != '':
|
126 |
+
prompt_text += ", "+lighting_radio
|
127 |
if context_radio != '':
|
128 |
+
prompt_text += ", "+ context_radio
|
129 |
if lens_radio != '':
|
130 |
+
prompt_text += ", "+ lens_radio
|
131 |
if device_radio != '':
|
132 |
+
prompt_text += ", "+ device_radio
|
133 |
return prompt_text
|
134 |
|
135 |
def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio):
|
136 |
if extract_text == "" or input_file == "":
|
137 |
+
prompt = add_to_prompt(prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio)
|
138 |
+
print(prompt)
|
139 |
+
output = SDpipe(prompt, height=512, width=512, num_images_per_prompt=4)
|
|
|
140 |
return output.images
|
141 |
elif extract_text != "" and input_file != "" and prompt_text !='':
|
142 |
+
prompt = add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio)
|
143 |
+
print(prompt)
|
144 |
+
print(extract_text)
|
145 |
+
output = generate_image(Image.fromarray(input_file), extract_text, prompt)
|
|
|
|
|
146 |
return output
|
147 |
else:
|
148 |
+
raise gr.Error("Please fill all details for guided image or atleast prompt for free image rendition !")
|
149 |
|
150 |
|
151 |
|
152 |
with gr.Blocks() as demo:
|
153 |
|
|
|
|
|
154 |
with gr.Row():
|
155 |
input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512))
|
156 |
extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True)
|
|
|
164 |
lens_radio = gr.Radio(lens_services , label=lens_label)
|
165 |
device_radio = gr.Radio(device_services , label=device_label)
|
166 |
|
167 |
+
button = gr.Button(value = button_label , visible = True)
|
168 |
|
169 |
with gr.Row():
|
170 |
+
output_gallery = gr.Gallery(label = output_label, visible= True)
|
171 |
|
|
|
|
|
172 |
|
|
|
173 |
button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery])
|
174 |
|
175 |
|