removed perplexity stuff
Browse filesSigned-off-by: Balazs Horvath <acsipont@gmail.com>
joy
CHANGED
@@ -330,20 +330,14 @@ class JoyCaptionModel:
|
|
330 |
caption_type: str,
|
331 |
caption_tone: str,
|
332 |
caption_length: str | int,
|
333 |
-
custom_prompt: str | None = None
|
334 |
-
) -> Tuple[str, float
|
335 |
"""
|
336 |
-
Process
|
337 |
-
|
338 |
-
Args:
|
339 |
-
input_image (Image.Image): The input image to caption.
|
340 |
-
caption_type (str): The type of caption to generate.
|
341 |
-
caption_tone (str): The tone of the caption.
|
342 |
-
caption_length (str | int): The desired length of the caption.
|
343 |
-
custom_prompt (str | None): A custom prompt for caption generation.
|
344 |
|
345 |
Returns:
|
346 |
-
Tuple[str, float
|
347 |
"""
|
348 |
torch.cuda.empty_cache()
|
349 |
|
@@ -370,11 +364,7 @@ class JoyCaptionModel:
|
|
370 |
token_ids = generate_ids[0].tolist()
|
371 |
entropy = self._calculate_entropy(token_ids)
|
372 |
|
373 |
-
|
374 |
-
loss = self._calculate_perplexity(generate_ids, input_ids)
|
375 |
-
perplexity = math.exp(-loss)
|
376 |
-
|
377 |
-
return caption.strip(), entropy, perplexity
|
378 |
|
379 |
def generate_valid_caption(
|
380 |
self,
|
@@ -388,7 +378,6 @@ class JoyCaptionModel:
|
|
388 |
min_sentence_count: int = 3,
|
389 |
max_word_repetitions: int = 5,
|
390 |
min_entropy: float = 1.75,
|
391 |
-
max_perplexity: float = 100.0,
|
392 |
stop_words: set[str] = STOP_WORDS
|
393 |
) -> str:
|
394 |
"""
|
@@ -400,18 +389,24 @@ class JoyCaptionModel:
|
|
400 |
caption_tone (str): The tone of the caption.
|
401 |
caption_length (str | int): The desired length of the caption.
|
402 |
custom_prompt (str | None): A custom prompt for caption generation.
|
403 |
-
limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy":
|
404 |
min_sentence_count (int): Minimum required number of sentences. Default is 3.
|
405 |
-
max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is
|
406 |
-
min_entropy (float): Minimum required entropy of the caption. Default is
|
407 |
-
max_perplexity (float): Maximum allowed perplexity of the caption. Default is 100.0.
|
408 |
-
stop_words (set[str]): Set of stop words to exclude from repetition checks. Default is STOP_WORDS.
|
409 |
|
410 |
Returns:
|
411 |
str: A valid caption meeting all specified criteria.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
"""
|
413 |
while True:
|
414 |
-
caption, entropy
|
415 |
input_image, caption_type, caption_tone,
|
416 |
caption_length, custom_prompt
|
417 |
)
|
@@ -435,8 +430,6 @@ class JoyCaptionModel:
|
|
435 |
print(f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}")
|
436 |
elif entropy < min_entropy:
|
437 |
print(f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}")
|
438 |
-
elif perplexity > max_perplexity:
|
439 |
-
print(f"Retrying: High perplexity ({perplexity:.2f} > {max_perplexity}).\nCaption: {caption!r}")
|
440 |
else:
|
441 |
return caption
|
442 |
|
@@ -597,26 +590,6 @@ class JoyCaptionModel:
|
|
597 |
|
598 |
return entropy
|
599 |
|
600 |
-
def _calculate_perplexity(self, generate_ids, input_ids):
|
601 |
-
"""
|
602 |
-
Calculate the perplexity of the generated caption.
|
603 |
-
|
604 |
-
Args:
|
605 |
-
generate_ids (torch.Tensor): Generated token IDs.
|
606 |
-
input_ids (torch.Tensor): Input token IDs.
|
607 |
-
|
608 |
-
Returns:
|
609 |
-
float: Perplexity of the generated caption.
|
610 |
-
"""
|
611 |
-
with torch.no_grad():
|
612 |
-
outputs = self.text_model(
|
613 |
-
input_ids=input_ids,
|
614 |
-
labels=generate_ids,
|
615 |
-
output_hidden_states=True,
|
616 |
-
)
|
617 |
-
loss = outputs.loss
|
618 |
-
return loss.item()
|
619 |
-
|
620 |
|
621 |
def main():
|
622 |
"""
|
@@ -738,7 +711,7 @@ def main():
|
|
738 |
args, image_path, tagset_normalizer
|
739 |
)
|
740 |
|
741 |
-
print(f"\
|
742 |
|
743 |
caption = joy_caption_model.generate_valid_caption(
|
744 |
input_image,
|
|
|
330 |
caption_type: str,
|
331 |
caption_tone: str,
|
332 |
caption_length: str | int,
|
333 |
+
custom_prompt: str | None = None,
|
334 |
+
) -> Tuple[str, float]:
|
335 |
"""
|
336 |
+
Process an input image and generate a caption based on specified parameters.
|
337 |
+
Also calculates the entropy of the generated caption.
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
Returns:
|
340 |
+
Tuple[str, float]: The generated caption and its entropy.
|
341 |
"""
|
342 |
torch.cuda.empty_cache()
|
343 |
|
|
|
364 |
token_ids = generate_ids[0].tolist()
|
365 |
entropy = self._calculate_entropy(token_ids)
|
366 |
|
367 |
+
return caption.strip(), entropy
|
|
|
|
|
|
|
|
|
368 |
|
369 |
def generate_valid_caption(
|
370 |
self,
|
|
|
378 |
min_sentence_count: int = 3,
|
379 |
max_word_repetitions: int = 5,
|
380 |
min_entropy: float = 1.75,
|
|
|
381 |
stop_words: set[str] = STOP_WORDS
|
382 |
) -> str:
|
383 |
"""
|
|
|
389 |
caption_tone (str): The tone of the caption.
|
390 |
caption_length (str | int): The desired length of the caption.
|
391 |
custom_prompt (str | None): A custom prompt for caption generation.
|
392 |
+
limited_words (Dict[str, int]): Dictionary of words with their maximum allowed occurrences. Default is {"fluffy": 1}.
|
393 |
min_sentence_count (int): Minimum required number of sentences. Default is 3.
|
394 |
+
max_word_repetitions (int): Maximum allowed repetitions for words longer than 4 characters. Default is 15.
|
395 |
+
min_entropy (float): Minimum required entropy of the caption. Default is 2.3.
|
|
|
|
|
396 |
|
397 |
Returns:
|
398 |
str: A valid caption meeting all specified criteria.
|
399 |
+
|
400 |
+
The method retries caption generation if:
|
401 |
+
- The caption contains only special characters
|
402 |
+
- The caption does not end with a period, exclamation mark, or question mark
|
403 |
+
- Any word in limited_words appears more than its specified maximum times
|
404 |
+
- Any word longer than 4 characters is repeated more than max_word_repetitions times
|
405 |
+
- The caption contains fewer than min_sentence_count sentences
|
406 |
+
- The entropy of the caption is below min_entropy
|
407 |
"""
|
408 |
while True:
|
409 |
+
caption, entropy = self.process_image(
|
410 |
input_image, caption_type, caption_tone,
|
411 |
caption_length, custom_prompt
|
412 |
)
|
|
|
430 |
print(f"Retrying: Only {sentence_count} sentences (min: {min_sentence_count}).\nCaption: {caption!r}")
|
431 |
elif entropy < min_entropy:
|
432 |
print(f"Retrying: Low entropy ({entropy:.2f} < {min_entropy}).\nCaption: {caption!r}")
|
|
|
|
|
433 |
else:
|
434 |
return caption
|
435 |
|
|
|
590 |
|
591 |
return entropy
|
592 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
|
594 |
def main():
|
595 |
"""
|
|
|
711 |
args, image_path, tagset_normalizer
|
712 |
)
|
713 |
|
714 |
+
print(f"\nCustom prompt: {custom_prompt}")
|
715 |
|
716 |
caption = joy_caption_model.generate_valid_caption(
|
717 |
input_image,
|