dongsheng commited on
Commit
be05fd1
·
1 Parent(s): 6986027

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+
6
+
7
+ def img2label(left, right):
8
+ left_img = Image.open(left).convert('RGB')
9
+ right_img = Image.open(right).convert('RGB')
10
+ # 将右眼底镜像反转
11
+ r2l = transforms.RandomHorizontalFlip(p=1)
12
+ right_img = r2l(right_img)
13
+
14
+ # 调整图片
15
+ left_img = my_transforms(left_img).to(device)
16
+ right_img = my_transforms(right_img).to(device)
17
+
18
+ # 读取模型
19
+ model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device)
20
+
21
+ with torch.no_grad():
22
+ output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))
23
+
24
+ output = torch.sigmoid(output.squeeze(0))
25
+ pred = output.cpu().numpy().tolist()
26
+
27
+ return {LABELS[i]: pred[i] for i in range(len(pred))}
28
+
29
+
30
+ if __name__ == '__main__':
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ # 标题
34
+ title = "基于眼底图像的智能健康诊断分析系统"
35
+ # 标题下的描述,支持md格式
36
+ description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \
37
+ "包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类"
38
+
39
+ # transforms设置
40
+ norm_mean = [0.485, 0.456, 0.406]
41
+ norm_std = [0.229, 0.224, 0.225]
42
+ my_transforms = transforms.Compose([
43
+ transforms.Resize((224, 224)),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(norm_mean, norm_std)
46
+ ])
47
+
48
+ LABELS = {0: '正常',
49
+ 1: '糖尿病',
50
+ 2: '青光眼',
51
+ 3: '白内障',
52
+ 4: '年龄性黄斑变性',
53
+ 5: '高血压',
54
+ 6: '病理性近视',
55
+ 7: '其他疾病'}
56
+
57
+ left_img_dir = 'left.jpg'
58
+ right_img_dir = 'right.jpg'
59
+ r = img2label(left_img_dir, right_img_dir)
60
+ demo = gr.Interface(fn=img2label, inputs=[gr.inputs.Image(), gr.inputs.Image()], outputs='label',
61
+ title=title, description=description)
62
+ demo.launch(share=True)
config/__pycache__/finetune_config.cpython-36.pyc ADDED
Binary file (1.38 kB). View file
 
config/finetune_config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def set_args():
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument('--save_name', default='densenet_ce_e4_b32_lr1e-4.pkl', type=str,
7
+ help='保存模型的名字,默认路径在 ./model_parameters/f')
8
+ parser.add_argument('--model_selection', default='MX', type=str,
9
+ help='模型选择,M 表示模改过的,没有则表示原来的')
10
+ parser.add_argument('--pt', default='FD', type=str,
11
+ help='有 FD 和 IN 两种,FD 表示预训练是眼底图像,IN 表示预训练是 imagenet')
12
+ parser.add_argument('--finetune_path', default='model_parameters/p/resnext_ce_e4_b32_lr1e-4.pkl', type=str,
13
+ help='所选用预训练模型的路径')
14
+ parser.add_argument('--feature_module', default='cat', type=str,
15
+ help='特征融合的方式,有 cat、mul、sum 三种方式,只在原始模型有效')
16
+ parser.add_argument('--opt', default='adamw', type=str, help='优化器')
17
+ parser.add_argument('--warmup_select', default='linear', type=str,
18
+ help='')
19
+ parser.add_argument('--MAX_EPOCH', default=4, type=int,
20
+ help='')
21
+ parser.add_argument('--BATCH_SIZE', default=32, type=int,
22
+ help='')
23
+ parser.add_argument('--start_epoch', default=0, type=int,
24
+ help='')
25
+ parser.add_argument('--LR', default=5e-4, type=float,
26
+ help='')
27
+ parser.add_argument('--WD', default=1e-2, type=float,
28
+ help='')
29
+ parser.add_argument('--adam_epsilon', default=1e-8, type=float,
30
+ help='')
31
+ parser.add_argument('--warmup_proportion', default=0.1, type=float,
32
+ help='')
33
+ return parser.parse_args()
densenet_FD_e4_l5e-4_b32.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:746019148ee61f72424b1f866b28ec220815413f3c5ca52436c6de24d9054b9a
3
+ size 732694486
models/__pycache__/modified_dual_densenet.cpython-36.pyc ADDED
Binary file (2.68 kB). View file
 
