fkonovalenko commited on
Commit
4e11364
β€’
1 Parent(s): 0b91798

load model

Browse files
Files changed (5) hide show
  1. README.md +2 -2
  2. app.py +106 -0
  3. pretrained_yolo.pt +3 -0
  4. requirements.txt +8 -0
  5. utils.py +100 -0
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: Hertz
3
- emoji: 🐒
4
- colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.28.2
 
1
  ---
2
  title: Hertz
3
+ emoji: πŸ™
4
+ colorFrom: green
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.28.2
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ from utils import segment, plotter, writer
5
+
6
+
7
+ class GlobalState:
8
+ """
9
+ Class to store global variables
10
+ """
11
+ heart_area = [.54, .5, -.14, .16]
12
+ smooth_factor = 3
13
+ video_file_path = os.path.join(os.path.dirname(__file__), 'videos/')
14
+ input_path = None
15
+ line = None
16
+ result_file_path = os.path.join(os.path.dirname(__file__), 'result/result.mp4')
17
+ result_folder = os.path.join(os.path.dirname(__file__), 'result/')
18
+ yolo_path = os.path.join(os.path.dirname(__file__), 'pretrained_yolo.pt')
19
+
20
+
21
+ def upload_video_file(fid):
22
+ """
23
+ uploads and save video to workdir
24
+ """
25
+ raw_path = os.path.join(GlobalState.video_file_path, os.path.basename(fid.name))
26
+ shutil.move(fid.name, raw_path)
27
+ GlobalState.input_path = raw_path
28
+ gr.Info("Video uploaded")
29
+ return gr.update('Run!')
30
+
31
+
32
+ def processing(sl):
33
+ GlobalState.smooth_factor = int(sl)
34
+ graph, frames, message = segment(GlobalState.input_path, GlobalState.yolo_path, start=0, fstep=1,
35
+ crop=GlobalState.heart_area)
36
+ gr.Info(message)
37
+
38
+ if message == 'Video processing succeeded':
39
+ GlobalState.line = graph
40
+ writer(GlobalState.result_file_path, frames)
41
+ gr.Info('Processed video saved!')
42
+ return gr.update(visible=True), gr.update(visible=True)
43
+ else:
44
+ return gr.update(visible=False), gr.update(visible=False)
45
+
46
+
47
+ def plot_graph(sl):
48
+ sl = int(sl)
49
+ result, text = plotter(GlobalState.line, sl)
50
+ return result, gr.update(value=text)
51
+
52
+ def show_video(btn):
53
+ return gr.update(label="Segmented Echo", value=GlobalState.result_file_path)
54
+
55
+ def main():
56
+
57
+ shutil.rmtree(os.path.join(os.path.dirname(__file__), 'videos/'), ignore_errors=True)
58
+ shutil.rmtree(os.path.join(os.path.dirname(__file__), 'result/'), ignore_errors=True)
59
+ os.mkdir(os.path.join(os.path.dirname(__file__), 'videos/'))
60
+ os.mkdir(os.path.join(os.path.dirname(__file__), 'result/'))
61
+
62
+ with gr.Blocks() as demo:
63
+ with gr.Tab("Load"):
64
+ with gr.Row():
65
+ gr.Markdown(
66
+ """
67
+ # Load video file πŸ«€
68
+ # Then press **Run!**
69
+ # Have fun:)
70
+ """)
71
+ with gr.Row():
72
+ with gr.Column():
73
+ with gr.Row():
74
+ video_upload = gr.File(label="Upload heart Echo", file_types=["video"], file_count="single")
75
+ with gr.Row():
76
+ process_button = gr.Button("Run!")
77
+ with gr.Row():
78
+ player = gr.Video(label="Segmented Echo", value=None, format='mp4')
79
+
80
+ with gr.Column():
81
+ with gr.Row():
82
+ smoother = gr.Slider(1, 50, 5, 1, label="Rolling Mean Window")
83
+ with gr.Row():
84
+ messenger = gr.Textbox(label='Ejection Fracture', value=None)
85
+ with gr.Row():
86
+ plot = gr.LinePlot(x="Frame", y="Left ventricle visible area, px*px",
87
+ overlay_point=False,
88
+ tooltip=["Frame", "Left ventricle visible area, px*px"],
89
+ width=500, height=300)
90
+ with gr.Row():
91
+ show_graph = gr.Button('Plot', visible=False)
92
+ with gr.Row():
93
+ show_button = gr.Button("Show result!")
94
+
95
+ video_upload.upload(upload_video_file, video_upload, outputs=[process_button], show_progress='full')
96
+ process_button.click(processing, inputs=[smoother], outputs=[show_graph, show_button], show_progress='full')
97
+ show_graph.click(plot_graph, inputs=[smoother], outputs=[plot, messenger])
98
+ show_button.click(show_video, outputs=[player])
99
+ player.change(show_video, outputs=[player])
100
+
101
+
102
+ demo.launch(allowed_paths=[GlobalState.result_folder])
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
pretrained_yolo.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3d873d2d48c135652359220459043186b8cb276a18ed2f3c39a27af7d51e02c
3
+ size 54836747
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ tqdm
3
+ matplotlib
4
+ ultralytics
5
+ numpy
6
+ pandas
7
+ gradio
8
+ ffmpeg-python
utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ultralytics
2
+ from ultralytics import YOLO
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+ from matplotlib import pyplot as plt
8
+ import cv2
9
+ import warnings
10
+ import ffmpeg
11
+
12
+ warnings.filterwarnings("ignore")
13
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
14
+
15
+
16
+
17
+
18
+ def segment(video_path: str, model_path: str, start: int, fstep: int, crop: list) -> tuple:
19
+ """
20
+ runs YOLO segmentation model and calculates the LV Area
21
+ """
22
+ model = YOLO(model_path)#.to(device)
23
+ cap = cv2.VideoCapture(video_path)
24
+ stop = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
25
+ lv_area = []
26
+ frames = []
27
+ message = 'Video processing succeeded'
28
+ for fr in tqdm(range(start, stop, fstep), desc=f'processing ECHO'):
29
+ cap.set(cv2.CAP_PROP_POS_FRAMES, fr)
30
+ _, frame = cap.read()
31
+ new_w = int(frame.shape[1] * crop[1])
32
+ new_h = int(frame.shape[0] * crop[0])
33
+ new_left = int(frame.shape[1] / 2 + crop[3] * new_w - new_w / 2)
34
+ new_top = int(frame.shape[0] / 2 + crop[2] * new_h - new_h / 2)
35
+ frame = frame[new_top:new_top + new_h, new_left:new_left + new_w]
36
+ frame_m = frame
37
+ inputs = frame #torch.Tensor(frame).to(device)
38
+ with torch.no_grad():
39
+ result = model(inputs, verbose=False)
40
+ result = result#.to('cpu')
41
+ classes = result[0].names
42
+ if len(classes) == 0:
43
+ pass
44
+ overlay = frame.copy()
45
+ color_list = [(255, 0, 0),
46
+ (255, 255, 0),
47
+ (255, 0, 255),
48
+ (0, 255, 0),
49
+ (0, 0, 255),
50
+ (128, 128, 128)]
51
+ for i, res in enumerate(result[0]):
52
+ bx = res.boxes
53
+ m = res.masks.xy
54
+ label = int(bx.cls.squeeze().cpu())
55
+ if label == 1:
56
+ lv_area.append(cv2.contourArea(m[0]))
57
+ box = list(map(int, bx.xyxy.squeeze().cpu().tolist()))
58
+ cv2.rectangle(overlay, (box[0], box[1]), (box[2], box[3]), (36, 255, 12), 2)
59
+ cv2.putText(overlay, classes[label], (box[0], box[1] - 5), cv2.FONT_HERSHEY_TRIPLEX, 1, (0, 0, 255), 2)
60
+ cv2.fillPoly(overlay, pts=np.int32([m]), color=color_list[i % 6])
61
+ alpha = 0.4
62
+ frame_m = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0)
63
+ frames.append(frame_m)
64
+ if len(lv_area) == 0:
65
+ message = 'Video processing failed'
66
+ return lv_area, frames, message
67
+
68
+
69
+ def plotter(lv_data: list, window: int) -> tuple:
70
+ """
71
+ plots the rolling mean graph for LV area.
72
+ calculates the average ejection fracture
73
+ """
74
+ lv_rolling = pd.Series(lv_data).rolling(window=window).mean().dropna()
75
+ ef = (max(lv_rolling) - min(lv_rolling)) / max(lv_rolling)
76
+ dataframe = pd.DataFrame({
77
+ 'Frame': np.array(range(len(lv_rolling))),
78
+ 'Left ventricle visible area, px*px': lv_rolling.values
79
+ }).astype('int32')
80
+ txt = f'Ejection fraction - {ef:.1%}'
81
+ return dataframe, txt
82
+
83
+
84
+ def writer(fn, images, framerate=25, vcodec='libx264'):
85
+ if not isinstance(images, np.ndarray):
86
+ images = np.asarray(images)
87
+ n, height, width, channels = images.shape
88
+ process = (
89
+ ffmpeg
90
+ .input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height))
91
+ .output(fn, pix_fmt='yuv420p', vcodec=vcodec, r=framerate)
92
+ .overwrite_output()
93
+ .run_async(pipe_stdin=True)
94
+ )
95
+ for frame in images:
96
+ process.stdin.write(
97
+ frame.astype(np.uint8).tobytes()
98
+ )
99
+ process.stdin.close()
100
+ process.wait()