File size: 1,743 Bytes
af82581
 
 
 
 
30e0f6c
7cca923
 
af82581
24bad90
af82581
160cb15
 
af82581
 
 
 
 
 
 
 
160cb15
 
af82581
b1cd55c
af82581
 
 
 
 
 
 
 
160cb15
af82581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24bad90
af82581
 
 
9017094
af82581
 
 
 
21637d3
af82581
8a41460
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import albumentations
import cv2
import torch
import timm
import gradio as gr
import numpy as np
import os
import random

device = torch.device('cpu')

labels = {
 0: 'bacterial_leaf_blight',
 1: 'bacterial_leaf_streak',
 2: 'bacterial_panicle_blight',
 3: 'blast',
 4: 'brown_spot',
 5: 'dead_heart',
 6: 'downy_mildew',
 7: 'hispa',
 8: 'normal',
 9: 'tungro'
 }
 
def inference_fn(model, image=None):  
    model.eval()
    image = image.to(device)  
    with torch.no_grad():
        output = model(image.unsqueeze(0))
    out = output.sigmoid().detach().cpu().numpy().flatten()   
    return out
    
    
def predict(image=None) -> dict:  
    mean = (0.485, 0.456, 0.406) 
    std = (0.229, 0.224, 0.225) 
    
    augmentations = albumentations.Compose(
            [
                albumentations.Resize(256, 256),
                albumentations.HorizontalFlip(p=0.5),
                albumentations.VerticalFlip(p=0.5),
                albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
            ]
    )
    
    augmented = augmentations(image=image)
    image = augmented["image"]
    image = np.transpose(image, (2, 0, 1))
    image = torch.tensor(image, dtype=torch.float32)
    model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
    model.load_state_dict(torch.load("paddy_model.pth", map_location=torch.device(device)))
    model.to(device)

    predicted = inference_fn(model, image)
  
    return {labels[i]: float(predicted[i]) for i in range(10)}
    

gr.Interface(fn=predict, 
             inputs=gr.inputs.Image(),
             outputs=gr.outputs.Label(num_top_classes=10),
             examples=["200005.jpg", "200006.jpg"], interpretation='default').launch()