Tech-Meld's picture
Update app.py
24d193b verified
raw
history blame contribute delete
No virus
2.62 kB
import gradio as gr
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces
import torch
import re
from PIL import Image
model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval()
processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
def modify_caption(caption: str) -> str:
"""
Removes specific prefixes from captions.
Args:
caption (str): A string containing a caption.
Returns:
str: The caption with the prefix removed if it was present.
"""
prefix_substrings = [
('captured from ', ''),
('captured at ', '')
]
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
replacers = {opening: replacer for opening, replacer in prefix_substrings}
def replace_fn(match):
return replacers[match.group(0)]
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
def create_captions_rich(files):
captions = []
prompt = "caption en"
for file_path in files:
try:
image = Image.open(file_path.name)
except Exception as e:
captions.append(f"Error opening image: {e}")
continue
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cpu")
input_len = model_inputs["input_ids"].shape[-1]
try:
with torch.no_grad():
generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
modified_caption = modify_caption(decoded)
captions.append(modified_caption)
except Exception as e:
captions.append(f"Error generating caption: {e}")
return "\n".join(captions)
css = """
#mkd {
height: 500px;
overflow: auto;
border: 16px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Fine-tuned PaliGemma for SD3 Image Guided Prompt Generation.<center><h1>")
with gr.Tab(label="Image to Prompt for SD3"):
with gr.Row():
with gr.Column():
input_files = gr.Files(label="Input Images")
submit_btn = gr.Button(value="Start")
outputs = gr.Textbox(label="Prompts", lines=10, interactive=False)
submit_btn.click(create_captions_rich, inputs=[input_files], outputs=[outputs])
demo.launch(debug=True)