File size: 8,427 Bytes
5041f6c
 
be34e7c
5041f6c
 
656934c
 
5741e23
656934c
 
1112f1b
 
 
656934c
 
 
 
 
 
 
5041f6c
656934c
 
1112f1b
656934c
1112f1b
 
5041f6c
 
1112f1b
656934c
5041f6c
d258d19
656934c
 
 
 
 
 
 
 
 
 
 
 
 
 
5b8a47b
 
656934c
 
 
 
 
f7baa55
 
656934c
474c1ae
656934c
5041f6c
 
 
 
656934c
5041f6c
2d904b5
 
 
 
 
 
 
 
 
 
db4a96c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e64bf0
35953b3
 
 
 
 
 
 
 
 
 
 
 
3a74692
35953b3
 
 
 
 
656934c
ef893ea
 
 
9fbff60
748aa2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e68751
 
62ed9c9
 
 
29ab471
 
62ed9c9
3e68751
 
 
 
 
d227cb9
 
3e68751
 
d601aed
 
 
 
 
77ce130
 
d601aed
 
3e68751
 
748aa2b
 
 
 
 
 
 
 
9fbff60
748aa2b
9fbff60
748aa2b
 
 
9fbff60
748aa2b
 
9fbff60
748aa2b
 
 
 
e570f26
db4a96c
 
 
9fbff60
db4a96c
 
656934c
db4a96c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
import torch
torch.jit.script = lambda f: f  # Avoid script error in lambda
from t2v_metrics import VQAScore, list_all_vqascore_models

def update_model(model_name):
    return VQAScore(model=model_name, device="cuda")

# Use global variables for model pipe and current model name
global model_pipe, cur_model_name
cur_model_name = "clip-flant5-xl"
model_pipe = update_model(cur_model_name)

# Ensure GPU context manager is imported correctly (assuming spaces is a module you have)
try:
    from spaces import GPU
except ImportError:
    GPU = lambda duration: (lambda f: f)  # Dummy decorator if spaces.GPU is not available

@GPU(duration=20)
def generate(model_name, image, text):
    global model_pipe, cur_model_name
    
    if model_name != cur_model_name:
        cur_model_name = model_name  # Update the current model name
        model_pipe = update_model(model_name)
    
    print("Image:", image)  # Debug: Print image path
    print("Text:", text)  # Debug: Print text input
    print("Using model:", model_name)
    
    try:
        result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item()  # Perform the model inference
        print("Result:", result)
    except RuntimeError as e:
        print(f"RuntimeError during model inference: {e}")
        raise e
    
    return result

@GPU(duration=20)
def rank_images(model_name, images, text):
    global model_pipe, cur_model_name
    
    if model_name != cur_model_name:
        cur_model_name = model_name  # Update the current model name
        model_pipe = update_model(model_name)

    images = [image_tuple[0] for image_tuple in images]
    print("Images:", images)  # Debug: Print image paths
    print("Text:", text)  # Debug: Print text input
    print("Using model:", model_name)
    
    try:
        results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist()  # Perform the model inference on all images
        print("Initial results: should be imgs x texts", results)
        ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True)  # Rank results
        ranked_images = [(img, f"Rank: {rank + 1} - Score: {score:.2f}") for rank, (img, score) in enumerate(ranked_results)]  # Pair images with their scores and rank
        print("Ranked Results:", ranked_results)
    except RuntimeError as e:
        print(f"RuntimeError during model inference: {e}")
        raise e
    
    return ranked_images


### EXAMPLES ###
example_imgs = ["0_imgs/DALLE3.png", 
                "0_imgs/DeepFloyd.jpg",
                "0_imgs/Midjourney.jpg",
                "0_imgs/SDXL.jpg"]
example_prompt0 = "Two dogs of different breeds playfully chasing around a tree"
example_prompt1 = "Two dogs of the same breed playing on the grass"
###

# # Create the first demo
# demo_vqascore = gr.Interface(
#     fn=generate,  # function to call
#     inputs=[
#         gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl", ], label="Model Name"), 
#         gr.Image(type="filepath"), 
#         gr.Textbox(label="Prompt")
#     ],  # define the types of inputs
#     examples=[
#         ["clip-flant5-xl", example_imgs[0], example_prompt0],
#         ["clip-flant5-xl", example_imgs[0], example_prompt1],
#         ],
#     outputs="number",  # define the type of output
#     title="VQAScore",  # title of the app
#     description="This model evaluates the similarity between an image and a text prompt."
# )

