tow252 santoshtyss commited on
Commit
bb167dc
0 Parent(s):

Duplicate from santoshtyss/QuickAd

Browse files

Co-authored-by: Santosh T.Y.S.S <santoshtyss@users.noreply.huggingface.co>

Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +252 -0
  4. requirements.txt +9 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: QuickAd
3
+ emoji: 👀
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.10.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: bigscience-openrail-m
11
+ duplicated_from: santoshtyss/QuickAd
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
2
+ from diffusers import StableDiffusionInpaintPipeline,StableDiffusionPipeline
3
+ from PIL import Image
4
+ import requests
5
+
6
+ import cv2
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+
10
+ import io
11
+ import requests
12
+ from huggingface_hub import login
13
+
14
+ import os
15
+ import streamlit as st
16
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
17
+
18
+
19
+
20
+ processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
21
+ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
22
+
23
+
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ IPmodel_path = "runwayml/stable-diffusion-inpainting"
26
+
27
+ IPpipe = StableDiffusionInpaintPipeline.from_pretrained(
28
+ IPmodel_path,
29
+ revision="fp16",
30
+ torch_dtype=torch.float16,
31
+ use_auth_token= st.secrets["AUTH_TOKEN"]
32
+ ).to(device)
33
+
34
+ trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
35
+ trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
36
+
37
+
38
+ SDpipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=st.secrets["AUTH_TOKEN"]).to(device)
39
+
40
+
41
+ def create_mask(image, prompt):
42
+ inputs = processor(text=[prompt], images=[image], padding="max_length", return_tensors="pt")
43
+ # predict
44
+ with torch.no_grad():
45
+ outputs = model(**inputs)
46
+
47
+ preds = outputs.logits
48
+
49
+ filename = f"mask.png"
50
+ plt.imsave(filename,torch.sigmoid(preds))
51
+
52
+ gray_image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY)
53
+
54
+ (thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY)
55
+
56
+ # For debugging only:
57
+ # cv2.imwrite(filename,bw_image)
58
+
59
+ # fix color format
60
+ cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
61
+
62
+ mask = cv2.bitwise_not(bw_image)
63
+ cv2.imwrite(filename, mask)
64
+
65
+ return Image.open('mask.png')
66
+
67
+
68
+
69
+
70
+ def generate_image(image, product_name, target_name):
71
+ mask = create_mask(image, product_name)
72
+ image = image.resize((512, 512))
73
+ mask = mask.resize((512,512))
74
+ guidance_scale=8
75
+ #guidance_scale=16
76
+ num_samples = 4
77
+
78
+ prompt = target_name
79
+ generator = torch.Generator(device=device).manual_seed(22) # change the seed to get different results
80
+
81
+ im = IPpipe(
82
+ prompt=prompt,
83
+ image=image,
84
+ mask_image=mask,
85
+ guidance_scale=guidance_scale,
86
+ generator=generator,
87
+ ).images
88
+
89
+ return im
90
+
91
+
92
+
93
+ def translate_sentence(article, source, target):
94
+ if target == 'eng_Latn':
95
+ return article
96
+ translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source, tgt_lang=target)
97
+ output = translator(article, max_length=400)
98
+ output = output[0]['translation_text']
99
+ return output
100
+
101
+
102
+ codes_as_string = '''Modern Standard Arabic arb_Arab
103
+ Danish dan_Latn
104
+ German deu_Latn
105
+ Greek ell_Grek
106
+ English eng_Latn
107
+ Estonian est_Latn
108
+ Finnish fin_Latn
109
+ French fra_Latn
110
+ Hebrew heb_Hebr
111
+ Hindi hin_Deva
112
+ Croatian hrv_Latn
113
+ Hungarian hun_Latn
114
+ Indonesian ind_Latn
115
+ Icelandic isl_Latn
116
+ Italian ita_Latn
117
+ Japanese jpn_Jpan
118
+ Korean kor_Hang
119
+ Luxembourgish ltz_Latn
120
+ Macedonian mkd_Cyrl
121
+ Maltese mlt_Latn
122
+ Dutch nld_Latn
123
+ Norwegian Bokmål nob_Latn
124
+ Polish pol_Latn
125
+ Portuguese por_Latn
126
+ Russian rus_Cyrl
127
+ Slovak slk_Latn
128
+ Slovenian slv_Latn
129
+ Spanish spa_Latn
130
+ Serbian srp_Cyrl
131
+ Swedish swe_Latn
132
+ Thai tha_Thai
133
+ Turkish tur_Latn
134
+ Ukrainian ukr_Cyrl
135
+ Vietnamese vie_Latn
136
+ Chinese (Simplified) zho_Hans'''
137
+
138
+ codes_as_string = codes_as_string.split('\n')
139
+
140
+ flores_codes = {}
141
+ for code in codes_as_string:
142
+ lang, lang_code = code.split('\t')
143
+ flores_codes[lang] = lang_code
144
+
145
+
146
+
147
+ import gradio as gr
148
+ import gc
149
+ gc.collect()
150
+
151
+ image_label = 'Please upload the image (optional)'
152
+ extract_label = 'Specify what need to be extracted from the above image'
153
+ prompt_label = 'Specify the description of image to be generated'
154
+ button_label = "Proceed"
155
+ output_label = "Generations"
156
+
157
+
158
+ shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long']
159
+ shot_label = 'Choose the shot type'
160
+
161
+ style_services = ['polaroid', 'monochrome', 'long exposure','color splash', 'Tilt shift']
162
+ style_label = 'Choose the style type'
163
+
164
+ lighting_services = ['soft', 'ambivalent', 'ring','sun', 'cinematic']
165
+ lighting_label = 'Choose the lighting type'
166
+
167
+ context_services = ['indoor', 'outdoor', 'at night','in the park', 'in the beach','studio']
168
+ context_label = 'Choose the context'
169
+
170
+ lens_services = ['wide angle', 'telephoto', '24 mm','EF 70mm', 'Bokeh']
171
+ lens_label = 'Choose the lens type'
172
+
173
+ device_services = ['iphone', 'CCTV', 'Nikon ZFX','Canon', 'Gopro']
174
+ device_label = 'Choose the device type'
175
+
176
+
177
+ def change_lang(choice):
178
+ global lang_choice
179
+ lang_choice = choice
180
+ new_image_label = translate_sentence(image_label, "english", choice)
181
+ return [gr.update(visible=True, label=translate_sentence(image_label, flores_codes["English"],flores_codes[choice])),
182
+ gr.update(visible=True, label=translate_sentence(extract_label, flores_codes["English"],flores_codes[choice])),
183
+ gr.update(visible=True, label=translate_sentence(prompt_label, flores_codes["English"],flores_codes[choice])),
184
+ gr.update(visible=True, value=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])),
185
+ gr.update(visible=True, label=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])),
186
+ ]
187
+
188
+ def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ):
189
+ if shot_radio != '':
190
+ prompt_text += ","+shot_radio
191
+ if style_radio != '':
192
+ prompt_text += ","+style_radio
193
+ if lighting_radio != '':
194
+ prompt_text += ","+lighting_radio
195
+ if context_radio != '':
196
+ prompt_text += ","+ context_radio
197
+ if lens_radio != '':
198
+ prompt_text += ","+ lens_radio
199
+ if device_radio != '':
200
+ prompt_text += ","+ device_radio
201
+ return prompt_text
202
+
203
+ def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio):
204
+ if extract_text == "" or input_file == "":
205
+ translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"])
206
+ translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio)
207
+ print(translated_prompt)
208
+ output = SDpipe(translated_prompt, height=512, width=512, num_images_per_prompt=4)
209
+ return output.images
210
+ elif extract_text != "" and input_file != "" and prompt_text !='':
211
+ translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"])
212
+ translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio)
213
+ print(translated_prompt)
214
+ translated_extract = translate_sentence(extract_text, flores_codes[lang_choice], flores_codes["English"])
215
+ print(translated_extract)
216
+ output = generate_image(Image.fromarray(input_file), translated_extract, translated_prompt)
217
+ return output
218
+ else:
219
+ raise gr.Error("Please fill all details for guided image or atleast promt for free image rendition !")
220
+
221
+
222
+
223
+ with gr.Blocks() as demo:
224
+
225
+ lang_option = gr.Dropdown(list(flores_codes.keys()), default='English', label='Please Select your Language')
226
+
227
+ with gr.Row():
228
+ input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512))
229
+ extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True)
230
+ prompt_text = gr.Textbox(label= prompt_label, lines=1, interactive = True, visible = True)
231
+
232
+ with gr.Accordion("Advanced Options", open=False):
233
+ shot_radio = gr.Radio(shot_services , label=shot_label, )
234
+ style_radio = gr.Radio(style_services , label=style_label)
235
+ lighting_radio = gr.Radio(lighting_services , label=lighting_label)
236
+ context_radio = gr.Radio(context_services , label=context_label)
237
+ lens_radio = gr.Radio(lens_services , label=lens_label)
238
+ device_radio = gr.Radio(device_services , label=device_label)
239
+
240
+ button = gr.Button(value = button_label , visible = False)
241
+
242
+ with gr.Row():
243
+ output_gallery = gr.Gallery(label = output_label, visible= False)
244
+
245
+
246
+
247
+
248
+ lang_option.change(fn=change_lang, inputs=lang_option, outputs=[input_file, extract_text, prompt_text, button, output_gallery])
249
+ button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery])
250
+
251
+
252
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ git+https://github.com/huggingface/diffusers
3
+ accelerate
4
+ transformers
5
+ sentencepiece
6
+ Pillow
7
+ gradio
8
+ torch
9
+ opencv-python