hav4ik commited on
Commit
2338872
1 Parent(s): 4d8d3ff
Files changed (3) hide show
  1. app.py +577 -117
  2. requirements.txt +7 -1
  3. style.css +16 -0
app.py CHANGED
@@ -1,142 +1,602 @@
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
- #import spaces #[uncomment to use ZeroGPU]
5
- from diffusers import DiffusionPipeline
6
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_repo_id = "stabilityai/sdxl-turbo" #Replace to the model you would like to use
10
 
11
  if torch.cuda.is_available():
12
- torch_dtype = torch.float16
 
 
 
13
  else:
14
- torch_dtype = torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
17
- pipe = pipe.to(device)
18
 
19
- MAX_SEED = np.iinfo(np.int32).max
20
- MAX_IMAGE_SIZE = 1024
21
 
22
- #@spaces.GPU #[uncomment to use ZeroGPU]
23
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
24
 
 
25
  if randomize_seed:
26
  seed = random.randint(0, MAX_SEED)
27
-
28
- generator = torch.Generator().manual_seed(seed)
29
-
30
- image = pipe(
31
- prompt = prompt,
32
- negative_prompt = negative_prompt,
33
- guidance_scale = guidance_scale,
34
- num_inference_steps = num_inference_steps,
35
- width = width,
36
- height = height,
37
- generator = generator
38
- ).images[0]
39
-
40
- return image, seed
41
-
42
- examples = [
43
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44
- "An astronaut riding a green horse",
45
- "A delicious ceviche cheesecake slice",
46
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- css="""
49
- #col-container {
50
- margin: 0 auto;
51
- max-width: 640px;
52
- }
53
- """
54
-
55
- with gr.Blocks(css=css) as demo:
56
-
57
- with gr.Column(elem_id="col-container"):
58
- gr.Markdown(f"""
59
- # Text-to-Image Gradio Template
60
- """)
61
-
62
- with gr.Row():
63
-
64
- prompt = gr.Text(
65
- label="Prompt",
66
- show_label=False,
67
- max_lines=1,
68
- placeholder="Enter your prompt",
69
- container=False,
70
- )
71
-
72
- run_button = gr.Button("Run", scale=0)
73
-
74
- result = gr.Image(label="Result", show_label=False)
75
-
76
- with gr.Accordion("Advanced Settings", open=False):
77
-
78
- negative_prompt = gr.Text(
79
- label="Negative prompt",
80
- max_lines=1,
81
- placeholder="Enter a negative prompt",
82
- visible=False,
83
- )
84
-
85
- seed = gr.Slider(
86
- label="Seed",
87
- minimum=0,
88
- maximum=MAX_SEED,
89
- step=1,
90
- value=0,
91
- )
92
-
93
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
94
-
95
- with gr.Row():
96
-
97
- width = gr.Slider(
98
- label="Width",
99
- minimum=256,
100
- maximum=MAX_IMAGE_SIZE,
101
- step=32,
102
- value=1024, #Replace with defaults that work for your model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
-
105
- height = gr.Slider(
106
- label="Height",
107
- minimum=256,
108
- maximum=MAX_IMAGE_SIZE,
109
- step=32,
110
- value=1024, #Replace with defaults that work for your model
 
 
 
 
 
 
 
111
  )
112
-
113
- with gr.Row():
114
-
115
  guidance_scale = gr.Slider(
116
  label="Guidance scale",
117
- minimum=0.0,
118
  maximum=10.0,
119
  step=0.1,
120
- value=0.0, #Replace with defaults that work for your model
121
  )
122
-
123
- num_inference_steps = gr.Slider(
124
- label="Number of inference steps",
125
- minimum=1,
126
- maximum=50,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  step=1,
128
- value=2, #Replace with defaults that work for your model
129
  )
130
-
131
- gr.Examples(
132
- examples = examples,
133
- inputs = [prompt]
134
- )
135
- gr.on(
136
- triggers=[run_button.click, prompt.submit],
137
- fn = infer,
138
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
139
- outputs = [result, seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
141
 
142
- demo.queue().launch()
 
 
1
+ import os
2
+ import random
3
+
4
  import gradio as gr
5
  import numpy as np
6
+ import PIL.Image
7
+ from PIL import ImageOps
 
8
  import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from transformers import BitsAndBytesConfig
11
+ import torchvision.transforms.functional as TF
12
+ from diffusers import (
13
+ AutoencoderKL,
14
+ EulerAncestralDiscreteScheduler,
15
+ StableDiffusionXLAdapterPipeline,
16
+ T2IAdapter,
17
+ )
18
+
19
+ import urllib.parse
20
+ import requests
21
+ from io import BytesIO
22
+ import json
23
+
24
+ from pathlib import Path
25
+ import uuid
26
+ import os, uuid
27
+ from azure.identity import DefaultAzureCredential
28
+ from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient
29
+
30
+ from datetime import datetime
31
+
32
+
33
+
34
+ class DEFAULTS:
35
+ NEGATIVE_PROMPT = " extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured"
36
+ REWRITING_PROMPT = (
37
+ "Rewrite the image caption by making it shorter (but retain all information about relative position), "
38
+ "remove information about style of objects or colors of background and foreground, and, most importantly, remove all details "
39
+ "that suggests it is a sketch. Write it as a Google image search query:"
40
+ )
41
+ MOONDREAM_PROMPT = "Describe this image."
42
+ NUM_STEPS = 25
43
+ GUIDANCE_SCALE = 5
44
+ ADAPTER_CONDITIONING_SCALE = 0.8
45
+ ADAPTER_CONDITIONING_FACTOR = 0.8
46
+ SEED = 1231245
47
+ RANDOMIZE_SEED = True
48
+
49
+
50
+
51
+ DESCRIPTION = '''# Sketch to Image/Caption to Bing Search :)
52
+ '''
53
+
54
+ if not torch.cuda.is_available():
55
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
56
+
57
+ style_list = [
58
+ {
59
+ "name": "(No style)",
60
+ "prompt": "{prompt}",
61
+ "negative_prompt": "",
62
+ },
63
+ {
64
+ "name": "Cinematic",
65
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
66
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
67
+ },
68
+ {
69
+ "name": "3D Model",
70
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
71
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
72
+ },
73
+ {
74
+ "name": "Anime",
75
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
76
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
77
+ },
78
+ {
79
+ "name": "Digital Art",
80
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
81
+ "negative_prompt": "photo, photorealistic, realism, ugly",
82
+ },
83
+ {
84
+ "name": "Photographic",
85
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
86
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
87
+ },
88
+ {
89
+ "name": "Pixel art",
90
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
91
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
92
+ },
93
+ {
94
+ "name": "Fantasy art",
95
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
96
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
97
+ },
98
+ {
99
+ "name": "Neonpunk",
100
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
101
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
102
+ },
103
+ {
104
+ "name": "Manga",
105
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
106
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
107
+ },
108
+ ]
109
+
110
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
111
+ STYLE_NAMES = list(styles.keys())
112
+ DEFAULT_STYLE_NAME = "Photographic" # "(No style)"
113
+
114
+
115
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
116
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
117
+ return p.replace("{prompt}", positive), n + negative
118
+
119
+
120
+ # with open("azure_connection_string.txt", "r") as f:
121
+ # CONNECTION_STRING = f.read().strip()
122
+ CONNECTION_STRING = os.getenv("AZURE_CONNECTION_STRING")
123
+
124
+
125
+ def upload_pil_image_to_azure(image, connection_string=CONNECTION_STRING):
126
+ image_name = f"{uuid.uuid4()}.png"
127
+ image_bytes = BytesIO()
128
+ image.save(image_bytes, format="PNG")
129
+ image_bytes.seek(0)
130
+
131
+ try:
132
+ # Create the BlobServiceClient object
133
+ blob_service_client = BlobServiceClient.from_connection_string(connection_string)
134
+ # Create a blob client using the local file name as the name for the blob
135
+ blob_client = blob_service_client.get_blob_client(container="blob-image-hosting", blob=image_name)
136
+ # Upload the created file and retrieve the URL
137
+ blob_client.upload_blob(image_bytes)
138
+ file_url = blob_client.url
139
+ except Exception as ex:
140
+ print('Exception:')
141
+ print(ex)
142
+ file_url = None
143
+ # If this function did not fail, upload was successful
144
+ return file_url
145
 
 
 
146
 
147
  if torch.cuda.is_available():
148
+ if torch.cuda.device_count() > 1:
149
+ device_0, device_1 = torch.device("cuda:0"), torch.device("cuda:1")
150
+ else:
151
+ device_0, device_1 = torch.device("cuda:0"), torch.device("cuda:0")
152
  else:
153
+ device_0, device_1 = torch.device("cpu"), torch.device("cpu")
154
+
155
+ if torch.cuda.is_available():
156
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
157
+ adapter = T2IAdapter.from_pretrained(
158
+ "TencentARC/t2i-adapter-sketch-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
159
+ )
160
+ scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
161
+ pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
162
+ model_id,
163
+ vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16),
164
+ adapter=adapter,
165
+ scheduler=scheduler,
166
+ torch_dtype=torch.float16,
167
+ variant="fp16",
168
+ )
169
+ pipe.to(device_0)
170
+ else:
171
+ pipe = None
172
 
 
 
173
 
 
 
174
 
175
+ MAX_SEED = np.iinfo(np.int32).max
 
176
 
177
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
178
  if randomize_seed:
179
  seed = random.randint(0, MAX_SEED)
180
+ return seed
181
+
182
+ nf4_config = BitsAndBytesConfig(
183
+ load_in_4bit=True,
184
+ bnb_4bit_quant_type="nf4",
185
+ bnb_4bit_use_double_quant=True,
186
+ bnb_4bit_compute_dtype=torch.bfloat16
187
+ )
188
+
189
+ vlmodel_id = "vikhyatk/moondream2"
190
+ vlmodel_revision = "2024-07-23"
191
+ vlmodel = AutoModelForCausalLM.from_pretrained(
192
+ vlmodel_id, trust_remote_code=True, revision=vlmodel_revision, device_map={"": device_1},
193
+ torch_dtype=torch.float16, attn_implementation="flash_attention_2")
194
+ vltokenizer = AutoTokenizer.from_pretrained(vlmodel_id, revision=vlmodel_revision)
195
+
196
+ rewrite_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
197
+ rewrite_model = AutoModelForCausalLM.from_pretrained(
198
+ rewrite_model_name,
199
+ device_map={"": device_1},
200
+ quantization_config=nf4_config,
201
+ # load_in_8bit=True,
202
+ torch_dtype=torch.bfloat16,
203
+ attn_implementation="flash_attention_2")
204
+ rewrite_tokenizer = AutoTokenizer.from_pretrained(rewrite_model_name)
205
+
206
+
207
+ def caption_image_with_recaption(pil_image, moondream_prompt, rewriting_prompt, user_prompt=""):
208
+ enc_image = vlmodel.encode_image(pil_image)
209
+ img_caption = vlmodel.answer_question(enc_image, moondream_prompt, vltokenizer)
210
+ rewritten_caption = rewrite_prompt(img_caption, rewriting_prompt, user_prompt=user_prompt)
211
+ rewritten_caption = rewritten_caption.strip('"').replace("\n", " ")
212
+ return img_caption, rewritten_caption
213
+
214
+
215
+ def rewrite_prompt(image_cap: str, guide: str, user_prompt: str = "") -> str:
216
+ prompt = f"{guide}\n{image_cap}"
217
+ messages = [
218
+ {"role": "system", "content": "You are a helpful assistant."},
219
+ {"role": "user", "content": prompt}
220
+ ]
221
+ text = rewrite_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
222
+ model_inputs = rewrite_tokenizer([text], return_tensors="pt").to(device_1)
223
+ generated_ids = rewrite_model.generate(model_inputs.input_ids, max_new_tokens=128)
224
+ generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
225
+ response = rewrite_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
226
+ return response
227
+
228
+
229
+ def run_full(
230
+ image,
231
+ user_prompt: str,
232
+ negative_prompt: str,
233
+ rewriting_prompt: str,
234
+ moondream_prompt: str,
235
+ style_name: str = DEFAULT_STYLE_NAME,
236
+ num_steps: int = 25,
237
+ guidance_scale: float = 5,
238
+ adapter_conditioning_scale: float = 0.8,
239
+ adapter_conditioning_factor: float = 0.8,
240
+ seed: int = 0,
241
+ progress=None,
242
+ ) -> PIL.Image.Image:
243
+ # image is a white background with black sketch
244
+ image = ImageOps.invert(image)
245
+ # Threshold the image to get a binary sketch
246
+ image = TF.to_tensor(image) > 0.5
247
+ image = TF.to_pil_image(image.to(torch.float32))
248
+
249
+ full_log = []
250
+ if user_prompt == "":
251
+ pre_caption = True
252
+ start_time = datetime.now()
253
+ img_caption, rewritten_caption = caption_image_with_recaption(
254
+ pil_image=image, rewriting_prompt=rewriting_prompt, moondream_prompt=moondream_prompt)
255
+ full_log.append(f"Combined captioning time: {datetime.now() - start_time}")
256
+ full_log.append(f"img_caption (pre): {img_caption}")
257
+ full_log.append(f"rewritten_caption (pre): {rewritten_caption}")
258
+ drawing_prompt = rewritten_caption
259
+ else:
260
+ pre_caption = False
261
+ drawing_prompt = user_prompt
262
+ full_log.append(f"Pre-caption: {pre_caption}")
263
+
264
+ # Generate image
265
+ start_time = datetime.now()
266
+ drawing_prompt, negative_prompt = apply_style(style_name, drawing_prompt, negative_prompt)
267
+ generator = torch.Generator(device=device_0).manual_seed(seed)
268
+ out_img = pipe(
269
+ prompt=drawing_prompt,
270
+ negative_prompt=negative_prompt,
271
+ image=image,
272
+ num_inference_steps=num_steps,
273
+ generator=generator,
274
+ guidance_scale=guidance_scale,
275
+ adapter_conditioning_scale=adapter_conditioning_scale,
276
+ adapter_conditioning_factor=adapter_conditioning_factor,
277
+ ).images[0]
278
+ full_log.append(f"Image generation time: {datetime.now() - start_time}")
279
+
280
+ if not pre_caption:
281
+ start_time = datetime.now()
282
+ img_caption, rewritten_caption = caption_image_with_recaption(
283
+ pil_image=out_img,
284
+ rewriting_prompt=rewriting_prompt,
285
+ moondream_prompt=moondream_prompt,
286
+ user_prompt=user_prompt)
287
+ full_log.append(f"Combined captioning time: {datetime.now() - start_time}")
288
+ full_log.append(f"img_caption (post): {img_caption}")
289
+ full_log.append(f"rewritten_caption (post): {rewritten_caption}")
290
 
291
+ # SERP query
292
+ bing_serp_query = f"https://www.bing.com/images/search?q={urllib.parse.quote(rewritten_caption)}"
293
+ md_text = f"### Bing search query\n[{bing_serp_query}]({bing_serp_query})\n"
294
+
295
+ # Visual Search query
296
+ out_img_imgur_url = upload_pil_image_to_azure(out_img)
297
+ if out_img_imgur_url is None:
298
+ md_text += "### Bing Visual Search\n**Error:** Failed to upload image to Azure Blob Storage\n"
299
+ bing_image_search_url = "https://www.bing.com/images"
300
+ else:
301
+ imgur_url_quote = urllib.parse.quote(out_img_imgur_url)
302
+ bing_image_search_url = f"https://www.bing.com/images/search?view=detailv2&iss=SBI&form=SBIIRP&q=imgurl:{imgur_url_quote}"
303
+ md_text += f"### Bing Visual Search\n[{bing_image_search_url}]({bing_image_search_url})\n"
304
+
305
+ # Debug info
306
+ md_text += f"### Debug: sketch caption\n{img_caption}\n\n### Debug: rewritten caption\n{rewritten_caption}\n"
307
+
308
+ # Full log dump
309
+ md_text += f"### Debug: full log\n{'<br>'.join(full_log)}"
310
+
311
+ # return dict
312
+ return {
313
+ "image": out_img,
314
+ "text_search_url": bing_serp_query,
315
+ "visual_search_url": bing_image_search_url,
316
+ "logs": md_text,
317
+ }
318
+
319
+
320
+ def run_full_gradio(
321
+ image,
322
+ user_prompt: str,
323
+ negative_prompt: str,
324
+ rewriting_prompt: str,
325
+ moondream_prompt: str,
326
+ style_name: str = DEFAULT_STYLE_NAME,
327
+ num_steps: int = 25,
328
+ guidance_scale: float = 5,
329
+ adapter_conditioning_scale: float = 0.8,
330
+ adapter_conditioning_factor: float = 0.8,
331
+ seed: int = 0,
332
+ progress=gr.Progress(track_tqdm=True),
333
+ ) -> PIL.Image.Image:
334
+ image = image['composite']
335
+ background = PIL.Image.new('RGBA', image.size, (255, 255, 255))
336
+ alpha_composite = PIL.Image.alpha_composite(background, image)
337
+ image = alpha_composite.convert("RGB")
338
+
339
+ results = run_full(
340
+ image=image,
341
+ user_prompt=user_prompt,
342
+ negative_prompt=negative_prompt,
343
+ rewriting_prompt=rewriting_prompt,
344
+ moondream_prompt=moondream_prompt,
345
+ style_name=style_name,
346
+ num_steps=num_steps,
347
+ guidance_scale=guidance_scale,
348
+ adapter_conditioning_scale=adapter_conditioning_scale,
349
+ adapter_conditioning_factor=adapter_conditioning_factor,
350
+ seed=seed,
351
+ progress=progress,
352
+ )
353
+
354
+ # construct markdown output
355
+ return results["image"], results["logs"]
356
+
357
+
358
+ def run_full_api(
359
+ image_url: str,
360
+ user_prompt: str,
361
+ progress=gr.Progress(track_tqdm=True),
362
+ ) -> str:
363
+ seed = randomize_seed_fn(0, True)
364
+ image = PIL.Image.open(BytesIO(requests.get(image_url).content))
365
+ results = run_full(
366
+ image=image, user_prompt=user_prompt,
367
+ negative_prompt=DEFAULTS.NEGATIVE_PROMPT,
368
+ rewriting_prompt=DEFAULTS.REWRITING_PROMPT,
369
+ moondream_prompt=DEFAULTS.MOONDREAM_PROMPT,
370
+ style_name=DEFAULT_STYLE_NAME,
371
+ num_steps=DEFAULTS.NUM_STEPS,
372
+ guidance_scale=DEFAULTS.GUIDANCE_SCALE,
373
+ adapter_conditioning_scale=DEFAULTS.ADAPTER_CONDITIONING_SCALE,
374
+ adapter_conditioning_factor=DEFAULTS.ADAPTER_CONDITIONING_FACTOR,
375
+ seed=seed)
376
+ return results["text_search_url"], results["visual_search_url"], results["logs"]
377
+
378
+
379
+ def run_caponly(
380
+ image,
381
+ rewriting_prompt: str,
382
+ moondream_prompt: str,
383
+ seed: int = 0,
384
+ progress=None,
385
+ ) -> PIL.Image.Image:
386
+ # image is a white background with black sketch
387
+ image = ImageOps.invert(image)
388
+ # Threshold the image to get a binary sketch
389
+ image = TF.to_tensor(image) > 0.5
390
+ image = TF.to_pil_image(image.to(torch.float32))
391
+
392
+ full_log = []
393
+ start_time = datetime.now()
394
+ img_caption, rewritten_caption = caption_image_with_recaption(
395
+ pil_image=image, rewriting_prompt=rewriting_prompt, moondream_prompt=moondream_prompt)
396
+ full_log.append(f"Combined captioning time: {datetime.now() - start_time}")
397
+ full_log.append(f"img_caption (pre): {img_caption}")
398
+ full_log.append(f"rewritten_caption (pre): {rewritten_caption}")
399
+ final_prompt = rewritten_caption
400
+
401
+ # SERP query
402
+ bing_serp_query = f"https://www.bing.com/images/search?q={urllib.parse.quote(rewritten_caption)}"
403
+ md_text = f"### Bing search query\n[{bing_serp_query}]({bing_serp_query})\n"
404
+
405
+ # Debug info
406
+ md_text += f"### Debug: sketch caption\n{img_caption}\n\n### Debug: rewritten caption\n{rewritten_caption}\n"
407
+
408
+ # Full log dump
409
+ md_text += f"### Debug: full log\n{'<br>'.join(full_log)}"
410
+
411
+ # return dict
412
+ return {
413
+ "text_search_url": bing_serp_query,
414
+ "logs": md_text,
415
+ }
416
+
417
+
418
+ def run_caponly_api(
419
+ image_url: str,
420
+ progress=gr.Progress(track_tqdm=True),
421
+ ) -> str:
422
+ seed = randomize_seed_fn(0, True)
423
+ image = PIL.Image.open(BytesIO(requests.get(image_url).content))
424
+ results = run_caponly(
425
+ image=image,
426
+ rewriting_prompt=DEFAULTS.REWRITING_PROMPT,
427
+ moondream_prompt=DEFAULTS.MOONDREAM_PROMPT,
428
+ seed=seed)
429
+ return results["text_search_url"], results["logs"]
430
+
431
+
432
+ with gr.Blocks(css="style.css") as demo:
433
+ gr.Markdown(DESCRIPTION, elem_id="description")
434
+ gr.DuplicateButton(
435
+ value="Duplicate Space for private use",
436
+ elem_id="duplicate-button",
437
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
438
+ )
439
+
440
+ with gr.Row():
441
+ with gr.Column():
442
+ with gr.Group():
443
+ image = gr.Sketchpad(
444
+ # sources=["canvas"],
445
+ # tool="sketch",
446
+ type="pil",
447
+ image_mode="RGBA",
448
+ # invert_colors=True,
449
+ layers=False,
450
+ canvas_size=(1024, 1024),
451
+ brush=gr.Brush(
452
+ default_color="black",
453
+ colors=None,
454
+ default_size=4,
455
+ color_mode="fixed",
456
+ ),
457
+ eraser=gr.Eraser(),
458
+ height=440,
459
+ )
460
+ prompt = gr.Textbox(label="Prompt")
461
+ style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
462
+ run_button = gr.Button("Run")
463
+ with gr.Accordion("Advanced options", open=False):
464
+ negative_prompt = gr.Textbox(
465
+ label="Negative prompt",
466
+ value=DEFAULTS.NEGATIVE_PROMPT,
467
  )
468
+ rewriting_prompt = gr.Textbox(
469
+ label="Rewriting prompt",
470
+ value=DEFAULTS.REWRITING_PROMPT,
471
+ )
472
+ moondream_prompt = gr.Textbox(
473
+ label="Moondream prompt",
474
+ value=DEFAULTS.MOONDREAM_PROMPT,
475
+ )
476
+ num_steps = gr.Slider(
477
+ label="Number of steps",
478
+ minimum=1,
479
+ maximum=50,
480
+ step=1,
481
+ value=DEFAULTS.NUM_STEPS,
482
  )
 
 
 
483
  guidance_scale = gr.Slider(
484
  label="Guidance scale",
485
+ minimum=0.1,
486
  maximum=10.0,
487
  step=0.1,
488
+ value=DEFAULTS.GUIDANCE_SCALE,
489
  )
490
+ adapter_conditioning_scale = gr.Slider(
491
+ label="Adapter conditioning scale",
492
+ minimum=0.5,
493
+ maximum=1,
494
+ step=0.1,
495
+ value=DEFAULTS.ADAPTER_CONDITIONING_SCALE,
496
+ )
497
+ adapter_conditioning_factor = gr.Slider(
498
+ label="Adapter conditioning factor",
499
+ info="Fraction of timesteps for which adapter should be applied",
500
+ minimum=0.5,
501
+ maximum=1,
502
+ step=0.1,
503
+ value=DEFAULTS.ADAPTER_CONDITIONING_FACTOR,
504
+ )
505
+ seed = gr.Slider(
506
+ label="Seed",
507
+ minimum=0,
508
+ maximum=MAX_SEED,
509
  step=1,
510
+ value=0,
511
  )
512
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
513
+ with gr.Column():
514
+ result_img = gr.Image(label="Result", height=400, interactive=False)
515
+ result_caption = gr.Markdown(label="Image caption")
516
+ result = [result_img, result_caption]
517
+
518
+ with gr.Row():
519
+ gr.Markdown("# API endpoints", elem_id="description")
520
+ with gr.Row():
521
+ with gr.Column():
522
+ with gr.Accordion("Full Experience API", open=False):
523
+ api_fullexp_image_url = gr.Textbox(label="Image URL")
524
+ api_fullexp_user_prompt = gr.Textbox(label="User prompt")
525
+ api_fullexp_run_button = gr.Button("Run API")
526
+ api_fullexp_text_search_url = gr.Textbox(label="Text search URL")
527
+ api_fullexp_visual_search_url = gr.Textbox(label="Visual search URL")
528
+ api_fullexp_logs = gr.Markdown(label="Logs")
529
+ with gr.Column():
530
+ with gr.Accordion("Caption Only API", open=False):
531
+ api_caponly_image_url = gr.Textbox(label="Image URL")
532
+ api_caponly_run_button = gr.Button("Run API")
533
+ api_caponly_text_search_url = gr.Textbox(label="Text search URL")
534
+ api_caponly_logs = gr.Markdown(label="Logs")
535
+
536
+ # Gradio components interconnections
537
+ inputs = [
538
+ image,
539
+ prompt,
540
+ negative_prompt,
541
+ rewriting_prompt,
542
+ moondream_prompt,
543
+ style,
544
+ num_steps,
545
+ guidance_scale,
546
+ adapter_conditioning_scale,
547
+ adapter_conditioning_factor,
548
+ seed,
549
+ ]
550
+ prompt.submit(
551
+ fn=randomize_seed_fn,
552
+ inputs=[seed, randomize_seed],
553
+ outputs=seed,
554
+ queue=False,
555
+ api_name=False,
556
+ ).then(
557
+ fn=run_full_gradio,
558
+ inputs=inputs,
559
+ outputs=result,
560
+ api_name=False,
561
+ )
562
+ negative_prompt.submit(
563
+ fn=randomize_seed_fn,
564
+ inputs=[seed, randomize_seed],
565
+ outputs=seed,
566
+ queue=False,
567
+ api_name=False,
568
+ ).then(
569
+ fn=run_full_gradio,
570
+ inputs=inputs,
571
+ outputs=result,
572
+ api_name=False,
573
+ )
574
+ run_button.click(
575
+ fn=randomize_seed_fn,
576
+ inputs=[seed, randomize_seed],
577
+ outputs=seed,
578
+ queue=False,
579
+ api_name=False,
580
+ ).then(
581
+ fn=run_full_gradio,
582
+ inputs=inputs,
583
+ outputs=result,
584
+ api_name=False,
585
+ )
586
+
587
+ # API interconnections
588
+ api_fullexp_run_button.click(
589
+ fn=run_full_api,
590
+ inputs=[api_fullexp_image_url, api_fullexp_user_prompt],
591
+ outputs=[api_fullexp_text_search_url, api_fullexp_visual_search_url, api_fullexp_logs],
592
+ api_name="full_experience",
593
+ )
594
+ api_caponly_run_button.click(
595
+ fn=run_caponly_api,
596
+ inputs=[api_caponly_image_url],
597
+ outputs=[api_caponly_text_search_url, api_caponly_logs],
598
+ api_name="caption_only",
599
  )
600
 
601
+ if __name__ == "__main__":
602
+ demo.queue(max_size=20).launch()
requirements.txt CHANGED
@@ -3,4 +3,10 @@ diffusers
3
  invisible_watermark
4
  torch
5
  transformers
6
- xformers
 
 
 
 
 
 
 
3
  invisible_watermark
4
  torch
5
  transformers
6
+ xformers
7
+ flash-attn
8
+ bitsandbytes
9
+ azure-core
10
+ azure-storage-blob
11
+ azure-identity
12
+ einops
style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #component-0{
2
+ max-width: 900px;
3
+ margin: 0 auto;
4
+ }
5
+
6
+ #description, h1 {
7
+ text-align: center;
8
+ }
9
+
10
+ #duplicate-button {
11
+ margin: auto;
12
+ color: #fff;
13
+ background: #1565c0;
14
+ border-radius: 100vh;
15
+ }
16
+