ShAnSantosh's picture
Update app.py
b1cd55c
raw
history blame
1.83 kB
import albumentations
import cv2
import torch
import timm
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
labels = {0: 'bacterial_leaf_blight',
1: 'bacterial_leaf_streak',
2: 'bacterial_panicle_blight',
3: 'blast',
4: 'brown_spot',
5: 'dead_heart',
6: 'downy_mildew',
7: 'hispa',
8: 'normal',
9: 'tungro'}
def inference_fn(model, image=None):
model.eval()
image = image.to(device)
print(image.shape)
with torch.no_grad():
output = model(image.unsqueeze(0))
out = output.sigmoid().detach().cpu().numpy().flatten()
return out
def predict(image = None) :
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
augmentations = albumentations.Compose(
[
albumentations.Resize(256, 256),
albumentations.HorizontalFlip(p=0.5),
albumentations.VerticalFlip(p=0.5),
albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
]
)
augmented = augmentations(image=image)
image = augmented["image"]
image = np.transpose(image, (2, 0, 1))
image = torch.tensor(image, dtype=torch.float32)
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
model.load_state_dict(torch.load("paddy_model.pth"))
model.to(device)
predicted = inference_fn(model, image)
del model
gc.collect()
torch.cuda.empty_cache()
return {labels[i]: float(predicted[i]) for i in range(10)}
gr.Interface(fn=predict,
inputs=gr.inputs.Image(shape=(256, 256)),
outputs=gr.outputs.Label(num_top_classes=10),
examples=["200001.jpg", "100028.jpg"]).launch()