yuhe6 commited on
Commit
666b18f
1 Parent(s): 9c2ac28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -4,7 +4,7 @@ from PIL import Image
4
  from torchvision import transforms
5
  import gradio as gr
6
  #https://huggingface.co/spaces/yuhe6/final_project/blob/main/CIFAR10_cnn.pth
7
- os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
8
 
9
  #model = torch.hub.load('huawei-noah/ghostnet', 'ghostnet_1x', pretrained=True)
10
  model = torch.hub.load('/', 'CIFAR10_cnn', pretrained=True)
@@ -36,8 +36,10 @@ def inference(input_image):
36
  # Read the categories
37
  with open("imagenet_classes.txt", "r") as f:
38
  categories = [s.strip() for s in f.readlines()]
 
 
39
  # Show top categories per image
40
- top5_prob, top5_catid = torch.topk(probabilities, 5)
41
  result = {}
42
  for i in range(top5_prob.size(0)):
43
  result[categories[top5_catid[i]]] = top5_prob[i].item()
 
4
  from torchvision import transforms
5
  import gradio as gr
6
  #https://huggingface.co/spaces/yuhe6/final_project/blob/main/CIFAR10_cnn.pth
7
+ #os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
8
 
9
  #model = torch.hub.load('huawei-noah/ghostnet', 'ghostnet_1x', pretrained=True)
10
  model = torch.hub.load('/', 'CIFAR10_cnn', pretrained=True)
 
36
  # Read the categories
37
  with open("imagenet_classes.txt", "r") as f:
38
  categories = [s.strip() for s in f.readlines()]
39
+ with open("dog_cat.txt", "r") as f:
40
+ categories = [s.strip() for s in f.readlines()]
41
  # Show top categories per image
42
+ top5_prob, top5_catid = torch.topk(probabilities, 1)
43
  result = {}
44
  for i in range(top5_prob.size(0)):
45
  result[categories[top5_catid[i]]] = top5_prob[i].item()