File size: 3,754 Bytes
5041f6c
 
be34e7c
5041f6c
 
656934c
 
5741e23
656934c
 
1112f1b
 
 
656934c
 
 
 
 
 
 
5041f6c
656934c
 
1112f1b
656934c
1112f1b
 
5041f6c
 
1112f1b
656934c
5041f6c
d258d19
656934c
 
 
 
 
 
 
 
 
 
 
 
 
 
5b8a47b
 
656934c
 
 
 
 
f7baa55
 
656934c
474c1ae
656934c
5041f6c
 
 
 
656934c
5041f6c
656934c
 
5041f6c
656934c
 
 
 
 
5041f6c
 
 
8e64bf0
 
656934c
 
 
 
 
 
 
 
 
 
474c1ae
 
656934c
 
 
 
 
 
 
5b8a47b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import torch
torch.jit.script = lambda f: f  # Avoid script error in lambda
from t2v_metrics import VQAScore, list_all_vqascore_models

def update_model(model_name):
    return VQAScore(model=model_name, device="cuda")

# Use global variables for model pipe and current model name
global model_pipe, cur_model_name
cur_model_name = "clip-flant5-xl"
model_pipe = update_model(cur_model_name)

# Ensure GPU context manager is imported correctly (assuming spaces is a module you have)
try:
    from spaces import GPU
except ImportError:
    GPU = lambda duration: (lambda f: f)  # Dummy decorator if spaces.GPU is not available

@GPU(duration=20)
def generate(model_name, image, text):
    global model_pipe, cur_model_name
    
    if model_name != cur_model_name:
        cur_model_name = model_name  # Update the current model name
        model_pipe = update_model(model_name)
    
    print("Image:", image)  # Debug: Print image path
    print("Text:", text)  # Debug: Print text input
    print("Using model:", model_name)
    
    try:
        result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item()  # Perform the model inference
        print("Result:", result)
    except RuntimeError as e:
        print(f"RuntimeError during model inference: {e}")
        raise e
    
    return result

@GPU(duration=20)
def rank_images(model_name, images, text):
    global model_pipe, cur_model_name
    
    if model_name != cur_model_name:
        cur_model_name = model_name  # Update the current model name
        model_pipe = update_model(model_name)

    images = [image_tuple[0] for image_tuple in images]
    print("Images:", images)  # Debug: Print image paths
    print("Text:", text)  # Debug: Print text input
    print("Using model:", model_name)
    
    try:
        results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist()  # Perform the model inference on all images
        print("Initial results: should be imgs x texts", results)
        ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True)  # Rank results
        ranked_images = [(img, f"Rank: {rank + 1} - Score: {score:.2f}") for rank, (img, score) in enumerate(ranked_results)]  # Pair images with their scores and rank
        print("Ranked Results:", ranked_results)
    except RuntimeError as e:
        print(f"RuntimeError during model inference: {e}")
        raise e
    
    return ranked_images

# Create the first demo
demo_vqascore = gr.Interface(
    fn=generate,  # function to call
    inputs=[
        gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), 
        gr.Image(type="filepath"), 
        gr.Textbox(label="Prompt")
    ],  # define the types of inputs
    outputs="number",  # define the type of output
    title="VQAScore",  # title of the app
    description="This model evaluates the similarity between an image and a text prompt."
)

# Create the second demo
demo_vqascore_ranking = gr.Interface(
    fn=rank_images,  # function to call
    inputs=[
        gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), 
        gr.Gallery(label="Generated Images"), 
        gr.Textbox(label="Prompt")
    ],  # define the types of inputs
    outputs=gr.Gallery(label="Ranked Images"),  # define the type of output
    title="VQAScore Ranking",  # title of the app
    description="This model ranks a gallery of images based on their similarity to a text prompt.",
    allow_flagging='never'  
)

# Combine the demos into a tabbed interface
tabbed_interface = gr.TabbedInterface([demo_vqascore, demo_vqascore_ranking], ["VQAScore", "VQAScore Ranking"])

# Launch the tabbed interface
tabbed_interface.queue()
tabbed_interface.launch(share=False)