models/modified_dual_densenet.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision.models import densenet169
4
+ from config.finetune_config import set_args
5
+
6
+ args = set_args()
7
+
8
+
9
+ class Classifier(nn.Module):
10
+ def __init__(self, num_classes):
11
+ super(Classifier, self).__init__()
12
+ self.GDConv1 = nn.Conv2d(1664 * 2, 1024, kernel_size=4, padding=0, dilation=2)
13
+ self.GDConv2 = nn.Conv2d(1664 * 2, 1024, kernel_size=5, padding=1, dilation=2)
14
+ self.GDConv3 = nn.Conv2d(1664 * 2, 1024, kernel_size=3, padding=0, dilation=3)
15
+ self.LN1 = nn.LayerNorm([1024, 1, 1])
16
+ self.LN2 = nn.LayerNorm([1024, 1, 1])
17
+ self.LN3 = nn.LayerNorm([1024, 1, 1])
18
+ self.gelu = nn.GELU()
19
+ self.fc_dropout = nn.Dropout(0.2)
20
+ self.fc = nn.Linear(1024 * 3, num_classes)
21
+
22
+ for m in self.modules():
23
+ if isinstance(m, nn.Conv2d):
24
+ nn.init.kaiming_normal_(m.weight)
25
+ elif isinstance(m, nn.BatchNorm2d):
26
+ nn.init.constant_(m.weight, 1)
27
+ nn.init.constant_(m.bias, 0)
28
+ elif isinstance(m, nn.Linear):
29
+ nn.init.constant_(m.bias, 0)
30
+
31
+ def forward(self, x):
32
+ x1 = self.GDConv1(x)
33
+ x1 = self.LN1(x1)
34
+ x1 = x1.view(x1.size(0), -1)
35
+
36
+ x2 = self.GDConv2(x)
37
+ x2 = self.LN2(x2)
38
+ x2 = x2.view(x2.size(0), -1)
39
+
40
+ x3 = self.GDConv3(x)
41
+ x3 = self.LN3(x3)
42
+ x3 = x3.view(x3.size(0), -1)
43
+
44
+ X = torch.cat((x1, x2, x3), 1)
45
+ X = self.gelu(X)
46
+ output = self.fc(self.fc_dropout(X))
47
+
48
+ return output
49
+
50
+
51
+ class M_DenseNet(nn.Module):
52
+ def __init__(self, pretrain='IN', num_classes=8):
53
+ super(M_DenseNet, self).__init__()
54
+ # feature layer
55
+ if pretrain == 'IN':
56
+ model = densenet169(pretrained=True) # 此处的model参数是已经加载了预训练参数的模型
57
+ self.feature = nn.Sequential(*list(model.children())[:-1])
58
+ else:
59
+ model = torch.load(args.finetune_path)
60
+ self.feature = nn.Sequential(*list(model.children())[:-2])
61
+ self.classifier = Classifier(num_classes)
62
+
63
+ def forward(self, left, right):
64
+ left = self.feature(left)
65
+ right = self.feature(right)
66
+ x = torch.cat((left, right), 1)
67
+
68
+ X = self.classifier(x)
69
+
70
+ return X
71
+
72
+
73
+ if __name__ == '__main__':
74
+ model = M_DenseNet()
75
+ input1 = torch.normal(0, 1, size=(4, 3, 224, 224))
76
+ input2 = torch.normal(0, 1, size=(4, 3, 224, 224))
77
+ output = model(input1, input2)
78
+ print(output)