File size: 1,828 Bytes
af82581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1cd55c
af82581
 
 
 
 
 
 
 
 
b1cd55c
af82581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import albumentations
import cv2
import torch
import timm
import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

labels = {0: 'bacterial_leaf_blight',
 1: 'bacterial_leaf_streak',
 2: 'bacterial_panicle_blight',
 3: 'blast',
 4: 'brown_spot',
 5: 'dead_heart',
 6: 'downy_mildew',
 7: 'hispa',
 8: 'normal',
 9: 'tungro'}
 
def inference_fn(model, image=None):  
    model.eval()
    image = image.to(device)  
    print(image.shape)
    with torch.no_grad():
        output = model(image.unsqueeze(0))
    out = output.sigmoid().detach().cpu().numpy().flatten()   
    return out
    
    
def predict(image = None) :  
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mean = (0.485, 0.456, 0.406) 
    std = (0.229, 0.224, 0.225) 
    
    augmentations = albumentations.Compose(
            [
                albumentations.Resize(256, 256),
                albumentations.HorizontalFlip(p=0.5),
                albumentations.VerticalFlip(p=0.5),
                albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
            ]
    )
    
    augmented = augmentations(image=image)
    image = augmented["image"]
    image = np.transpose(image, (2, 0, 1))
    image = torch.tensor(image, dtype=torch.float32)
    model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
    model.load_state_dict(torch.load("paddy_model.pth"))
    model.to(device)

    predicted = inference_fn(model, image)
    
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
    return {labels[i]: float(predicted[i]) for i in range(10)}
    

gr.Interface(fn=predict, 
             inputs=gr.inputs.Image(shape=(256, 256)),
             outputs=gr.outputs.Label(num_top_classes=10),
             examples=["200001.jpg", "100028.jpg"]).launch()