adityapathakk commited on
Commit
4a401bd
·
1 Parent(s): 8c5b605

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import torch
4
+ import gradio
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import torch.nn as nn
8
+ from torchvision import transforms, models
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ os.system("wget https://www.dropbox.com/s/3us120bz5lhoh0t/model_best.pt?dl=0")
12
+
13
+ model = models.resnet50(pretrained=True)
14
+ num_ftrs = model.fc.in_features
15
+ # Here the size of each output sample is set to 2.
16
+ # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
17
+ model.fc = nn.Linear(num_ftrs, 7)
18
+
19
+ model.load_state_dict(torch.load("./model_best.pt?dl=0", map_location=device))
20
+
21
+ # img = Image.open(path).convert('RGB')
22
+ # from torchvision import transforms
23
+
24
+ transforms2 = transforms.Compose([
25
+ transforms.Resize(256),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
28
+ ])
29
+
30
+ # img = transforms(img)
31
+ # img = img.unsqueeze(0)
32
+ # model.eval()
33
+
34
+ labels = ["Bacterialblight",
35
+ "Blast",
36
+ "Brownspot",
37
+ "Healthy",
38
+ "Hispa",
39
+ "LeafBlast",
40
+ "Tungro"]
41
+ # with torch.no_grad():
42
+ # # preds =
43
+ # preds = model(img)
44
+ # score, indices = torch.max(preds, 1)
45
+
46
+ def recognize_digit(image):
47
+ image = transforms2(image)
48
+ image = image.unsqueeze(0)
49
+ # image = image.unsqueeze(0)
50
+ # image = image.reshape(1, -1)
51
+ # with torch.no_grad():
52
+ # preds =
53
+ # img = image.reshape((-1, 3, 256, 256))
54
+ preds = model(image).flatten()
55
+ # prediction = model.predict(image).tolist()[0]
56
+ # score, indices = torch.max(preds, 1)
57
+ # return {str(indices.item())}
58
+ return {labels[i]: float(preds[i]) for i in range(7)}
59
+
60
+
61
+ im = gradio.inputs.Image(
62
+ shape=(256, 256), image_mode="RGB", type="pil")
63
+
64
+ iface = gr.Interface(
65
+ recognize_digit,
66
+ im,
67
+ gradio.outputs.Label(num_top_classes=3),
68
+ live=True,
69
+ interpretation="default",
70
+ # examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]],
71
+ capture_session=True,
72
+ )
73
+
74
+ iface.test_launch()
75
+ iface.launch()