Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
torch.jit.script = lambda f: f # Avoid script error in lambda | |
from t2v_metrics import VQAScore, list_all_vqascore_models | |
def update_model(model_name): | |
return VQAScore(model=model_name, device="cuda") | |
# Use global variables for model pipe and current model name | |
global model_pipe, cur_model_name | |
cur_model_name = "clip-flant5-xl" | |
model_pipe = update_model(cur_model_name) | |
# Ensure GPU context manager is imported correctly (assuming spaces is a module you have) | |
try: | |
from spaces import GPU | |
except ImportError: | |
GPU = lambda duration: (lambda f: f) # Dummy decorator if spaces.GPU is not available | |
def generate(model_name, image, text): | |
global model_pipe, cur_model_name | |
if model_name != cur_model_name: | |
cur_model_name = model_name # Update the current model name | |
model_pipe = update_model(model_name) | |
print("Image:", image) # Debug: Print image path | |
print("Text:", text) # Debug: Print text input | |
print("Using model:", model_name) | |
try: | |
result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference | |
print("Result:", result) | |
except RuntimeError as e: | |
print(f"RuntimeError during model inference: {e}") | |
raise e | |
return result | |
def rank_images(model_name, images, text): | |
global model_pipe, cur_model_name | |
if model_name != cur_model_name: | |
cur_model_name = model_name # Update the current model name | |
model_pipe = update_model(model_name) | |
images = [image_tuple[0] for image_tuple in images] | |
print("Images:", images) # Debug: Print image paths | |
print("Text:", text) # Debug: Print text input | |
print("Using model:", model_name) | |
try: | |
results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist() # Perform the model inference on all images | |
print("Initial results: should be imgs x texts", results) | |
ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True) # Rank results | |
ranked_images = [(img, f"Rank: {rank + 1} - Score: {score:.2f}") for rank, (img, score) in enumerate(ranked_results)] # Pair images with their scores and rank | |
print("Ranked Results:", ranked_results) | |
except RuntimeError as e: | |
print(f"RuntimeError during model inference: {e}") | |
raise e | |
return ranked_images | |
# Create the first demo | |
demo_vqascore = gr.Interface( | |
fn=generate, # function to call | |
inputs=[ | |
gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), | |
gr.Image(type="filepath"), | |
gr.Textbox(label="Prompt") | |
], # define the types of inputs | |
outputs="number", # define the type of output | |
title="VQAScore", # title of the app | |
description="This model evaluates the similarity between an image and a text prompt." | |
) | |
# Create the second demo | |
demo_vqascore_ranking = gr.Interface( | |
fn=rank_images, # function to call | |
inputs=[ | |
gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), | |
gr.Gallery(label="Generated Images"), | |
gr.Textbox(label="Prompt") | |
], # define the types of inputs | |
outputs=gr.Gallery(label="Ranked Images"), # define the type of output | |
title="VQAScore Ranking", # title of the app | |
description="This model ranks a gallery of images based on their similarity to a text prompt.", | |
allow_flagging='never' | |
) | |
# Combine the demos into a tabbed interface | |
tabbed_interface = gr.TabbedInterface([demo_vqascore, demo_vqascore_ranking], ["VQAScore", "VQAScore Ranking"]) | |
# Launch the tabbed interface | |
tabbed_interface.queue() | |
tabbed_interface.launch(share=False) | |