# # Create the second demo
# demo_vqascore_ranking = gr.Interface(
#     fn=rank_images,  # function to call
#     inputs=[
#         gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), 
#         gr.Gallery(label="Generated Images"), 
#         gr.Textbox(label="Prompt")
#     ],  # define the types of inputs
#     outputs=gr.Gallery(label="Ranked Images"),  # define the type of output
#     examples=[
#         ["clip-flant5-xl", [[img, ""] for img in example_imgs], example_prompt0],
#         ["clip-flant5-xl", [[img, ""] for img in example_imgs], example_prompt1]
        
#     ],
#     title="VQAScore Ranking",  # title of the app
#     description="This model ranks a gallery of images based on their similarity to a text prompt.",
#     allow_flagging='never'  
# )

# Custom component for loading examples
def load_example(model_name, images, prompt):
    return model_name, images, prompt


# demo_vqascore = gr.Interface(
#     fn=generate,  # function to call
#     inputs=[
#         gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl", ], label="Model Name"), 
#         gr.Image(type="filepath"), 
#         gr.Textbox(label="Prompt")
#     ],  # define the types of inputs
#     examples=[
#         ["clip-flant5-xl", example_imgs[0], example_prompt0],
#         ["clip-flant5-xl", example_imgs[0], example_prompt1],
#         ],
#     outputs="number",  # define the type of output
#     title="VQAScore",  # title of the app
#     description="This model evaluates the similarity between an image and a text prompt."
# )

# Create the second demo: VQAScore Ranking
with gr.Blocks() as demo_vqascore_ranking:
    # gr.Markdown("# VQAScore Ranking\nThis model ranks a gallery of images based on their similarity to a text prompt.")
    gr.Markdown("""
    # VQAScore Ranking
    This demo ranks a gallery of images by their VQAScores to an input text prompt. Try examples 1 and 2, or use your own images and prompts.
    If you encounter errors, the model may not have loaded on the GPU properly. Retrying usually resolves this issue.
    """)
    
    with gr.Row():
        with gr.Column():
            model_dropdown = gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl"], value="clip-flant5-xxl", label="Model Name")
            prompt = gr.Textbox(label="Prompt")
            gallery = gr.Gallery(label="Input Image(s)", elem_id="input-gallery", columns=4, allow_preview=True)
            rank_button = gr.Button("Submit")
            
        with gr.Column():

            ranked_gallery = gr.Gallery(label="Output: Ranked Images with Scores", elem_id="ranked-gallery", columns=4, allow_preview=True)
            rank_button.click(fn=rank_images, inputs=[model_dropdown, gallery, prompt], outputs=ranked_gallery)

            
            example1_button = gr.Button("Load Example 1")
            example2_button = gr.Button("Load Example 2")
            example1_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt0), inputs=[], outputs=[model_dropdown, gallery, prompt])
            example2_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt1), inputs=[], outputs=[model_dropdown, gallery, prompt])
    
    
# # Create the second demo
# with gr.Blocks() as demo_vqascore_ranking:
#     gr.Markdown("# VQAScore Ranking\nThis model ranks a gallery of images based on their similarity to a text prompt.")
#     model_dropdown = gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl"], value="clip-flant5-xxl",  label="Model Name")
#     gallery = gr.Gallery(label="Generated Images", elem_id="input-gallery", columns=4, allow_preview=True)
#     prompt = gr.Textbox(label="Prompt")
#     rank_button = gr.Button("Rank Images")
#     ranked_gallery = gr.Gallery(label="Ranked Images with Scores", elem_id="ranked-gallery", columns=4, allow_preview=True)
    
#     rank_button.click(fn=rank_images, inputs=[model_dropdown, gallery, prompt], outputs=ranked_gallery)
    
#     # Custom example buttons
#     example1_button = gr.Button("Load Example 1")
#     example2_button = gr.Button("Load Example 2")
    
#     example1_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt0), inputs=[], outputs=[model_dropdown, gallery, prompt])
#     example2_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt1), inputs=[], outputs=[model_dropdown, gallery, prompt])

#     # Layout to allow user to input their own data
#     with gr.Row():
#         gr.Column([model_dropdown, gallery, prompt, rank_button])
#         gr.Column([example1_button, example2_button])

# Launch the interface
demo_vqascore_ranking.queue()
demo_vqascore_ranking.launch(share=False)

# # Combine the demos into a tabbed interface
# tabbed_interface = gr.TabbedInterface([demo_vqascore, demo_vqascore_ranking], ["VQAScore", "VQAScore Ranking"])

# # Launch the tabbed interface
# tabbed_interface.queue()
# tabbed_interface.launch(share=False)