Maria-Dolgaya commited on
Commit
b3f257f
1 Parent(s): fe51708

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from torchvision.models import resnet18, ResNet18_Weights
5
+ from torch import nn
6
+ from PIL import Image # pip install pillow
7
+
8
+ labels = ['Fractured','Non-fractured']
9
+
10
+ # Same data transformation that was used for inputs (except data augmentation)
11
+ data_transform = transforms.Compose([
12
+ transforms.Resize(size=(256, 256)),
13
+ transforms.ToTensor(),
14
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
15
+ std=[0.229, 0.224, 0.225])
16
+ ])
17
+
18
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html
19
+ # Loading Model for Inference with state_dict (recommended)
20
+ model = resnet18(weights=ResNet18_Weights.DEFAULT)
21
+ model.fc = nn.Linear(in_features=512, out_features=len(labels))
22
+ model.load_state_dict(torch.load("model.pth",map_location=torch.device('cpu')))
23
+ model.eval()
24
+
25
+ def predict(img):
26
+ X = data_transform(img).unsqueeze(0) # returns tensor
27
+ with torch.no_grad():
28
+ predictions = model(X).flatten()
29
+ predictions = torch.nn.functional.softmax(predictions)
30
+ confidences = {labels[i]: float(predictions[i]) for i in range(len(labels))}
31
+ return confidences
32
+
33
+ title = "Corn Leaf Diseases"
34
+ description = "A corn leaf disease classifier trained on the Kaggle dataset using Resnet18"
35
+
36
+ demo=gr.Interface(fn=predict,
37
+ inputs=gr.Image(type="pil"),
38
+ outputs=gr.Label(num_top_classes=len(labels)),
39
+ title=title,
40
+ description=description,
41
+ examples=["2.jpg", "Corn_Common_Rust.jpg"])
42
+
43
+ demo.launch('share=True')