File size: 1,837 Bytes
7389856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585a0fb
7389856
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
import torch 
from torchvision import models,transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'
backbone=models.densenet121(pretrained=True)
for param in backbone.parameters():
    #print(param)
    param.requires_grad=False
for param in backbone.features.denseblock4.parameters():
    param.requires_grad=True
backbone.classifier=torch.nn.Sequential(
torch.nn.Dropout(0.2,inplace=True),
torch.nn.Linear(1024,4),
   
)
model=backbone.to(device)
model.load_state_dict(torch.load('Alzahimer_weights.pt',map_location=torch.device('cpu')))
label_2id={'MildDemented':0,'ModerateDemented':1,'NonDemented':2,'VeryMildDemented':3}
id2_label={i:v for v,i in label_2id.items()}
preprocess=transforms.Compose([ transforms.Resize((224,224)),
                                transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])
def detect_symptom(image):
  img=preprocess(image)
  with torch.no_grad():
    model.eval()
    pred=model(img.unsqueeze(0))
    pred_max=torch.argmax(pred,dim=1)
    label=id2_label[pred_max.detach().cpu().numpy()[0]]
    confidence=torch.nn.functional.softmax(pred.float(),dim=-1).detach().cpu().topk(3)
  #print(confidence.indices[0][0].item())
    predictions={id2_label[confidence.indices[0][i].item()]:confidence.values[0][i].item() for i in range(3)}
    return predictions

  
demo=gr.Interface(fn=detect_symptom,inputs=gr.Image(type='pil'),
                  outputs=gr.Label(num_top_classes=3),
                  examples=['0020ed3a-2b5f-4e46-9b96-97484c10a88c.jpg','003304b2-5091-4342-93d8-1720afca671e.jpg',
                           'mild_dimer.jpg'],
                 title="Alzimer different stage detection")
demo.launch(debug=False)