File size: 3,502 Bytes
577d9ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import torch
from PIL import Image
from processor import MultiModalProcessor
from inference import test_inference
from load_model import load_hf_model

# Load model and processor
MODEL_PATH = "merve/paligemma_vqav2"  # or your local model path
TOKENIZER_PATH = "./tokenizer"  # path to your local tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"

model, tokenizer = load_hf_model(MODEL_PATH, TOKENIZER_PATH, device)
model = model.eval()

num_image_tokens = model.config.vision_config.num_image_tokens
image_size = model.config.vision_config.image_size
max_length = 512
processor = MultiModalProcessor(tokenizer, num_image_tokens, image_size, max_length)

def generate_caption(image, prompt, max_tokens=300, temperature=0.8, top_p=0.9, do_sample=False):
    # Save the input image temporarily
    temp_image_path = "temp_image.jpg"
    Image.fromarray(image).save(temp_image_path)
    
    # Use the existing test_inference function
    result = []
    def capture_print(text):
        result.append(text)

    import builtins
    original_print = builtins.print
    builtins.print = capture_print

    test_inference(
        model,
        processor,
        device,
        prompt,
        temp_image_path,
        max_tokens,
        temperature,
        top_p,
        do_sample
    )

    builtins.print = original_print

    # Return the captured output
    return "".join(result)

# Define Gradio demo
with gr.Blocks(title="Image Captioning with PaliGemma", theme=gr.themes.Monochrome()) as demo:
    gr.Markdown(
        """
        # Image Captioning with PaliGemma
        This demo uses the PaliGemma model to generate captions for images.
        """
    )
    
    with gr.Tabs():
        with gr.TabItem("Generate Caption"):
            with gr.Row():
                with gr.Column(scale=1):
                    image_input = gr.Image(type="numpy", label="Upload Image")
                    prompt_input = gr.Textbox(label="Prompt", placeholder="What is happening in the photo?")
                
                with gr.Column(scale=1):
                    with gr.Group():
                        max_tokens_input = gr.Slider(1, 500, value=300, step=1, label="Max Tokens")
                        temperature_input = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature")
                        top_p_input = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top P")
                        do_sample_input = gr.Checkbox(label="Do Sample")
                    
                    generate_button = gr.Button("Generate Caption")
            
            output = gr.Textbox(label="Generated Caption", lines=5)
        
        with gr.TabItem("About"):
            gr.Markdown(
                """
                ## How to use:
                1. Upload an image in the 'Generate Caption' tab.
                2. Enter a prompt to guide the caption generation.
                3. Adjust the generation parameters if desired.
                4. Click 'Generate Caption' to see the results.

                ## Model Details:
                - Model: PaliGemma
                - Type: Multimodal (Text + Image)
                - Task: Image Captioning
                """
            )

    generate_button.click(
        generate_caption,
        inputs=[image_input, prompt_input, max_tokens_input, temperature_input, top_p_input, do_sample_input],
        outputs=output
    )

# Launch the demo
if __name__ == "__main__":
    demo.launch()