wogh2012 commited on
Commit
b7914f1
Β·
1 Parent(s): aefacda

refactor: remove aitiautils

Browse files
__pycache__/model.cpython-311.pyc ADDED
Binary file (1.78 kB). View file
 
app.py CHANGED
@@ -1,13 +1,14 @@
1
- # from io import BytesIO
 
 
 
2
  import gradio as gr
3
  import pandas as pd
4
  import numpy as np
5
  import torch
6
- import os
7
- from aitiautils.model_loader import ModelLoader
8
- import tempfile
9
  import matplotlib.pyplot as plt
10
- import traceback as tb
 
11
 
12
  # True 이면, tmp directory 에 파일 쑴재 μœ λ¬΄μ™€ 상관없이 항상 μƒˆλ‘œμš΄ 이미지 생성
13
  ALWAYS_RECREATE_IMAGE = os.getenv("ALWAYS_RECREATE_IMAGE", "False").lower() == "true"
@@ -24,11 +25,14 @@ test_df = pd.read_csv("./res/ludb/dataset/test_for_public.csv").drop_duplicates(
24
  cutoffs = [0.001163482666015625, 0.15087890625, -0.587890625]
25
  lead_names = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
26
 
 
 
27
 
28
  def gen_seg(subject_id):
29
  input = np.load(f"./res/ludb/ecg_np/{subject_id}.npy")
30
- network = ModelLoader("./res/models/hrnetv2/checkpoint.pth").get_network()
31
- output: torch.Tensor = network(torch.from_numpy(input)).detach().numpy()
 
32
  seg = [(output[:, i, :] >= cutoffs[i]).astype(int) for i in range(len(cutoffs))]
33
  return input, np.stack(seg, axis=1)
34
 
 
1
+ import os
2
+ import tempfile
3
+ import traceback as tb
4
+
5
  import gradio as gr
6
  import pandas as pd
7
  import numpy as np
8
  import torch
 
 
 
9
  import matplotlib.pyplot as plt
10
+
11
+ from model import HRNetV2Wrapper
12
 
13
  # True 이면, tmp directory 에 파일 쑴재 μœ λ¬΄μ™€ 상관없이 항상 μƒˆλ‘œμš΄ 이미지 생성
14
  ALWAYS_RECREATE_IMAGE = os.getenv("ALWAYS_RECREATE_IMAGE", "False").lower() == "true"
 
25
  cutoffs = [0.001163482666015625, 0.15087890625, -0.587890625]
26
  lead_names = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
27
 
28
+ hrnetv2_wrapper = HRNetV2Wrapper()
29
+
30
 
31
  def gen_seg(subject_id):
32
  input = np.load(f"./res/ludb/ecg_np/{subject_id}.npy")
33
+ output: torch.Tensor = (
34
+ hrnetv2_wrapper.model(torch.from_numpy(input)).detach().numpy()
35
+ )
36
  seg = [(output[:, i, :] >= cutoffs[i]).astype(int) for i in range(len(cutoffs))]
37
  return input, np.stack(seg, axis=1)
38
 
model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from res.impl.HRNetV2 import HRNetV2
2
+ import torch
3
+
4
+
5
+ class Config:
6
+ pass
7
+
8
+
9
+ class HRNetV2Wrapper:
10
+ def __init__(self):
11
+ config = Config()
12
+ config.data_len = 5000
13
+ config.kernel_size = 5
14
+ config.dilation = 1
15
+ config.num_stages = 3
16
+ config.num_blocks = 6
17
+ config.num_modules = [1, 1, 1, 4, 3]
18
+ config.use_bottleneck = [1, 0, 0, 0, 0]
19
+ config.stage1_channels = 128
20
+ config.num_channels_init = 48
21
+ config.interpolate_mode = "linear"
22
+ config.output_size = 3
23
+ self.model = HRNetV2(config)
24
+ weights = torch.load("./res/models/hrnetv2/weights.pth")
25
+ self.model.load_state_dict(weights)
26
+ self.model = self.model.to("cpu").eval()
requirements.txt CHANGED
@@ -1,8 +1,6 @@
1
  --extra-index-url https://download.pytorch.org/whl/cpu
2
-
3
- # ./res/whl/aitiautils-1.1.59-py3-none-any.whl[torch]
4
- https://huggingface.co/spaces/MedicalAI-DP/ECG_Delineation/resolve/main/res/whl/aitiautils-1.1.59-py3-none-any.whl
5
  torch==2.0.0+cpu
6
  matplotlib==3.3.0
7
- netcal==1.3.5
 
8
  gradio==5.12.0
 
1
  --extra-index-url https://download.pytorch.org/whl/cpu
 
 
 
2
  torch==2.0.0+cpu
3
  matplotlib==3.3.0
4
+ pandas==2.2.3
5
+ numpy==1.26.4
6
  gradio==5.12.0
res/impl/__pycache__/HRNetV2.cpython-311.pyc ADDED
Binary file (17.8 kB). View file
 
res/models/hrnetv2/{checkpoint.pth β†’ weights.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2fd6e031d55efa48257225ebb2dce0e861fc02080830e655c61a618a227e1537
3
- size 44641506
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d4d77219df74e4e69dc2feb33322caaf1b273353183e21c60385c9141d57bec
3
+ size 22389961
res/whl/aitiautils-1.1.59-py3-none-any.whl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:99bb506c65d15b19d19691156efaa2fd15dcc78ea31913f817417c31d008100b
3
- size 14134456
 
 
 
 
utils.py DELETED
@@ -1,124 +0,0 @@
1
- def show_segmentation(ti=0, target_dir: str = "./seg"):
2
- fig = plt.figure(figsize=(15, 2 * len(Alg) * 2))
3
- fig.subplots_adjust(hspace=0, wspace=0.1)
4
-
5
- # for lead_cnt, ti in enumerate(range(12 * target_idx, 12 * (target_idx + 1))):
6
- test_df = pd.read_csv(test_path)
7
- lead_type = ast.literal_eval(test_df.iloc[ti]["lead_type"])[0]
8
- seq = test_df.iloc[ti]["seq"]
9
- object_id: str = test_df.iloc[ti]["objectid"]
10
- ecg_data_path = f"/bfai/data/ecg_data/{object_id[18:22]}/{object_id}.json"
11
- # ecg_data_path = test_df.iloc[ti]["file_path"]
12
- ecg_data = []
13
- with open(ecg_data_path) as ecg_data_file:
14
- ecg_json = json.load(ecg_data_file)
15
- ecg_data = (
16
- np.array(ecg_json["waveform"]["data"][lead_type])
17
- * ecg_json["study"]["mv_unit"]
18
- )
19
- if seq != "1/1":
20
- seq_idx, seq_range = [int(str_seq) for str_seq in seq.split("/")]
21
- use_lead_length = 5000 # (500 * 10)
22
- total_use_length = use_lead_length * seq_range
23
- front_idx = int((len(ecg_data) - total_use_length) / 2)
24
- ecg_data = ecg_data[front_idx:]
25
-
26
- for alg, alg_idx in ALG_ORDER.items():
27
- for pp_type_idx, pp_type in enumerate(["ori", "pp"]):
28
- sub_fig = fig.add_subplot(len(Alg) * 2, 1, 2 * alg_idx + pp_type_idx + 1)
29
- # sub_fig.set_title(f"{alg} - Lead {lead_type}")
30
- sub_fig.text(
31
- 0.02,
32
- 0.5,
33
- # f"{alg} {pp_type} - {lead_type}",
34
- f"{alg} {pp_type}\n{lead_type}",
35
- fontsize=9,
36
- fontweight="bold",
37
- ha="center",
38
- va="center",
39
- rotation=90,
40
- transform=sub_fig.transAxes,
41
- )
42
- sub_fig.set_xticks([])
43
- sub_fig.set_yticks([])
44
-
45
- sub_fig.plot(range(len(ecg_data)), ecg_data, color="black", linewidth=1.0)
46
- sub_fig.plot(
47
- range(len(output[alg][ti][0])),
48
- (output[alg][ti][0] >= cutoff[alg][0]).astype(int) / 2
49
- if pp_type_idx == 0
50
- else (pp_output[alg][ti][0]).astype(int) / 2,
51
- label="P",
52
- color="red",
53
- linewidth=0.7,
54
- )
55
- sub_fig.plot(
56
- range(len(output[alg][ti][1])),
57
- (output[alg][ti][1] >= cutoff[alg][1]).astype(int)
58
- if pp_type_idx == 0
59
- else (pp_output[alg][ti][1]).astype(int),
60
- label="QRS",
61
- color="green",
62
- linewidth=0.7,
63
- )
64
- sub_fig.plot(
65
- (output[alg][ti][2] >= cutoff[alg][2]).astype(int) / 2
66
- if pp_type_idx == 0
67
- else (pp_output[alg][ti][2]).astype(int) / 2,
68
- label="T",
69
- color="blue",
70
- linewidth=0.7,
71
- )
72
-
73
- sub_fig.plot(
74
- range(len(origin_seg[ti][0])),
75
- ((origin_seg[ti][0] > 0).astype(int) * (-1)) / 2,
76
- label="P Label",
77
- color="salmon",
78
- linewidth=0.7,
79
- )
80
- sub_fig.plot(
81
- range(len(origin_seg[ti][1])),
82
- ((origin_seg[ti][1] > 0).astype(int) * (-1)),
83
- label="QRS Label",
84
- color="seagreen",
85
- linewidth=0.7,
86
- )
87
- sub_fig.plot(
88
- range(len(origin_seg[ti][2])),
89
- ((origin_seg[ti][2] > 0).astype(int) * (-1)) / 2,
90
- label="T Label",
91
- color="darkslateblue",
92
- linewidth=0.7,
93
- )
94
-
95
- sub_fig.plot(
96
- range(len(origin_seg[ti][0])),
97
- ((origin_seg[ti][0] < 0).astype(int) * (-1)),
98
- label="P UnLabeled",
99
- linestyle=":",
100
- color="salmon",
101
- linewidth=0.5,
102
- )
103
- sub_fig.plot(
104
- range(len(origin_seg[ti][1])),
105
- ((origin_seg[ti][1] < 0).astype(int) * (-1)),
106
- label="QRS UnLabeled",
107
- linestyle=":",
108
- color="seagreen",
109
- linewidth=0.5,
110
- )
111
- sub_fig.plot(
112
- range(len(origin_seg[ti][2])),
113
- ((origin_seg[ti][2] < 0).astype(int) * (-1)),
114
- label="T UnLabeled",
115
- linestyle=":",
116
- color="darkslateblue",
117
- linewidth=0.5,
118
- )
119
- # sub_fig.legend()
120
- plt.savefig(
121
- f"./{target_dir}/{test_df.iloc[ti]['lead_cnt']}_{object_id}_{lead_type}_{test_df.iloc[ti]['seq'].replace('/', '_')}.png",
122
- dpi=150,
123
- )
124
- plt.close()