alibidaran commited on
Commit
7389856
1 Parent(s): db3c865

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import models,transforms
4
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
5
+ backbone=models.densenet121(pretrained=True)
6
+ for param in backbone.parameters():
7
+ #print(param)
8
+ param.requires_grad=False
9
+ for param in backbone.features.denseblock4.parameters():
10
+ param.requires_grad=True
11
+ backbone.classifier=torch.nn.Sequential(
12
+ torch.nn.Dropout(0.2,inplace=True),
13
+ torch.nn.Linear(1024,4),
14
+
15
+ )
16
+ model=backbone.to(device)
17
+ model.load_state_dict(torch.load('Alzahimer_weights.pt',map_location=torch.device('cpu')))
18
+ label_2id={'MildDemented':0,'ModerateDemented':1,'NonDemented':2,'VeryMildDemented':3}
19
+ id2_label={i:v for v,i in label_2id.items()}
20
+ preprocess=transforms.Compose([ transforms.Resize((224,224)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize([0.485, 0.456, 0.406],
23
+ [0.229, 0.224, 0.225])])
24
+ def detect_symptom(image):
25
+ img=preprocess(image)
26
+ with torch.no_grad():
27
+ model.eval()
28
+ pred=model(img.unsqueeze(0))
29
+ pred_max=torch.argmax(pred,dim=1)
30
+ label=id2_label[pred_max.detach().cpu().numpy()[0]]
31
+ confidence=torch.nn.functional.softmax(pred.float(),dim=-1).detach().cpu().topk(3)
32
+ #print(confidence.indices[0][0].item())
33
+ predictions={id2_label[confidence.indices[0][i].item()]:confidence.values[0][i].item() for i in range(3)}
34
+ return predictions
35
+
36
+
37
+ demo=gr.Interface(fn=detect_symptom,inputs=gr.Image(type='pil'),
38
+ outputs=gr.Label(num_top_classes=3),
39
+ examples=['0020ed3a-2b5f-4e46-9b96-97484c10a88c.jpg','003304b2-5091-4342-93d8-1720afca671e.jpg',
40
+ 'mild_dimer.jpg'],
41
+ themes=gr.themes.Soft(primary_hue=gr.themes.colors.amber,secondary_hue=gr.themes.colors.blue))
42
+ demo.launch(debug=False)