Severian commited on
Commit
250653b
·
verified ·
1 Parent(s): f45fc1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -86
app.py CHANGED
@@ -267,92 +267,97 @@ image_adapter.to("cuda")
267
  def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int,
268
  lens_type: str = "standard", film_stock: str = "digital",
269
  composition: str = "rule of thirds", lighting: str = "natural") -> str:
270
- torch.cuda.empty_cache()
271
-
272
- # 'any' means no length specified
273
- length = None if caption_length == "any" else caption_length
274
-
275
- if isinstance(length, str):
276
- try:
277
- length = int(length)
278
- except ValueError:
279
- pass
280
-
281
- # 'rng-tags' and 'training_prompt' don't have formal/informal tones
282
- if caption_type == "rng-tags" or caption_type == "training_prompt":
283
- caption_tone = "formal"
284
-
285
- # Build prompt
286
- prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
287
- if prompt_key not in CAPTION_TYPE_MAP:
288
- raise ValueError(f"Invalid caption type: {prompt_key}")
289
-
290
- prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
291
-
292
- # Add style prompt details if applicable
293
- if caption_type == "style_prompt":
294
- prompt_str += (f" The prompt should specifically include details about using a {lens_type} lens, "
295
- f"{film_stock} film stock, {composition} composition, and {lighting} lighting. "
296
- f"Format the output as a comma-separated list of descriptors and modifiers, "
297
- f"suitable for direct input into a Stable Diffusion interface.")
298
-
299
- print(f"Prompt: {prompt_str}")
300
-
301
- # Preprocess image
302
- #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
303
- image = input_image.resize((384, 384), Image.LANCZOS)
304
- pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
305
- pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
306
- pixel_values = pixel_values.to('cuda')
307
-
308
- # Tokenize the prompt
309
- prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
310
-
311
- # Embed image
312
- with torch.amp.autocast_mode.autocast('cuda', enabled=True):
313
- vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
314
- image_features = vision_outputs.hidden_states
315
- embedded_images = image_adapter(image_features)
316
- embedded_images = embedded_images.to('cuda')
317
-
318
- # Embed prompt
319
- prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
320
- assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
321
- embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
322
- eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
323
-
324
- # Construct prompts
325
- inputs_embeds = torch.cat([
326
- embedded_bos.expand(embedded_images.shape[0], -1, -1),
327
- embedded_images.to(dtype=embedded_bos.dtype),
328
- prompt_embeds.expand(embedded_images.shape[0], -1, -1),
329
- eot_embed.expand(embedded_images.shape[0], -1, -1),
330
- ], dim=1)
331
-
332
- input_ids = torch.cat([
333
- torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
334
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
335
- prompt,
336
- torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
337
- ], dim=1).to('cuda')
338
- attention_mask = torch.ones_like(input_ids)
339
-
340
- #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
341
- #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
342
- generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
343
-
344
- # Trim off the prompt
345
- generate_ids = generate_ids[:, input_ids.shape[1]:]
346
- if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
347
- generate_ids = generate_ids[:, :-1]
348
-
349
- caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
350
-
351
- # For style_prompt, format the output for easy copying into image generation platforms
352
- if caption_type == "style_prompt":
353
- caption = "Stable Diffusion Prompt: " + caption.replace("\n", ", ")
354
-
355
- return caption.strip()
 
 
 
 
 
356
 
357
  css = """
358
  h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
 
267
  def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int,
268
  lens_type: str = "standard", film_stock: str = "digital",
269
  composition: str = "rule of thirds", lighting: str = "natural") -> str:
270
+ torch.cuda.empty_cache()
271
+
272
+ # 'any' means no length specified
273
+ length = None if caption_length == "any" else caption_length
274
+
275
+ if isinstance(length, str):
276
+ try:
277
+ length = int(length)
278
+ except ValueError:
279
+ pass
280
+
281
+ # 'rng-tags' and 'training_prompt' don't have formal/informal tones
282
+ if caption_type == "rng-tags" or caption_type == "training_prompt":
283
+ caption_tone = "formal"
284
+
285
+ # Build prompt
286
+ prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
287
+ if prompt_key not in CAPTION_TYPE_MAP:
288
+ raise ValueError(f"Invalid caption type: {prompt_key}")
289
+
290
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
291
+
292
+ # Add style prompt details if applicable
293
+ if caption_type == "style_prompt":
294
+ prompt_str += (f" The prompt should specifically include details about using a {lens_type} lens, "
295
+ f"{film_stock} film stock, {composition} composition, and {lighting} lighting. "
296
+ f"Format the output as a comma-separated list of descriptors and modifiers, "
297
+ f"suitable for direct input into a Stable Diffusion interface.")
298
+
299
+ print(f"Prompt: {prompt_str}")
300
+
301
+ # Preprocess image
302
+ image = input_image.resize((384, 384), Image.LANCZOS)
303
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
304
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
305
+ pixel_values = pixel_values.to('cuda')
306
+
307
+ # Tokenize the prompt
308
+ prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
309
+
310
+ # Embed image
311
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
312
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
313
+ image_features = vision_outputs.hidden_states
314
+ embedded_images = image_adapter(image_features)
315
+ embedded_images = embedded_images.to('cuda')
316
+
317
+ # Embed prompt
318
+ prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
319
+ assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
320
+
321
+ # Check if bos_token_id exists
322
+ if tokenizer.bos_token_id is None:
323
+ print("Warning: bos_token_id is None. Using default value of 1.")
324
+ bos_token_id = 1
325
+ else:
326
+ bos_token_id = tokenizer.bos_token_id
327
+
328
+ embedded_bos = text_model.model.embed_tokens(torch.tensor([[bos_token_id]], device=text_model.device, dtype=torch.int64))
329
+ eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
330
+
331
+ # Construct prompts
332
+ inputs_embeds = torch.cat([
333
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
334
+ embedded_images.to(dtype=embedded_bos.dtype),
335
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
336
+ eot_embed.expand(embedded_images.shape[0], -1, -1),
337
+ ], dim=1)
338
+
339
+ input_ids = torch.cat([
340
+ torch.tensor([[bos_token_id]], dtype=torch.long),
341
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
342
+ prompt,
343
+ torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
344
+ ], dim=1).to('cuda')
345
+ attention_mask = torch.ones_like(input_ids)
346
+
347
+ generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None)
348
+
349
+ # Trim off the prompt
350
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
351
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
352
+ generate_ids = generate_ids[:, :-1]
353
+
354
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
355
+
356
+ # For style_prompt, format the output for easy copying into image generation platforms
357
+ if caption_type == "style_prompt":
358
+ caption = "Stable Diffusion Prompt: " + caption.replace("\n", ", ")
359
+
360
+ return caption.strip()
361
 
362
  css = """
363
  h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {