File size: 2,625 Bytes
be05fd1
 
 
 
 
 
 
3b16162
 
be05fd1
 
3b16162
be05fd1
 
3b16162
 
be05fd1
 
 
 
 
 
 
 
9ed6175
 
 
3b16162
 
be05fd1
3b16162
9ed6175
3b16162
 
 
 
 
9ed6175
be05fd1
 
 
865f10a
be05fd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b16162
 
 
7bd876a
9ed6175
 
 
 
1e947cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms


def img2label(left, right):
    left = Image.fromarray(left.astype('uint8'), 'RGB')
    right = Image.fromarray(right.astype('uint8'), 'RGB')
    # 将右眼底镜像反转
    r2l = transforms.RandomHorizontalFlip(p=1)
    right = r2l(right)

    # 调整图片
    left_img = my_transforms(left).to(device)
    right_img = my_transforms(right).to(device)

    # 读取模型
    model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device)

    with torch.no_grad():
        output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0))

    output = torch.sigmoid(output.squeeze(0))
    output_ = output.cpu().numpy().tolist()
    res_dict = {LABELS[i]: output_[i] for i in range(len(output_))}

    pred = torch.nonzero(output > 0.4).view(-1)
    pred = pred.cpu().numpy().tolist()

    if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0):
        return LABELS[0], res_dict
    res = ''
    for i in pred:
        if i == 0:
            continue
        res += ', ' + LABELS[i]
    return '目前的身体状态:' + res[2:], res_dict


if __name__ == '__main__':
    device = torch.device("cpu")

    # 标题
    title = "基于眼底图像的智能健康诊断分析系统"
    # 标题下的描述,支持md格式
    description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \
                  "包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类"

    # transforms设置
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
    my_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std)
    ])

    LABELS = {0: '正常',
              1: '糖尿病',
              2: '青光眼',
              3: '白内障',
              4: '年龄性黄斑变性',
              5: '高血压',
              6: '病理性近视',
              7: '其他疾病'}

    left_img_dir = 'left.jpg'
    right_img_dir = 'right.jpg'
    examples = [[left_img_dir, right_img_dir]]
    # r = img2label(left_img_dir, right_img_dir)
    demo = gr.Interface(fn=img2label,
                        inputs=[gr.inputs.Image(), gr.inputs.Image()],
                        outputs=["text", "label"],
                        examples=examples, title=title, description=description)
    demo.launch()