WalidBouss commited on
Commit
601fb13
1 Parent(s): c249249

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
  import open_clip
9
 
10
  import gradio as gr
 
11
 
12
  from legrad import LeWrapper, LePreprocess
13
 
@@ -54,6 +55,7 @@ def logits_to_heatmaps(logits, image_cv):
54
 
55
 
56
  # ---------- Main visualization function ----------
 
57
  def viz_func(url, image, text_query):
58
  image_torch = preprocess(image).unsqueeze(0).to(device)
59
  text_emb = _get_text_embedding(model, tokenizer, classes=[text_query], device=device)
 
8
  import open_clip
9
 
10
  import gradio as gr
11
+ import spaces
12
 
13
  from legrad import LeWrapper, LePreprocess
14
 
 
55
 
56
 
57
  # ---------- Main visualization function ----------
58
+ @spaces.GPU
59
  def viz_func(url, image, text_query):
60
  image_torch = preprocess(image).unsqueeze(0).to(device)
61
  text_emb = _get_text_embedding(model, tokenizer, classes=[text_query], device=device)