File size: 1,950 Bytes
7389856 585a0fb 7389856 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import gradio as gr
import torch
from torchvision import models,transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'
backbone=models.densenet121(pretrained=True)
for param in backbone.parameters():
#print(param)
param.requires_grad=False
for param in backbone.features.denseblock4.parameters():
param.requires_grad=True
backbone.classifier=torch.nn.Sequential(
torch.nn.Dropout(0.2,inplace=True),
torch.nn.Linear(1024,4),
)
model=backbone.to(device)
model.load_state_dict(torch.load('Alzahimer_weights.pt',map_location=torch.device('cpu')))
label_2id={'MildDemented':0,'ModerateDemented':1,'NonDemented':2,'VeryMildDemented':3}
id2_label={i:v for v,i in label_2id.items()}
preprocess=transforms.Compose([ transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
def detect_symptom(image):
img=preprocess(image)
with torch.no_grad():
model.eval()
pred=model(img.unsqueeze(0))
pred_max=torch.argmax(pred,dim=1)
label=id2_label[pred_max.detach().cpu().numpy()[0]]
confidence=torch.nn.functional.softmax(pred.float(),dim=-1).detach().cpu().topk(3)
#print(confidence.indices[0][0].item())
predictions={id2_label[confidence.indices[0][i].item()]:confidence.values[0][i].item() for i in range(3)}
return predictions
demo=gr.Interface(fn=detect_symptom,inputs=gr.Image(type='pil'),
outputs=gr.Label(num_top_classes=3),
examples=['0020ed3a-2b5f-4e46-9b96-97484c10a88c.jpg','003304b2-5091-4342-93d8-1720afca671e.jpg',
'mild_dimer.jpg'],
themes=gr.themes.Soft(primary_hue=gr.themes.colors.amber,secondary_hue=gr.themes.colors.blue),
title="Alzimer different stage detection")
demo.launch(debug=False) |