PhuongPhan commited on
Commit
c0c1175
·
verified ·
1 Parent(s): 0696e86

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pretrained Resnet-18 mode
2
+ import torch
3
+ model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
4
+
5
+ # define a function that takes in the user input, which in this case is an image, and returns the prediction.
6
+ '''The prediction should be returned as a dictionary whose keys are class name and values are confidence probabilities.
7
+ We will load the class names from this text file.
8
+ '''
9
+ import requests
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+
13
+ # Download human-readable labels for ImageNet.
14
+ response = requests.get("https://git.io/JJkYN")
15
+ labels = response.text.split("\n")
16
+
17
+ def predict(inp):
18
+ inp = transforms.ToTensor()(inp).unsqueeze(0)
19
+ with torch.no_grad():
20
+ prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
21
+ confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
22
+ return confidences
23
+
24
+ '''The function converts the input image into a PIL Image and subsequently into a PyTorch tensor.
25
+ After processing the tensor through the model, it returns the predictions in the form of a dictionary named confidences.
26
+ The dictionary's keys are the class labels, and its values are the corresponding confidence probabilities.
27
+
28
+ In this section, we define a predict function that processes an input image to return prediction probabilities.
29
+ The function first converts the image into a PyTorch tensor and then forwards it through the pretrained model.
30
+
31
+ We use the softmax function in the final step to calculate the probabilities of each class.
32
+ The softmax function is crucial because it converts the raw output logits from the model, which can be any real number, into probabilities that sum up to 1.
33
+ This makes it easier to interpret the model’s outputs as confidence levels for each class.'''
34
+
35
+ # Creating a Gradio interface
36
+ import gradio as gr
37
+
38
+ gr.Interface(fn=predict,
39
+ inputs=gr.Image(type="pil"), # creates the component and handles the preprocessing to convert that to a PIL image
40
+ outputs=gr.Label(num_top_classes=3), # a Label, which displays the top labels in a nice form. Since we don't want to show all 1,000 class labels, we will customize it to show only the top 3 images by constructing it as
41
+ examples=["/lion.jpg", "/cheetah.jpg"]).launch()