Spaces:
Sleeping
Sleeping
umyuu
commited on
Commit
·
da8bdb9
1
Parent(s):
271d94c
refactoring
Browse files- App.pyの処理内容をSaliencyMapクラスに分割。
- 実行ログ出力用としてreporter.pyを追加。
- src/app.py +80 -46
- src/launch.py +7 -4
- src/reporter.py +49 -0
- src/saliency.py +75 -0
src/app.py
CHANGED
@@ -6,54 +6,73 @@ import argparse
|
|
6 |
from datetime import datetime
|
7 |
import sys
|
8 |
|
9 |
-
import cv2
|
10 |
import gradio as gr
|
11 |
import numpy as np
|
12 |
|
13 |
import utils
|
14 |
-
|
|
|
|
|
15 |
PROGRAM_NAME = 'SaliencyMapDemo'
|
16 |
__version__ = utils.get_package_version()
|
|
|
17 |
|
18 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
"""
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
np.ndarray
|
29 |
-
カラーマップのJET画像
|
30 |
"""
|
31 |
-
#
|
32 |
-
|
33 |
-
|
|
|
34 |
success, saliencyMap = saliency.computeSaliency(image)
|
|
|
35 |
|
36 |
-
if success:
|
37 |
-
#
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
|
49 |
def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
|
50 |
"""
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
コマンドライン引数
|
55 |
-
|
56 |
-
起動したスタート時間
|
57 |
"""
|
58 |
# analytics_enabled=False
|
59 |
# https://github.com/gradio-app/gradio/issues/4226
|
@@ -68,22 +87,32 @@ def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
|
|
68 |
gr.Markdown(
|
69 |
"""
|
70 |
# Saliency Map demo.
|
71 |
-
1. inputタブで画像を選択します。
|
72 |
-
2. Submitボタンを押します。
|
73 |
-
※画像は外部送信していません。ローカルで処理が完結します。
|
74 |
-
3. 結果は、overlayタブに表示します。
|
75 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
submit_button = gr.Button("submit")
|
78 |
-
|
79 |
with gr.Row():
|
80 |
-
with gr.Tab("input"):
|
81 |
-
image_input = gr.Image()
|
82 |
-
with gr.Tab("overlay"):
|
83 |
-
image_overlay = gr.Image(interactive=False)
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
submit_button.click(
|
87 |
|
88 |
gr.Markdown(
|
89 |
f"""
|
@@ -93,7 +122,12 @@ def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
|
|
93 |
|
94 |
demo.queue(default_concurrency_limit=5)
|
95 |
|
96 |
-
|
97 |
|
98 |
# https://www.gradio.app/docs/gradio/blocks#blocks-launch
|
99 |
-
demo.launch(
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from datetime import datetime
|
7 |
import sys
|
8 |
|
|
|
9 |
import gradio as gr
|
10 |
import numpy as np
|
11 |
|
12 |
import utils
|
13 |
+
from saliency import SaliencyMap, convertColorMap
|
14 |
+
from reporter import get_current_reporter
|
15 |
+
|
16 |
PROGRAM_NAME = 'SaliencyMapDemo'
|
17 |
__version__ = utils.get_package_version()
|
18 |
+
log = get_current_reporter()
|
19 |
|
20 |
+
def jetTab_Selected(image: np.ndarray):
|
21 |
+
#print(f"{datetime.now()}#jet")
|
22 |
+
saliency = SaliencyMap("SpectralResidual")
|
23 |
+
success, saliencyMap = saliency.computeSaliency(image)
|
24 |
+
retval = convertColorMap(image, saliencyMap, "jet")
|
25 |
+
#print(f"{datetime.now()}#jet")
|
26 |
+
|
27 |
+
return retval
|
28 |
+
|
29 |
+
def hotTab_Selected(image: np.ndarray):
|
30 |
+
#print(f"{datetime.now()}#hot")
|
31 |
+
saliency = SaliencyMap("SpectralResidual")
|
32 |
+
success, saliencyMap = saliency.computeSaliency(image)
|
33 |
+
retval = convertColorMap(image, saliencyMap, "hot")
|
34 |
+
#print(f"{datetime.now()}#hot")
|
35 |
+
|
36 |
+
return retval
|
37 |
+
|
38 |
+
def submit_Clicked(image: np.ndarray, algorithm: str):
|
39 |
"""
|
40 |
+
入力画像を元に顕著マップを計算します。
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
image: 入力画像
|
44 |
+
str: 顕著性マップのアルゴリズム
|
45 |
+
Returns:
|
46 |
+
np.ndarray: JET画像
|
47 |
+
np.ndarray: HOT画像
|
|
|
|
|
48 |
"""
|
49 |
+
log.info(f"#submit_Clicked")
|
50 |
+
watch = utils.Stopwatch.startNew()
|
51 |
+
|
52 |
+
saliency = SaliencyMap(algorithm)
|
53 |
success, saliencyMap = saliency.computeSaliency(image)
|
54 |
+
log.info(f"#SaliencyMap computeSaliency()")
|
55 |
|
56 |
+
if not success:
|
57 |
+
return image, image # エラーが発生した場合は入力画像を返します。
|
58 |
+
|
59 |
+
log.info(f"#jet")
|
60 |
+
jet = convertColorMap(image, saliencyMap, "jet")
|
61 |
+
#jet = None
|
62 |
+
log.info(f"#hot")
|
63 |
+
hot = convertColorMap(image, saliencyMap, "hot")
|
64 |
+
|
65 |
+
saliency = None
|
66 |
+
log.info(f"#submit_Clicked End{watch.stop():.3f}")
|
67 |
+
return jet, hot
|
68 |
|
69 |
def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
|
70 |
"""
|
71 |
+
アプリの画面を作成し、Gradioサービスを起動します。
|
72 |
+
|
73 |
+
Parameters:
|
74 |
+
args: コマンドライン引数
|
75 |
+
watch: 起動したスタート時間
|
|
|
76 |
"""
|
77 |
# analytics_enabled=False
|
78 |
# https://github.com/gradio-app/gradio/issues/4226
|
|
|
87 |
gr.Markdown(
|
88 |
"""
|
89 |
# Saliency Map demo.
|
|
|
|
|
|
|
|
|
90 |
""")
|
91 |
+
with gr.Accordion("取り扱い説明書", open=False):
|
92 |
+
gr.Markdown(
|
93 |
+
"""
|
94 |
+
1. inputタブで画像を選択します。
|
95 |
+
2. Submitボタンを押します。
|
96 |
+
※画像は外部送信していません。ローカルで処理が完結します。
|
97 |
+
3. 結果は、JETタブとHOTタブに表示します。
|
98 |
+
""")
|
99 |
+
|
100 |
+
algorithmType = gr.Radio(["SpectralResidual", "FineGrained"], label="Saliency", value="SpectralResidual", interactive=True)
|
101 |
|
102 |
submit_button = gr.Button("submit")
|
103 |
+
|
104 |
with gr.Row():
|
105 |
+
with gr.Tab("input", id="input"):
|
|
|
|
|
|
|
106 |
|
107 |
+
image_input = gr.Image(sources = ["upload", "clipboard"], interactive=True)
|
108 |
+
with gr.Tab("overlay(JET)"):
|
109 |
+
image_overlay_jet = gr.Image(interactive=False)
|
110 |
+
#tab_jet.select(jetTab_Selected, inputs=[image_input], outputs=image_overlay_jet)
|
111 |
+
with gr.Tab("overlay(HOT)"):
|
112 |
+
image_overlay_hot = gr.Image(interactive=False)
|
113 |
+
#tab_hot.select(hotTab_Selected, inputs=[image_input], outputs=image_overlay_hot, api_name=False)
|
114 |
|
115 |
+
submit_button.click(submit_Clicked, inputs=[image_input, algorithmType], outputs=[image_overlay_jet, image_overlay_hot])
|
116 |
|
117 |
gr.Markdown(
|
118 |
f"""
|
|
|
122 |
|
123 |
demo.queue(default_concurrency_limit=5)
|
124 |
|
125 |
+
log.info(f"#アプリ起動完了({watch.stop():.3f}s)")
|
126 |
|
127 |
# https://www.gradio.app/docs/gradio/blocks#blocks-launch
|
128 |
+
demo.launch(
|
129 |
+
max_file_size=args.max_file_size,
|
130 |
+
server_port=args.server_port,
|
131 |
+
inbrowser=True,
|
132 |
+
share=False,
|
133 |
+
)
|
src/launch.py
CHANGED
@@ -6,15 +6,18 @@ import argparse
|
|
6 |
from datetime import datetime
|
7 |
|
8 |
from utils import get_package_version, Stopwatch
|
|
|
9 |
|
10 |
def main():
|
11 |
"""
|
12 |
エントリーポイント
|
13 |
-
コマンドライン引数の解析を行います
|
|
|
14 |
"""
|
15 |
-
|
|
|
16 |
watch = Stopwatch.startNew()
|
17 |
-
|
18 |
import app
|
19 |
|
20 |
parser = argparse.ArgumentParser(prog=app.PROGRAM_NAME, description="SaliencyMapDemo")
|
@@ -25,4 +28,4 @@ def main():
|
|
25 |
app.run(parser.parse_args(), watch)
|
26 |
|
27 |
if __name__ == "__main__":
|
28 |
-
|
|
|
6 |
from datetime import datetime
|
7 |
|
8 |
from utils import get_package_version, Stopwatch
|
9 |
+
from reporter import get_current_reporter
|
10 |
|
11 |
def main():
|
12 |
"""
|
13 |
エントリーポイント
|
14 |
+
1, コマンドライン引数の解析を行います
|
15 |
+
2, アプリを起動します。
|
16 |
"""
|
17 |
+
log = get_current_reporter()
|
18 |
+
log.info("#アプリ起動中")
|
19 |
watch = Stopwatch.startNew()
|
20 |
+
|
21 |
import app
|
22 |
|
23 |
parser = argparse.ArgumentParser(prog=app.PROGRAM_NAME, description="SaliencyMapDemo")
|
|
|
28 |
app.run(parser.parse_args(), watch)
|
29 |
|
30 |
if __name__ == "__main__":
|
31 |
+
main()
|
src/reporter.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Reporter
|
4 |
+
ログハンドラーが重複登録されるのを防ぐために1箇所で生成してログハンドラーを返します。
|
5 |
+
Example:
|
6 |
+
from reporter import get_current_reporter
|
7 |
+
|
8 |
+
logger = get_current_reporter()
|
9 |
+
logger.info("message");
|
10 |
+
"""
|
11 |
+
from logging import Logger, getLogger, Formatter, StreamHandler
|
12 |
+
from logging import DEBUG
|
13 |
+
|
14 |
+
_reporters = []
|
15 |
+
|
16 |
+
def get_current_reporter() -> Logger:
|
17 |
+
return _reporters[-1]
|
18 |
+
|
19 |
+
def __make_reporter(name: str='SaliencyMapDemo') -> None:
|
20 |
+
"""
|
21 |
+
ログハンドラーを生成します。
|
22 |
+
@see https://docs.python.jp/3/howto/logging-cookbook.html#logging-to-a-single-file-from-multiple-processes
|
23 |
+
|
24 |
+
Parameters:
|
25 |
+
name: アプリ名
|
26 |
+
"""
|
27 |
+
handler = StreamHandler() # コンソールに出力します。
|
28 |
+
formatter = Formatter('%(asctime)s%(message)s')
|
29 |
+
handler.setFormatter(formatter)
|
30 |
+
handler.setLevel(DEBUG)
|
31 |
+
|
32 |
+
logger = getLogger(name)
|
33 |
+
logger.setLevel(DEBUG)
|
34 |
+
logger.addHandler(handler)
|
35 |
+
_reporters.append(logger)
|
36 |
+
|
37 |
+
__make_reporter()
|
38 |
+
|
39 |
+
def main():
|
40 |
+
"""
|
41 |
+
Entry Point
|
42 |
+
"""
|
43 |
+
assert len(_reporters) == 1
|
44 |
+
|
45 |
+
logger = get_current_reporter()
|
46 |
+
logger.debug("main")
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
main()
|
src/saliency.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from typing import Literal
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
|
8 |
+
class SaliencyMap:
|
9 |
+
"""
|
10 |
+
SaliencyMap 顕著性マップを計算するクラスです。
|
11 |
+
Example:
|
12 |
+
from lib.saliency import SaliencyMap
|
13 |
+
|
14 |
+
saliency = SaliencyMap("SpectralResidual")
|
15 |
+
success, saliencyMap = saliency.computeSaliency(image)
|
16 |
+
"""
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
algorithm: Literal["SpectralResidual", "FineGrained"] = "SpectralResidual",
|
20 |
+
):
|
21 |
+
self.algorithm = algorithm
|
22 |
+
# OpenCVのsaliencyを作成します。
|
23 |
+
if algorithm == "SpectralResidual":
|
24 |
+
self.saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
|
25 |
+
else:
|
26 |
+
self.saliency = cv2.saliency.StaticSaliencyFineGrained_create()
|
27 |
+
|
28 |
+
|
29 |
+
def computeSaliency(self, image: np.ndarray):
|
30 |
+
"""
|
31 |
+
入力画像から顕著性マップを作成します。
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
image: 入力画像
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
bool:
|
38 |
+
true: SaliencyMap computed, false:NG
|
39 |
+
np.ndarray: 顕著性マップ
|
40 |
+
"""
|
41 |
+
# 画像の顕著性を計算します。
|
42 |
+
return self.saliency.computeSaliency(image)
|
43 |
+
|
44 |
+
def convertColorMap(
|
45 |
+
image: np.ndarray,
|
46 |
+
saliencyMap: np.ndarray,
|
47 |
+
colormap_name: Literal["jet", "hot"] = "jet"):
|
48 |
+
"""
|
49 |
+
顕著性マップをカラーマップに変換後に、入力画像に重ね合わせします。
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
image: 入力画像
|
53 |
+
saliencyMap: 顕著性マップ
|
54 |
+
colormap_name: カラーマップの種類
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
np.ndarray: 重ね合わせた画像(RGBA形式)
|
58 |
+
"""
|
59 |
+
#image = (image * 255).astype("uint8")
|
60 |
+
#
|
61 |
+
#return cv2.applyColorMap(image, cv2.COLORMAP_JET)
|
62 |
+
|
63 |
+
|
64 |
+
# 顕著性マップをカラーマップに変換
|
65 |
+
saliencyMap = (saliencyMap * 255).astype("uint8")
|
66 |
+
if colormap_name == "jet":
|
67 |
+
saliencyMap = cv2.applyColorMap(saliencyMap, cv2.COLORMAP_JET)
|
68 |
+
else:
|
69 |
+
saliencyMap = cv2.applyColorMap(saliencyMap, cv2.COLORMAP_HOT)
|
70 |
+
#return saliencyMap
|
71 |
+
# 入力画像とカラーマップを重ね合わせ
|
72 |
+
overlay = cv2.addWeighted(image, 0.5, saliencyMap, 0.5, 0)
|
73 |
+
#return overlay
|
74 |
+
|
75 |
+
return cv2.cvtColor(overlay, cv2.COLOR_BGR2RGBA)
|