juliensimon HF staff commited on
Commit
82abe24
1 Parent(s): f939b49

Move natten install to app

Browse files
Files changed (2) hide show
  1. app.py +18 -5
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,6 +1,16 @@
 
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
 
 
 
 
 
 
 
4
  model_names = [
5
  "facebook/deit-base-patch16-224",
6
  "facebook/convnext-base-224",
@@ -10,7 +20,7 @@ model_names = [
10
  "microsoft/beit-base-patch16-224",
11
  "nvidia/mit-b0",
12
  "shi-labs/nat-base-in1k-224",
13
- "shi-labs/dinat-base-in1k-224"
14
  ]
15
 
16
 
@@ -19,10 +29,13 @@ def process(image_file, top_k, model_name):
19
  pred = p(image_file)
20
  return {x["label"]: x["score"] for x in pred[:top_k]}
21
 
 
22
  # Inputs
23
  image = gr.Image(type="filepath", label="Upload an image")
24
  top_k = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Top k classes")
25
- model_selection = gr.Dropdown(model_names, label="Pick a model")
 
 
26
 
27
  # Output
28
  labels = gr.Label()
@@ -36,9 +49,9 @@ iface = gr.Interface(
36
  inputs=[image, top_k, model_selection],
37
  outputs=[labels],
38
  examples=[
39
- ["bike.jpg", 5, "google/vit-base-patch16-224"],
40
- ["car.jpg", 5, "microsoft/swin-base-patch4-window7-224"],
41
- ["food.jpg", 5, "facebook/convnext-base-224"]
42
  ],
43
  allow_flagging="never",
44
  )
 
1
+ import subprocess
2
+ import sys
3
+
4
  import gradio as gr
5
  from transformers import pipeline
6
 
7
+
8
+ def install(package, index):
9
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package, index])
10
+
11
+
12
+ install("natten", "-f https://shi-labs.com/natten/wheels/cpu/torch1.13/index.html")
13
+
14
  model_names = [
15
  "facebook/deit-base-patch16-224",
16
  "facebook/convnext-base-224",
 
20
  "microsoft/beit-base-patch16-224",
21
  "nvidia/mit-b0",
22
  "shi-labs/nat-base-in1k-224",
23
+ "shi-labs/dinat-base-in1k-224",
24
  ]
25
 
26
 
 
29
  pred = p(image_file)
30
  return {x["label"]: x["score"] for x in pred[:top_k]}
31
 
32
+
33
  # Inputs
34
  image = gr.Image(type="filepath", label="Upload an image")
35
  top_k = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Top k classes")
36
+ model_selection = gr.Dropdown(
37
+ model_names, value="google/vit-base-patch16-224", label="Pick a model"
38
+ )
39
 
40
  # Output
41
  labels = gr.Label()
 
49
  inputs=[image, top_k, model_selection],
50
  outputs=[labels],
51
  examples=[
52
+ ["bike.jpg", 5, "google/vit-base-patch16-224"],
53
+ ["car.jpg", 5, "microsoft/swin-base-patch4-window7-224"],
54
+ ["food.jpg", 5, "facebook/convnext-base-224"],
55
  ],
56
  allow_flagging="never",
57
  )
requirements.txt CHANGED
@@ -1,3 +1,2 @@
1
  torch==1.13.1
2
  transformers>=4.25.1
3
- natten -f https://shi-labs.com/natten/wheels/cpu/torch1.13/index.html
 
1
  torch==1.13.1
2
  transformers>=4.25.1