abidlabs HF staff commited on
Commit
c592b6f
1 Parent(s): d6891e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -3
app.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
4
 
5
  import requests
6
- from PIL import Image
7
  from torchvision import transforms
8
 
9
  # Download human-readable labels for ImageNet.
@@ -11,7 +10,6 @@ response = requests.get("https://git.io/JJkYN")
11
  labels = response.text.split("\n")
12
 
13
  def predict(inp):
14
- inp = Image.fromarray(inp.astype('uint8'), 'RGB')
15
  inp = transforms.ToTensor()(inp).unsqueeze(0)
16
  with torch.no_grad():
17
  prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
@@ -21,7 +19,7 @@ def predict(inp):
21
  import gradio as gr
22
 
23
  gr.Interface(fn=predict,
24
- inputs="image",
25
  outputs=gr.outputs.Label(num_top_classes=3),
26
  examples=["lion.jpg", "cheetah.jpg"],
27
  theme="default",
 
3
  model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
4
 
5
  import requests
 
6
  from torchvision import transforms
7
 
8
  # Download human-readable labels for ImageNet.
 
10
  labels = response.text.split("\n")
11
 
12
  def predict(inp):
 
13
  inp = transforms.ToTensor()(inp).unsqueeze(0)
14
  with torch.no_grad():
15
  prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
 
19
  import gradio as gr
20
 
21
  gr.Interface(fn=predict,
22
+ inputs=gr.inputs.Image(type="pil"),
23
  outputs=gr.outputs.Label(num_top_classes=3),
24
  examples=["lion.jpg", "cheetah.jpg"],
25
  theme="default",