Commit
•
7389856
1
Parent(s):
db3c865
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from torchvision import models,transforms
|
4 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
5 |
+
backbone=models.densenet121(pretrained=True)
|
6 |
+
for param in backbone.parameters():
|
7 |
+
#print(param)
|
8 |
+
param.requires_grad=False
|
9 |
+
for param in backbone.features.denseblock4.parameters():
|
10 |
+
param.requires_grad=True
|
11 |
+
backbone.classifier=torch.nn.Sequential(
|
12 |
+
torch.nn.Dropout(0.2,inplace=True),
|
13 |
+
torch.nn.Linear(1024,4),
|
14 |
+
|
15 |
+
)
|
16 |
+
model=backbone.to(device)
|
17 |
+
model.load_state_dict(torch.load('Alzahimer_weights.pt',map_location=torch.device('cpu')))
|
18 |
+
label_2id={'MildDemented':0,'ModerateDemented':1,'NonDemented':2,'VeryMildDemented':3}
|
19 |
+
id2_label={i:v for v,i in label_2id.items()}
|
20 |
+
preprocess=transforms.Compose([ transforms.Resize((224,224)),
|
21 |
+
transforms.ToTensor(),
|
22 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
23 |
+
[0.229, 0.224, 0.225])])
|
24 |
+
def detect_symptom(image):
|
25 |
+
img=preprocess(image)
|
26 |
+
with torch.no_grad():
|
27 |
+
model.eval()
|
28 |
+
pred=model(img.unsqueeze(0))
|
29 |
+
pred_max=torch.argmax(pred,dim=1)
|
30 |
+
label=id2_label[pred_max.detach().cpu().numpy()[0]]
|
31 |
+
confidence=torch.nn.functional.softmax(pred.float(),dim=-1).detach().cpu().topk(3)
|
32 |
+
#print(confidence.indices[0][0].item())
|
33 |
+
predictions={id2_label[confidence.indices[0][i].item()]:confidence.values[0][i].item() for i in range(3)}
|
34 |
+
return predictions
|
35 |
+
|
36 |
+
|
37 |
+
demo=gr.Interface(fn=detect_symptom,inputs=gr.Image(type='pil'),
|
38 |
+
outputs=gr.Label(num_top_classes=3),
|
39 |
+
examples=['0020ed3a-2b5f-4e46-9b96-97484c10a88c.jpg','003304b2-5091-4342-93d8-1720afca671e.jpg',
|
40 |
+
'mild_dimer.jpg'],
|
41 |
+
themes=gr.themes.Soft(primary_hue=gr.themes.colors.amber,secondary_hue=gr.themes.colors.blue))
|
42 |
+
demo.launch(debug=False)
|