masapasa commited on
Commit
78e7c40
1 Parent(s): 837482c

Create new file

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations
2
+ import cv2
3
+ import torch
4
+ import timm
5
+ import gradio as gr
6
+ import numpy as np
7
+ import os
8
+ import random
9
+
10
+ device = torch.device('cpu')
11
+
12
+ labels = {
13
+ 0: 'bacterial_leaf_blight',
14
+ 1: 'bacterial_leaf_streak',
15
+ 2: 'bacterial_panicle_blight',
16
+ 3: 'blast',
17
+ 4: 'brown_spot',
18
+ 5: 'dead_heart',
19
+ 6: 'downy_mildew',
20
+ 7: 'hispa',
21
+ 8: 'normal',
22
+ 9: 'tungro'
23
+ }
24
+
25
+ def inference_fn(model, image=None):
26
+ model.eval()
27
+ image = image.to(device)
28
+ with torch.no_grad():
29
+ output = model(image.unsqueeze(0))
30
+ out = output.sigmoid().detach().cpu().numpy().flatten()
31
+ return out
32
+
33
+
34
+ def predict(image=None) -> dict:
35
+ mean = (0.485, 0.456, 0.406)
36
+ std = (0.229, 0.224, 0.225)
37
+
38
+ augmentations = albumentations.Compose(
39
+ [
40
+ albumentations.Resize(256, 256),
41
+ albumentations.HorizontalFlip(p=0.5),
42
+ albumentations.VerticalFlip(p=0.5),
43
+ albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
44
+ ]
45
+ )
46
+
47
+ augmented = augmentations(image=image)
48
+ image = augmented["image"]
49
+ image = np.transpose(image, (2, 0, 1))
50
+ image = torch.tensor(image, dtype=torch.float32)
51
+ model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
52
+ model.load_state_dict(torch.load("/home/aswin/Downloads/paddy_model.pth", map_location=torch.device(device)))
53
+ model.to(device)
54
+
55
+ predicted = inference_fn(model, image)
56
+
57
+ return {labels[i]: float(predicted[i]) for i in range(10)}
58
+
59
+
60
+ interface = gr.Interface(fn=predict,
61
+ inputs=gr.inputs.Image(),
62
+ outputs=gr.outputs.Label(num_top_classes=10),
63
+ interpretation='default').launch()
64
+ interface.launch()