|
import gradio as gr |
|
import torch |
|
from torchvision import models,transforms |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
backbone=models.densenet121(pretrained=True) |
|
for param in backbone.parameters(): |
|
|
|
param.requires_grad=False |
|
for param in backbone.features.denseblock4.parameters(): |
|
param.requires_grad=True |
|
backbone.classifier=torch.nn.Sequential( |
|
torch.nn.Dropout(0.2,inplace=True), |
|
torch.nn.Linear(1024,4), |
|
|
|
) |
|
model=backbone.to(device) |
|
model.load_state_dict(torch.load('Alzahimer_weights.pt',map_location=torch.device('cpu'))) |
|
label_2id={'MildDemented':0,'ModerateDemented':1,'NonDemented':2,'VeryMildDemented':3} |
|
id2_label={i:v for v,i in label_2id.items()} |
|
preprocess=transforms.Compose([ transforms.Resize((224,224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], |
|
[0.229, 0.224, 0.225])]) |
|
def detect_symptom(image): |
|
img=preprocess(image) |
|
with torch.no_grad(): |
|
model.eval() |
|
pred=model(img.unsqueeze(0)) |
|
pred_max=torch.argmax(pred,dim=1) |
|
label=id2_label[pred_max.detach().cpu().numpy()[0]] |
|
confidence=torch.nn.functional.softmax(pred.float(),dim=-1).detach().cpu().topk(3) |
|
|
|
predictions={id2_label[confidence.indices[0][i].item()]:confidence.values[0][i].item() for i in range(3)} |
|
return predictions |
|
|
|
|
|
demo=gr.Interface(fn=detect_symptom,inputs=gr.Image(type='pil'), |
|
outputs=gr.Label(num_top_classes=3), |
|
examples=['0020ed3a-2b5f-4e46-9b96-97484c10a88c.jpg','003304b2-5091-4342-93d8-1720afca671e.jpg', |
|
'mild_dimer.jpg'], |
|
themes=gr.themes.Soft(primary_hue=gr.themes.colors.amber,secondary_hue=gr.themes.colors.blue)) |
|
demo.launch(debug=False) |