zhiqiulin commited on
Commit
656934c
1 Parent(s): 5741e23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -16
app.py CHANGED
@@ -1,45 +1,96 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
- torch.jit.script = lambda f: f # Avoid script error in lambda
5
-
6
  from t2v_metrics import VQAScore, list_all_vqascore_models
7
 
8
- def update_model(model_name):
9
- return VQAScore(model=model_name, device="cuda")
10
 
11
- global model_pipe
 
12
 
13
- # Global model variable, but do not initialize or move to CUDA here
 
14
  cur_model_name = "clip-flant5-xl"
15
  model_pipe = update_model(cur_model_name)
16
 
17
- @spaces.GPU(duration = 20)
 
 
 
 
 
 
18
  def generate(model_name, image, text):
 
 
19
  if model_name != cur_model_name:
 
20
  model_pipe = update_model(model_name)
21
 
22
  print("Image:", image) # Debug: Print image path
23
  print("Text:", text) # Debug: Print text input
24
  print("Using model:", model_name)
25
- # Wrap the model call in a try-except block to capture and debug CUDA errors
26
  try:
27
  result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference
28
- print("Result", result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  except RuntimeError as e:
30
  print(f"RuntimeError during model inference: {e}")
31
  raise e
32
 
33
- return result # Return the result
34
 
35
- demo = gr.Interface(
 
36
  fn=generate, # function to call
37
- # ['clip-flant5-xxl', 'clip-flant5-xl', 'clip-flant5-xxl-no-system', 'clip-flant5-xxl-no-system-no-user', 'llava-v1.5-13b', 'llava-v1.5-7b', 'sharegpt4v-7b', 'sharegpt4v-13b', 'llava-v1.6-13b', 'instructblip-flant5-xxl', 'instructblip-flant5-xl']
38
- inputs=[gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), gr.Image(type="filepath"), gr.Textbox(label="Prompt")], # define the types of inputs
 
 
 
39
  outputs="number", # define the type of output
40
  title="VQAScore", # title of the app
41
  description="This model evaluates the similarity between an image and a text prompt."
42
  )
43
 
44
- demo.queue()
45
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
 
 
3
  from t2v_metrics import VQAScore, list_all_vqascore_models
4
 
5
+ torch.jit.script = lambda f: f # Avoid script error in lambda
 
6
 
7
+ def update_model(model_name):
8
+ return VQAScore(model=model_name, device="cuda")
9
 
10
+ # Use global variables for model pipe and current model name
11
+ global model_pipe, cur_model_name
12
  cur_model_name = "clip-flant5-xl"
13
  model_pipe = update_model(cur_model_name)
14
 
15
+ # Ensure GPU context manager is imported correctly (assuming spaces is a module you have)
16
+ try:
17
+ from spaces import GPU
18
+ except ImportError:
19
+ GPU = lambda duration: (lambda f: f) # Dummy decorator if spaces.GPU is not available
20
+
21
+ @GPU(duration=20)
22
  def generate(model_name, image, text):
23
+ global model_pipe, cur_model_name
24
+
25
  if model_name != cur_model_name:
26
+ cur_model_name = model_name # Update the current model name
27
  model_pipe = update_model(model_name)
28
 
29
  print("Image:", image) # Debug: Print image path
30
  print("Text:", text) # Debug: Print text input
31
  print("Using model:", model_name)
32
+
33
  try:
34
  result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference
35
+ print("Result:", result)
36
+ except RuntimeError as e:
37
+ print(f"RuntimeError during model inference: {e}")
38
+ raise e
39
+
40
+ return result
41
+
42
+ @GPU(duration=20)
43
+ def rank_images(model_name, images, text):
44
+ global model_pipe, cur_model_name
45
+
46
+ if model_name != cur_model_name:
47
+ cur_model_name = model_name # Update the current model name
48
+ model_pipe = update_model(model_name)
49
+
50
+ print("Images:", images) # Debug: Print image paths
51
+ print("Text:", text) # Debug: Print text input
52
+ print("Using model:", model_name)
53
+
54
+ try:
55
+ results = model_pipe(images=images, texts=[text] * len(images)).cpu()[:, 0].tolist() # Perform the model inference on all images
56
+ ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True) # Rank results
57
+ ranked_images = [img for img, score in ranked_results]
58
+ print("Ranked Results:", ranked_results)
59
  except RuntimeError as e:
60
  print(f"RuntimeError during model inference: {e}")
61
  raise e
62
 
63
+ return ranked_images
64
 
65
+ # Create the first demo
66
+ demo_vqascore = gr.Interface(
67
  fn=generate, # function to call
68
+ inputs=[
69
+ gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"),
70
+ gr.Image(type="filepath"),
71
+ gr.Textbox(label="Prompt")
72
+ ], # define the types of inputs
73
  outputs="number", # define the type of output
74
  title="VQAScore", # title of the app
75
  description="This model evaluates the similarity between an image and a text prompt."
76
  )
77
 
78
+ # Create the second demo
79
+ demo_vqascore_ranking = gr.Interface(
80
+ fn=rank_images, # function to call
81
+ inputs=[
82
+ gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"),
83
+ gr.Gallery(label="Generated Images"),
84
+ gr.Textbox(label="Prompt")
85
+ ], # define the types of inputs
86
+ outputs=gr.Gallery(label="Ranked Images"), # define the type of output
87
+ title="VQAScore Ranking", # title of the app
88
+ description="This model ranks a gallery of images based on their similarity to a text prompt."
89
+ )
90
+
91
+ # Combine the demos into a tabbed interface
92
+ tabbed_interface = gr.TabbedInterface([demo_vqascore, demo_vqascore_ranking], ["VQAScore", "VQAScore Ranking"])
93
+
94
+ # Launch the tabbed interface
95
+ tabbed_interface.queue()
96
+ tabbed_interface.launch()