gokaygokay commited on
Commit
81a88cf
1 Parent(s): c30f461

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -1
app.py CHANGED
@@ -6,6 +6,36 @@ import torch
6
  model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cuda").eval()
7
  processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  @spaces.GPU
10
  def create_captions_rich(image):
11
  prompt = "caption en"
@@ -16,7 +46,10 @@ def create_captions_rich(image):
16
  generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
17
  generation = generation[0][input_len:]
18
  decoded = processor.decode(generation, skip_special_tokens=True)
19
- return decoded
 
 
 
20
 
21
  css = """
22
  #mkd {
 
6
  model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cuda").eval()
7
  processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
8
 
9
+ import re
10
+
11
+ def modify_caption(caption: str) -> str:
12
+ """
13
+ Removes specific prefixes from captions.
14
+
15
+ Args:
16
+ caption (str): A string containing a caption.
17
+
18
+ Returns:
19
+ str: The caption with the prefix removed if it was present.
20
+ """
21
+ # Define the prefixes to remove
22
+ prefix_substrings = [
23
+ ('captured from ', ''),
24
+ ('captured at ', '')
25
+ ]
26
+
27
+ # Create a regex pattern to match any of the prefixes
28
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
29
+ replacers = {opening: replacer for opening, replacer in prefix_substrings}
30
+
31
+ # Function to replace matched prefix with its corresponding replacement
32
+ def replace_fn(match):
33
+ return replacers[match.group(0)]
34
+
35
+ # Apply the regex to the caption
36
+ return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE).capitalize()
37
+
38
+ # Example usage in your existing function
39
  @spaces.GPU
40
  def create_captions_rich(image):
41
  prompt = "caption en"
 
46
  generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
47
  generation = generation[0][input_len:]
48
  decoded = processor.decode(generation, skip_special_tokens=True)
49
+
50
+ # Modify the caption to remove specific prefixes
51
+ modified_caption = modify_caption(decoded)
52
+ return modified_caption
53
 
54
  css = """
55
  #mkd {