xi0v commited on
Commit
aeacaa2
1 Parent(s): 94d4af5

FIXED RuntimeError: CUDA has been initialized before importing the `spaces` package

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -1,8 +1,10 @@
 
1
  import gradio as gr
2
  import torch
3
  torch.jit.script = lambda f: f # Avoid script error in lambda
4
  from t2v_metrics import VQAScore, list_all_vqascore_models
5
 
 
6
  def update_model(model_name):
7
  return VQAScore(model=model_name, device="cuda")
8
 
@@ -11,13 +13,19 @@ global model_pipe, cur_model_name
11
  cur_model_name = "clip-flant5-xl"
12
  model_pipe = update_model(cur_model_name)
13
 
14
- # Ensure GPU context manager is imported correctly (assuming spaces is a module you have)
15
- try:
16
- from spaces import GPU
17
- except ImportError:
18
- GPU = lambda duration: (lambda f: f) # Dummy decorator if spaces.GPU is not available
19
 
20
- @GPU(duration=20)
 
 
 
 
 
 
 
 
 
 
 
21
  def generate(model_name, image, text):
22
  global model_pipe, cur_model_name
23
 
@@ -38,7 +46,7 @@ def generate(model_name, image, text):
38
 
39
  return result
40
 
41
- @GPU(duration=20)
42
  def rank_images(model_name, images, text):
43
  global model_pipe, cur_model_name
44
 
 
1
+ import spaces
2
  import gradio as gr
3
  import torch
4
  torch.jit.script = lambda f: f # Avoid script error in lambda
5
  from t2v_metrics import VQAScore, list_all_vqascore_models
6
 
7
+
8
  def update_model(model_name):
9
  return VQAScore(model=model_name, device="cuda")
10
 
 
13
  cur_model_name = "clip-flant5-xl"
14
  model_pipe = update_model(cur_model_name)
15
 
 
 
 
 
 
16
 
17
+ # Ensure GPU context manager is imported correctly (assuming spaces is a module you have)
18
+ #try:
19
+ #from spaces import GPU # i believe this is wrong, spaces package does not have "GPU"
20
+ #except ImportError:
21
+ # GPU = lambda duration: (lambda f: f) # Dummy decorator if spaces.GPU is not available
22
+
23
+ if torch.cuda.is_available():
24
+ model_pipe.device = "cuda"
25
+ else:
26
+ print("CUDA is not available")
27
+
28
+ @spaces.GPU # a duration lower than 60 does not work, leave as is.
29
  def generate(model_name, image, text):
30
  global model_pipe, cur_model_name
31
 
 
46
 
47
  return result
48
 
49
+
50
  def rank_images(model_name, images, text):
51
  global model_pipe, cur_model_name
52