alibidaran's picture
Update app.py
221f1d7
import gradio as gr
import torch
from torchvision import models,transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'
backbone=models.densenet121(pretrained=True)
for param in backbone.parameters():
#print(param)
param.requires_grad=False
for param in backbone.features.denseblock4.parameters():
param.requires_grad=True
backbone.classifier=torch.nn.Sequential(
torch.nn.Dropout(0.2,inplace=True),
torch.nn.Linear(1024,4),
)
model=backbone.to(device)
model.load_state_dict(torch.load('Alzahimer_weights.pt',map_location=torch.device('cpu')))
label_2id={'MildDemented':0,'ModerateDemented':1,'NonDemented':2,'VeryMildDemented':3}
id2_label={i:v for v,i in label_2id.items()}
preprocess=transforms.Compose([ transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
def detect_symptom(image):
img=preprocess(image)
with torch.no_grad():
model.eval()
pred=model(img.unsqueeze(0))
pred_max=torch.argmax(pred,dim=1)
label=id2_label[pred_max.detach().cpu().numpy()[0]]
confidence=torch.nn.functional.softmax(pred.float(),dim=-1).detach().cpu().topk(3)
#print(confidence.indices[0][0].item())
predictions={id2_label[confidence.indices[0][i].item()]:confidence.values[0][i].item() for i in range(3)}
return predictions
demo=gr.Interface(fn=detect_symptom,inputs=gr.Image(type='pil'),
outputs=gr.Label(num_top_classes=3),
examples=['0020ed3a-2b5f-4e46-9b96-97484c10a88c.jpg','003304b2-5091-4342-93d8-1720afca671e.jpg',
'mild_dimer.jpg'],
title="Alzimer different stage detection")
demo.launch(debug=False)