k4d3 commited on
Commit
b54f417
1 Parent(s): ee41534

removed perplexity stuff

Browse files

Signed-off-by: Balazs Horvath <acsipont@gmail.com>

Files changed (1) hide show
  1. joy +19 -46
joy CHANGED
@@ -330,20 +330,14 @@ class JoyCaptionModel:
330
  caption_type: str,
331
  caption_tone: str,
332
  caption_length: str | int,
333
- custom_prompt: str | None = None
334
- ) -> Tuple[str, float, float]:
335
  """
336
- Process the input image and generate a caption.
337
-
338
- Args:
339
- input_image (Image.Image): The input image to caption.
340
- caption_type (str): The type of caption to generate.
341
- caption_tone (str): The tone of the caption.
342
- caption_length (str | int): The desired length of the caption.
343
- custom_prompt (str | None): A custom prompt for caption generation.
344
 
345
  Returns:
346
- Tuple[str, float, float]: A tuple containing the generated caption, its entropy, and its perplexity.
347
  """
348
  torch.cuda.empty_cache()
349
 
@@ -370,11 +364,7 @@ class JoyCaptionModel:
370
  token_ids = generate_ids[0].tolist()
371
  entropy = self._calculate_entropy(token_ids)
372
 
373
- # Calculate perplexity
374
- loss = self._calculate_perplexity(generate_ids, input_ids)
375
- perplexity = math.exp(-loss)
376
-
377
- return caption.strip(), entropy, perplexity
378
 
379
  def generate_valid_caption(
380
  self,
@@ -388,7 +378,6 @@ class JoyCaptionModel:
388
  min_sentence_count: int = 3,
389
  max_word_repetitions: int = 5,
390
  min_entropy: float = 1.75,
391
- max_perplexity: float = 100.0,
392
  stop_words: set[str] = STOP_WORDS
393
  ) -> str:
394
  """
@@ -400,18 +389,24 @@ class JoyCaptionModel:
400
  caption_tone (str): The tone of the caption.
401
  caption_length (str | int): The desired length of the caption.
402
  custom_prompt (str | None): A custom prompt for caption generation.
403
- limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy": 2}.
404
  min_sentence_count (int): Minimum required number of sentences. Default is 3.
405
- max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is 5.
406
- min_entropy (float): Minimum required entropy of the caption. Default is 1.75.
407
- max_perplexity (float): Maximum allowed perplexity of the caption. Default is 100.0.
408
- stop_words (set[str]): Set of stop words to exclude from repetition checks. Default is STOP_WORDS.
409
 
410
  Returns:
411
  str: A valid caption meeting all specified criteria.
 
 
 
 
 
 
 
 
412
  """
413
  while True:
414
- caption, entropy, perplexity = self.process_image(
415
  input_image, caption_type, caption_tone,
416
  caption_length, custom_prompt
417
  )
@@ -435,8 +430,6 @@ class JoyCaptionModel:
435
  print(f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}")
436
  elif entropy < min_entropy:
437
  print(f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}")
438
- elif perplexity > max_perplexity:
439
- print(f"Retrying: High perplexity ({perplexity:.2f} > {max_perplexity}).\nCaption: {caption!r}")
440
  else:
441
  return caption
442
 
@@ -597,26 +590,6 @@ class JoyCaptionModel:
597
 
598
  return entropy
599
 
600
- def _calculate_perplexity(self, generate_ids, input_ids):
601
- """
602
- Calculate the perplexity of the generated caption.
603
-
604
- Args:
605
- generate_ids (torch.Tensor): Generated token IDs.
606
- input_ids (torch.Tensor): Input token IDs.
607
-
608
- Returns:
609
- float: Perplexity of the generated caption.
610
- """
611
- with torch.no_grad():
612
- outputs = self.text_model(
613
- input_ids=input_ids,
614
- labels=generate_ids,
615
- output_hidden_states=True,
616
- )
617
- loss = outputs.loss
618
- return loss.item()
619
-
620
 
621
  def main():
622
  """
@@ -738,7 +711,7 @@ def main():
738
  args, image_path, tagset_normalizer
739
  )
740
 
741
- print(f"\nCaptioning {image_path}...\nCustom prompt: {custom_prompt}")
742
 
743
  caption = joy_caption_model.generate_valid_caption(
744
  input_image,
 
330
  caption_type: str,
331
  caption_tone: str,
332
  caption_length: str | int,
333
+ custom_prompt: str | None = None,
334
+ ) -> Tuple[str, float]:
335
  """
336
+ Process an input image and generate a caption based on specified parameters.
337
+ Also calculates the entropy of the generated caption.
 
 
 
 
 
 
338
 
339
  Returns:
340
+ Tuple[str, float]: The generated caption and its entropy.
341
  """
342
  torch.cuda.empty_cache()
343
 
 
364
  token_ids = generate_ids[0].tolist()
365
  entropy = self._calculate_entropy(token_ids)
366
 
367
+ return caption.strip(), entropy
 
 
 
 
368
 
369
  def generate_valid_caption(
370
  self,
 
378
  min_sentence_count: int = 3,
379
  max_word_repetitions: int = 5,
380
  min_entropy: float = 1.75,
 
381
  stop_words: set[str] = STOP_WORDS
382
  ) -> str:
383
  """
 
389
  caption_tone (str): The tone of the caption.
390
  caption_length (str | int): The desired length of the caption.
391
  custom_prompt (str | None): A custom prompt for caption generation.
392
+ limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy": 1}.
393
  min_sentence_count (int): Minimum required number of sentences. Default is 3.
394
+ max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is 15.
395
+ min_entropy (float): Minimum required entropy of the caption. Default is 2.3.
 
 
396
 
397
  Returns:
398
  str: A valid caption meeting all specified criteria.
399
+
400
+ The method retries caption generation if:
401
+ - The caption contains only special characters
402
+ - The caption does not end with a period, exclamation mark, or question mark
403
+ - Any word in limited_words appears more than its specified maximum times
404
+ - Any word longer than 4 characters is repeated more than max_word_repetitions times
405
+ - The caption contains fewer than min_sentence_count sentences
406
+ - The entropy of the caption is below min_entropy
407
  """
408
  while True:
409
+ caption, entropy = self.process_image(
410
  input_image, caption_type, caption_tone,
411
  caption_length, custom_prompt
412
  )
 
430
  print(f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}")
431
  elif entropy < min_entropy:
432
  print(f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}")
 
 
433
  else:
434
  return caption
435
 
 
590
 
591
  return entropy
592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
 
594
  def main():
595
  """
 
711
  args, image_path, tagset_normalizer
712
  )
713
 
714
+ print(f"\nCustom prompt: {custom_prompt}")
715
 
716
  caption = joy_caption_model.generate_valid_caption(
717
  input_image,