whyu commited on
Commit
f9a49d7
1 Parent(s): 3641e30

To support zero zero-gpu

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -5,9 +5,11 @@ from PIL import Image
5
  from timm.data import create_transform
6
 
7
 
 
 
8
  # Prepare the model.
9
  import models
10
- model = models.mambaout_femto(pretrained=True) # can change different model name
11
  model.eval()
12
 
13
  # Prepare the transform.
@@ -17,9 +19,9 @@ transform = create_transform(input_size=224, crop_pct=model.default_cfg['crop_pc
17
  response = requests.get("https://git.io/JJkYN")
18
  labels = response.text.split("\n")
19
 
 
20
  def predict(inp):
21
- inp = transform(inp).unsqueeze(0)
22
-
23
  with torch.no_grad():
24
  prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
25
  confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
 
5
  from timm.data import create_transform
6
 
7
 
8
+ device = "cuda"
9
+
10
  # Prepare the model.
11
  import models
12
+ model = models.mambaout_femto(pretrained=True).to(device=device) # can change different model name
13
  model.eval()
14
 
15
  # Prepare the transform.
 
19
  response = requests.get("https://git.io/JJkYN")
20
  labels = response.text.split("\n")
21
 
22
+ +@spaces.GPU
23
  def predict(inp):
24
+ inp = transform(inp).unsqueeze(0).to(device=device)
 
25
  with torch.no_grad():
26
  prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
27
  confidences = {labels[i]: float(prediction[i]) for i in range(1000)}