import gradio as gr import torch torch.jit.script = lambda f: f # Avoid script error in lambda from t2v_metrics import VQAScore, list_all_vqascore_models def update_model(model_name): return VQAScore(model=model_name, device="cuda") # Use global variables for model pipe and current model name global model_pipe, cur_model_name cur_model_name = "clip-flant5-xl" model_pipe = update_model(cur_model_name) # Ensure GPU context manager is imported correctly (assuming spaces is a module you have) try: from spaces import GPU except ImportError: GPU = lambda duration: (lambda f: f) # Dummy decorator if spaces.GPU is not available @GPU(duration=20) def generate(model_name, image, text): global model_pipe, cur_model_name if model_name != cur_model_name: cur_model_name = model_name # Update the current model name model_pipe = update_model(model_name) print("Image:", image) # Debug: Print image path print("Text:", text) # Debug: Print text input print("Using model:", model_name) try: result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference print("Result:", result) except RuntimeError as e: print(f"RuntimeError during model inference: {e}") raise e return result @GPU(duration=20) def rank_images(model_name, images, text): global model_pipe, cur_model_name if model_name != cur_model_name: cur_model_name = model_name # Update the current model name model_pipe = update_model(model_name) images = [image_tuple[0] for image_tuple in images] print("Images:", images) # Debug: Print image paths print("Text:", text) # Debug: Print text input print("Using model:", model_name) try: results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist() # Perform the model inference on all images print("Initial results: should be imgs x texts", results) ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True) # Rank results ranked_images = [(img, f"Rank: {rank + 1} - Score: {score:.2f}") for rank, (img, score) in enumerate(ranked_results)] # Pair images with their scores and rank print("Ranked Results:", ranked_results) except RuntimeError as e: print(f"RuntimeError during model inference: {e}") raise e return ranked_images ### EXAMPLES ### example_imgs = ["0_imgs/DALLE3.png", "0_imgs/DeepFloyd.jpg", "0_imgs/Midjourney.jpg", "0_imgs/SDXL.jpg"] example_prompt0 = "Two dogs of different breeds playfully chasing around a tree" example_prompt1 = "Two dogs of the same breed playing on the grass" ### # # Create the first demo # demo_vqascore = gr.Interface( # fn=generate, # function to call # inputs=[ # gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl", ], label="Model Name"), # gr.Image(type="filepath"), # gr.Textbox(label="Prompt") # ], # define the types of inputs # examples=[ # ["clip-flant5-xl", example_imgs[0], example_prompt0], # ["clip-flant5-xl", example_imgs[0], example_prompt1], # ], # outputs="number", # define the type of output # title="VQAScore", # title of the app # description="This model evaluates the similarity between an image and a text prompt." # ) # # Create the second demo # demo_vqascore_ranking = gr.Interface( # fn=rank_images, # function to call # inputs=[ # gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), # gr.Gallery(label="Generated Images"), # gr.Textbox(label="Prompt") # ], # define the types of inputs # outputs=gr.Gallery(label="Ranked Images"), # define the type of output # examples=[ # ["clip-flant5-xl", [[img, ""] for img in example_imgs], example_prompt0], # ["clip-flant5-xl", [[img, ""] for img in example_imgs], example_prompt1] # ], # title="VQAScore Ranking", # title of the app # description="This model ranks a gallery of images based on their similarity to a text prompt.", # allow_flagging='never' # ) # Custom component for loading examples def load_example(model_name, images, prompt): return model_name, images, prompt # Create the second demo with gr.Blocks() as demo_vqascore_ranking: gr.Markdown("# VQAScore Ranking\nThis model ranks a gallery of images based on their similarity to a text prompt.") model_dropdown = gr.Dropdown(["clip-flant5-xxl", "clip-flant5-xl"], label="Model Name") gallery = gr.Gallery(label="Generated Images", elem_id="input-gallery", columns=4, allow_preview=True) prompt = gr.Textbox(label="Prompt") rank_button = gr.Button("Rank Images") ranked_gallery = gr.Gallery(label="Ranked Images with Scores", elem_id="ranked-gallery", columns=4, allow_preview=True) rank_button.click(fn=rank_images, inputs=[model_dropdown, gallery, prompt], outputs=ranked_gallery) # Custom example buttons example1_button = gr.Button("Load Example 1") example2_button = gr.Button("Load Example 2") example1_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt0), inputs=[], outputs=[model_dropdown, gallery, prompt]) example2_button.click(fn=lambda: load_example("clip-flant5-xxl", example_imgs, example_prompt1), inputs=[], outputs=[model_dropdown, gallery, prompt]) # Layout to allow user to input their own data with gr.Row(): gr.Column([model_dropdown, gallery, prompt, rank_button]) gr.Column([example1_button, example2_button]) # Launch the interface demo_vqascore_ranking.queue() demo_vqascore_ranking.launch(share=False) # # Combine the demos into a tabbed interface # tabbed_interface = gr.TabbedInterface([demo_vqascore, demo_vqascore_ranking], ["VQAScore", "VQAScore Ranking"]) # # Launch the tabbed interface # tabbed_interface.queue() # tabbed_interface.launch(share=False)