VQAScore / app.py
zhiqiulin's picture
Update app.py
db4a96c verified
raw
history blame
No virus
6.07 kB
import gradio as gr
import torch
torch.jit.script = lambda f: f # Avoid script error in lambda
from t2v_metrics import VQAScore, list_all_vqascore_models
def update_model(model_name):
return VQAScore(model=model_name, device="cuda")
# Use global variables for model pipe and current model name
global model_pipe, cur_model_name
cur_model_name = "clip-flant5-xl"
model_pipe = update_model(cur_model_name)
# Ensure GPU context manager is imported correctly (assuming spaces is a module you have)
try:
from spaces import GPU
except ImportError:
GPU = lambda duration: (lambda f: f) # Dummy decorator if spaces.GPU is not available
@GPU(duration=20)
def generate(model_name, image, text):
global model_pipe, cur_model_name
if model_name != cur_model_name:
cur_model_name = model_name # Update the current model name
model_pipe = update_model(model_name)
print("Image:", image) # Debug: Print image path
print("Text:", text) # Debug: Print text input
print("Using model:", model_name)
try:
result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference
print("Result:", result)
except RuntimeError as e:
print(f"RuntimeError during model inference: {e}")
raise e
return result
@GPU(duration=20)
def rank_images(model_name, images, text):
global model_pipe, cur_model_name
if model_name != cur_model_name:
cur_model_name = model_name # Update the current model name
model_pipe = update_model(model_name)
images = [image_tuple[0] for image_tuple in images]
print("Images:", images) # Debug: Print image paths
print("Text:", text) # Debug: Print text input
print("Using model:", model_name)
try:
results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist() # Perform the model inference on all images
print("Initial results: should be imgs x texts", results)
ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True) # Rank results
ranked_images = [(img, f"Rank: {rank + 1} - Score: {score:.2f}") for rank, (img, score) in enumerate(ranked_results)] # Pair images with their scores and rank
print("Ranked Results:", ranked_results)
except RuntimeError as e:
print(f"RuntimeError during model inference: {e}")
raise e
return ranked_images
### EXAMPLES ###
example_imgs = ["0_imgs/DALLE3.png",
"0_imgs/DeepFloyd.jpg",
"0_imgs/Midjourney.jpg",
"0_imgs/SDXL.jpg"]
example_prompt0 = "Two dogs of different breeds playfully chasing around a tree"
example_prompt1 = "Two dogs of the same breed playing on the grass"
###
# # Create the first demo
# demo_vqascore = gr.Interface(
# fn=generate, # function to call
# inputs=[
# gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl", ], label="Model Name"),
# gr.Image(type="filepath"),
# gr.Textbox(label="Prompt")
# ], # define the types of inputs
# examples=[
# ["clip-flant5-xl", example_imgs[0], example_prompt0],
# ["clip-flant5-xl", example_imgs[0], example_prompt1],
# ],
# outputs="number", # define the type of output
# title="VQAScore", # title of the app
# description="This model evaluates the similarity between an image and a text prompt."
# )
# # Create the second demo
# demo_vqascore_ranking = gr.Interface(
# fn=rank_images, # function to call
# inputs=[
# gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"),
# gr.Gallery(label="Generated Images"),
# gr.Textbox(label="Prompt")
# ], # define the types of inputs
# outputs=gr.Gallery(label="Ranked Images"), # define the type of output
# examples=[
# ["clip-flant5-xl", [[img, ""] for img in example_imgs], example_prompt0],
# ["clip-flant5-xl", [[img, ""] for img in example_imgs], example_prompt1]
# ],
# title="VQAScore Ranking", # title of the app
# description="This model ranks a gallery of images based on their similarity to a text prompt.",
# allow_flagging='never'
# )
# Custom component for loading examples
def load_example(model_name, images, prompt):
return model_name, images, prompt
# Create the second demo
with gr.Blocks() as demo_vqascore_ranking:
gr.Markdown("# VQAScore Ranking\nThis model ranks a gallery of images based on their similarity to a text prompt.")
model_dropdown = gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl"], label="Model Name")
gallery = gr.Gallery(label="Generated Images", elem_id="input-gallery", columns=4, allow_preview=True)
prompt = gr.Textbox(label="Prompt")
rank_button = gr.Button("Rank Images")
ranked_gallery = gr.Gallery(label="Ranked Images with Scores", elem_id="ranked-gallery", columns=4, allow_preview=True)
rank_button.click(fn=rank_images, inputs=[model_dropdown, gallery, prompt], outputs=ranked_gallery)
# Custom example buttons
example1_button = gr.Button("Load Example 1")
example2_button = gr.Button("Load Example 2")
example1_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt0), inputs=[], outputs=[model_dropdown, gallery, prompt])
example2_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt1), inputs=[], outputs=[model_dropdown, gallery, prompt])
# Layout to allow user to input their own data
with gr.Row():
gr.Column([model_dropdown, gallery, prompt, rank_button])
gr.Column([example1_button, example2_button])
# Launch the interface
demo_vqascore_ranking.queue()
demo_vqascore_ranking.launch(share=False)
# # Combine the demos into a tabbed interface
# tabbed_interface = gr.TabbedInterface([demo_vqascore, demo_vqascore_ranking], ["VQAScore", "VQAScore Ranking"])
# # Launch the tabbed interface
# tabbed_interface.queue()
# tabbed_interface.launch(share=False)