Tech-Meld commited on
Commit
b502a48
1 Parent(s): cb6dd7d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
3
+ import spaces
4
+ import torch
5
+ import re
6
+
7
+ # Load the model and processor
8
+ model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval()
9
+ processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
10
+
11
+ def modify_caption(caption: str) -> str:
12
+ """
13
+ Removes specific prefixes from captions.
14
+ Args:
15
+ caption (str): A string containing a caption.
16
+ Returns:
17
+ str: The caption with the prefix removed if it was present.
18
+ """
19
+ prefix_substrings = [
20
+ ('captured from ', ''),
21
+ ('captured at ', '')
22
+ ]
23
+
24
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
25
+ replacers = {opening: replacer for opening, replacer in prefix_substrings}
26
+
27
+ def replace_fn(match):
28
+ return replacers[match.group(0)]
29
+
30
+ return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
31
+
32
+ def create_captions_rich(images):
33
+ """
34
+ Generates captions for input images.
35
+
36
+ Args:
37
+ images (list): List of images to generate captions for.
38
+
39
+ Returns:
40
+ list: List of captions, one for each input image.
41
+ """
42
+ captions = []
43
+ for image in images:
44
+ try:
45
+ prompt = "caption en"
46
+ model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cpu")
47
+ input_len = model_inputs["input_ids"].shape[-1]
48
+
49
+ with torch.inference_mode():
50
+ generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
51
+ generation = generation[0][input_len:]
52
+ decoded = processor.decode(generation, skip_special_tokens=True)
53
+
54
+ modified_caption = modify_caption(decoded)
55
+ captions.append(modified_caption)
56
+ except Exception as e:
57
+ captions.append(f"Error processing image: {e}")
58
+ return captions
59
+
60
+ css = """
61
+ #mkd {
62
+ height: 500px;
63
+ overflow: auto;
64
+ border: 8px solid #ccc;
65
+ }
66
+ """
67
+
68
+ with gr.Blocks(css=css) as demo:
69
+ gr.HTML("<h1><center>Image caption using finetuned PaliGemma on SD3 generation data.<center><h1>")
70
+ with gr.Tab(label="Img2Prompt for SD3"):
71
+ with gr.Row():
72
+ with gr.Column():
73
+ input_img = gr.Image(label="Input Image", tool="select", type="pil", interactive=True)
74
+ submit_btn = gr.Button(value="Start")
75
+ output = gr.Textbox(label="Prompt", lines=10, interactive=True)
76
+
77
+ submit_btn.click(create_captions_rich, [input_img], [output])
78
+
79
+ demo.launch(debug=True)