ShAnSantosh's picture
Update app.py
7cca923
raw
history blame
2.1 kB
import albumentations
import cv2
import torch
import timm
import gradio as gr
import numpy as np
import os
import random
device = torch.device('cpu')
def seed_everything(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
seed_everything(seed=42)
labels = {0: 'bacterial_leaf_blight',
1: 'bacterial_leaf_streak',
2: 'bacterial_panicle_blight',
3: 'blast',
4: 'brown_spot',
5: 'dead_heart',
6: 'downy_mildew',
7: 'hispa',
8: 'normal',
9: 'tungro'}
def inference_fn(model, image=None):
model.eval()
image = image.to(device)
print(image.shape)
with torch.no_grad():
output = model(image.unsqueeze(0))
out = output.sigmoid().detach().cpu().numpy().flatten()
return out
def predict(image = None) :
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
augmentations = albumentations.Compose(
[
albumentations.Resize(256, 256),
albumentations.HorizontalFlip(p=0.5),
albumentations.VerticalFlip(p=0.5),
albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
]
)
augmented = augmentations(image=image)
image = augmented["image"]
image = np.transpose(image, (2, 0, 1))
image = torch.tensor(image, dtype=torch.float32)
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
model.load_state_dict(torch.load("paddy_model.pth", map_location=torch.device(device)))
model.to(device)
predicted = inference_fn(model, image)
return {labels[i]: float(predicted[i]) for i in range(10)}
gr.Interface(fn=predict,
inputs=gr.inputs.Image(shape=(256, 256)),
outputs=gr.outputs.Label(num_top_classes=10),
examples=["200005.jpg", "200006.jpg"]).launch()