File size: 2,616 Bytes
b502a48
 
 
 
 
24d193b
b502a48
a66b74b
b502a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bd959a
b502a48
 
24d193b
b502a48
8bd959a
 
24d193b
 
 
 
 
 
 
a66b74b
8bd959a
 
24d193b
 
 
 
 
 
 
 
 
8bd959a
24d193b
b502a48
 
 
 
 
8bd959a
b502a48
 
 
 
8bd959a
 
ac510fc
8bd959a
 
24d193b
8bd959a
ac510fc
b502a48
24d193b
b502a48
8bd959a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces
import torch
import re
from PIL import Image

model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval()
processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")

def modify_caption(caption: str) -> str:
    """
    Removes specific prefixes from captions.
    Args:
        caption (str): A string containing a caption.
    Returns:
        str: The caption with the prefix removed if it was present.
    """
    prefix_substrings = [
        ('captured from ', ''),
        ('captured at ', '')
    ]
    
    pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
    replacers = {opening: replacer for opening, replacer in prefix_substrings}
    
    def replace_fn(match):
        return replacers[match.group(0)]

    return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)

def create_captions_rich(files):
    captions = []
    prompt = "caption en"
    
    for file_path in files:
        try:
            image = Image.open(file_path.name)
        except Exception as e:
            captions.append(f"Error opening image: {e}")
            continue
        
        model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cpu")
        input_len = model_inputs["input_ids"].shape[-1]
        
        try:
            with torch.no_grad():
                generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
                generation = generation[0][input_len:]
                decoded = processor.decode(generation, skip_special_tokens=True)
                modified_caption = modify_caption(decoded)
                captions.append(modified_caption)
        except Exception as e:
            captions.append(f"Error generating caption: {e}")
    
    return "\n".join(captions)

css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 16px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1><center>Fine-tuned PaliGemma for SD3 Image Guided Prompt Generation.<center><h1>")
    
    with gr.Tab(label="Image to Prompt for SD3"):
        with gr.Row():
            with gr.Column():
                input_files = gr.Files(label="Input Images")
                submit_btn = gr.Button(value="Start")
            outputs = gr.Textbox(label="Prompts", lines=10, interactive=False)

        submit_btn.click(create_captions_rich, inputs=[input_files], outputs=[outputs])

demo.launch(debug=True)