ShAnSantosh commited on
Commit
af82581
1 Parent(s): 1fa98fd

created new file

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations
2
+ import cv2
3
+ import torch
4
+ import timm
5
+ import gradio as gr
6
+
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
+
9
+ labels = {0: 'bacterial_leaf_blight',
10
+ 1: 'bacterial_leaf_streak',
11
+ 2: 'bacterial_panicle_blight',
12
+ 3: 'blast',
13
+ 4: 'brown_spot',
14
+ 5: 'dead_heart',
15
+ 6: 'downy_mildew',
16
+ 7: 'hispa',
17
+ 8: 'normal',
18
+ 9: 'tungro'}
19
+
20
+ def inference_fn(model, image=None):
21
+ model.eval()
22
+ image = image.to(device)
23
+ print(image.shape)
24
+ with torch.no_grad():
25
+ output = model(image.unsqueeze(0))
26
+ out = output.sigmoid().detach().cpu().numpy().flatten()
27
+ return out
28
+
29
+
30
+ def predict(image = None) :
31
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
32
+ mean = (0.485, 0.456, 0.406)
33
+ std = (0.229, 0.224, 0.225)
34
+
35
+ augmentations = albumentations.Compose(
36
+ [
37
+ albumentations.Resize(256, 256),
38
+ albumentations.HorizontalFlip(p=0.5),
39
+ albumentations.VerticalFlip(p=0.5),
40
+ albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
41
+ ]
42
+ )
43
+
44
+ augmented = augmentations(image=image)
45
+ image = augmented["image"]
46
+ image = np.transpose(image, (2, 0, 1))
47
+ image = torch.tensor(image, dtype=torch.float32)
48
+ model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
49
+ model.load_state_dict(torch.load("paddy_model.pth"))
50
+ model.to(device)
51
+
52
+ predicted = inference_fn(model, image)
53
+
54
+ del model
55
+ gc.collect()
56
+ torch.cuda.empty_cache()
57
+
58
+ return {labels[i]: float(predicted[i]) for i in range(10)}
59
+
60
+
61
+ gr.Interface(fn=predict,
62
+ inputs=gr.inputs.Image(shape=(256, 256)),
63
+ outputs=gr.outputs.Label(num_top_classes=10),
64
+ examples=["200001.jpg", "100028.jpg"]).launch()