Update app.py
Browse files
app.py
CHANGED
@@ -267,92 +267,97 @@ image_adapter.to("cuda")
|
|
267 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int,
|
268 |
lens_type: str = "standard", film_stock: str = "digital",
|
269 |
composition: str = "rule of thirds", lighting: str = "natural") -> str:
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
css = """
|
358 |
h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
|
|
|
267 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int,
|
268 |
lens_type: str = "standard", film_stock: str = "digital",
|
269 |
composition: str = "rule of thirds", lighting: str = "natural") -> str:
|
270 |
+
torch.cuda.empty_cache()
|
271 |
+
|
272 |
+
# 'any' means no length specified
|
273 |
+
length = None if caption_length == "any" else caption_length
|
274 |
+
|
275 |
+
if isinstance(length, str):
|
276 |
+
try:
|
277 |
+
length = int(length)
|
278 |
+
except ValueError:
|
279 |
+
pass
|
280 |
+
|
281 |
+
# 'rng-tags' and 'training_prompt' don't have formal/informal tones
|
282 |
+
if caption_type == "rng-tags" or caption_type == "training_prompt":
|
283 |
+
caption_tone = "formal"
|
284 |
+
|
285 |
+
# Build prompt
|
286 |
+
prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
|
287 |
+
if prompt_key not in CAPTION_TYPE_MAP:
|
288 |
+
raise ValueError(f"Invalid caption type: {prompt_key}")
|
289 |
+
|
290 |
+
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
|
291 |
+
|
292 |
+
# Add style prompt details if applicable
|
293 |
+
if caption_type == "style_prompt":
|
294 |
+
prompt_str += (f" The prompt should specifically include details about using a {lens_type} lens, "
|
295 |
+
f"{film_stock} film stock, {composition} composition, and {lighting} lighting. "
|
296 |
+
f"Format the output as a comma-separated list of descriptors and modifiers, "
|
297 |
+
f"suitable for direct input into a Stable Diffusion interface.")
|
298 |
+
|
299 |
+
print(f"Prompt: {prompt_str}")
|
300 |
+
|
301 |
+
# Preprocess image
|
302 |
+
image = input_image.resize((384, 384), Image.LANCZOS)
|
303 |
+
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
304 |
+
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
305 |
+
pixel_values = pixel_values.to('cuda')
|
306 |
+
|
307 |
+
# Tokenize the prompt
|
308 |
+
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
309 |
+
|
310 |
+
# Embed image
|
311 |
+
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
|
312 |
+
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
313 |
+
image_features = vision_outputs.hidden_states
|
314 |
+
embedded_images = image_adapter(image_features)
|
315 |
+
embedded_images = embedded_images.to('cuda')
|
316 |
+
|
317 |
+
# Embed prompt
|
318 |
+
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
319 |
+
assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
|
320 |
+
|
321 |
+
# Check if bos_token_id exists
|
322 |
+
if tokenizer.bos_token_id is None:
|
323 |
+
print("Warning: bos_token_id is None. Using default value of 1.")
|
324 |
+
bos_token_id = 1
|
325 |
+
else:
|
326 |
+
bos_token_id = tokenizer.bos_token_id
|
327 |
+
|
328 |
+
embedded_bos = text_model.model.embed_tokens(torch.tensor([[bos_token_id]], device=text_model.device, dtype=torch.int64))
|
329 |
+
eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
|
330 |
+
|
331 |
+
# Construct prompts
|
332 |
+
inputs_embeds = torch.cat([
|
333 |
+
embedded_bos.expand(embedded_images.shape[0], -1, -1),
|
334 |
+
embedded_images.to(dtype=embedded_bos.dtype),
|
335 |
+
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
|
336 |
+
eot_embed.expand(embedded_images.shape[0], -1, -1),
|
337 |
+
], dim=1)
|
338 |
+
|
339 |
+
input_ids = torch.cat([
|
340 |
+
torch.tensor([[bos_token_id]], dtype=torch.long),
|
341 |
+
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
|
342 |
+
prompt,
|
343 |
+
torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
|
344 |
+
], dim=1).to('cuda')
|
345 |
+
attention_mask = torch.ones_like(input_ids)
|
346 |
+
|
347 |
+
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None)
|
348 |
+
|
349 |
+
# Trim off the prompt
|
350 |
+
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
351 |
+
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
352 |
+
generate_ids = generate_ids[:, :-1]
|
353 |
+
|
354 |
+
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
355 |
+
|
356 |
+
# For style_prompt, format the output for easy copying into image generation platforms
|
357 |
+
if caption_type == "style_prompt":
|
358 |
+
caption = "Stable Diffusion Prompt: " + caption.replace("\n", ", ")
|
359 |
+
|
360 |
+
return caption.strip()
|
361 |
|
362 |
css = """
|
363 |
h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
|