umyuu commited on
Commit
da8bdb9
·
1 Parent(s): 271d94c

refactoring

Browse files

- App.pyの処理内容をSaliencyMapクラスに分割。
- 実行ログ出力用としてreporter.pyを追加。

Files changed (4) hide show
  1. src/app.py +80 -46
  2. src/launch.py +7 -4
  3. src/reporter.py +49 -0
  4. src/saliency.py +75 -0
src/app.py CHANGED
@@ -6,54 +6,73 @@ import argparse
6
  from datetime import datetime
7
  import sys
8
 
9
- import cv2
10
  import gradio as gr
11
  import numpy as np
12
 
13
  import utils
14
-
 
 
15
  PROGRAM_NAME = 'SaliencyMapDemo'
16
  __version__ = utils.get_package_version()
 
17
 
18
- def compute_saliency(image: np.ndarray):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
- 入力画像から顕著性マップを作成しJET画像を返します。
21
- Parameters
22
- ----------
23
- param1 : np.ndarray
24
- 入力画像
25
-
26
- Returns
27
- -------
28
- np.ndarray
29
- カラーマップのJET画像
30
  """
31
- # OpenCVのsaliencyを作成
32
- saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
33
- # 画像の顕著性を計算
 
34
  success, saliencyMap = saliency.computeSaliency(image)
 
35
 
36
- if success:
37
- # 顕著性マップをカラーマップに変換
38
- saliencyMap = (saliencyMap * 255).astype("uint8")
39
- saliencyMap = cv2.applyColorMap(saliencyMap, cv2.COLORMAP_JET)
40
-
41
- #overlay = saliencyMap
42
- # 元の画像とカラーマップを重ね合わせ
43
- overlay = cv2.addWeighted(image, 0.5, saliencyMap, 0.5, 0)
44
-
45
- return overlay
46
- else:
47
- return image # エラーが発生した場合は元の画像を返す
48
 
49
  def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
50
  """
51
- アプリの画面を作成し、Gradioサービスを起動します。
52
- ----------
53
- param1 : argparse.Namespace
54
- コマンドライン引数
55
- param2 : utils.Stopwatch
56
- 起動したスタート時間
57
  """
58
  # analytics_enabled=False
59
  # https://github.com/gradio-app/gradio/issues/4226
@@ -68,22 +87,32 @@ def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
68
  gr.Markdown(
69
  """
70
  # Saliency Map demo.
71
- 1. inputタブで画像を選択します。
72
- 2. Submitボタンを押します。
73
- ※画像は外部送信していません。ローカルで処理が完結します。
74
- 3. 結果は、overlayタブに表示します。
75
  """)
 
 
 
 
 
 
 
 
 
 
76
 
77
  submit_button = gr.Button("submit")
78
-
79
  with gr.Row():
80
- with gr.Tab("input"):
81
- image_input = gr.Image()
82
- with gr.Tab("overlay"):
83
- image_overlay = gr.Image(interactive=False)
84
 
 
 
 
 
 
 
 
85
 
86
- submit_button.click(compute_saliency, inputs=image_input, outputs=image_overlay)
87
 
88
  gr.Markdown(
89
  f"""
@@ -93,7 +122,12 @@ def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
93
 
94
  demo.queue(default_concurrency_limit=5)
95
 
96
- print(f"{datetime.now()}:アプリ起動完了({watch.stop():.3f}s)")
97
 
98
  # https://www.gradio.app/docs/gradio/blocks#blocks-launch
99
- demo.launch(max_file_size=args.max_file_size, server_port=args.server_port, inbrowser=True, share=False)
 
 
 
 
 
 
6
  from datetime import datetime
7
  import sys
8
 
 
9
  import gradio as gr
10
  import numpy as np
11
 
12
  import utils
13
+ from saliency import SaliencyMap, convertColorMap
14
+ from reporter import get_current_reporter
15
+
16
  PROGRAM_NAME = 'SaliencyMapDemo'
17
  __version__ = utils.get_package_version()
18
+ log = get_current_reporter()
19
 
20
+ def jetTab_Selected(image: np.ndarray):
21
+ #print(f"{datetime.now()}#jet")
22
+ saliency = SaliencyMap("SpectralResidual")
23
+ success, saliencyMap = saliency.computeSaliency(image)
24
+ retval = convertColorMap(image, saliencyMap, "jet")
25
+ #print(f"{datetime.now()}#jet")
26
+
27
+ return retval
28
+
29
+ def hotTab_Selected(image: np.ndarray):
30
+ #print(f"{datetime.now()}#hot")
31
+ saliency = SaliencyMap("SpectralResidual")
32
+ success, saliencyMap = saliency.computeSaliency(image)
33
+ retval = convertColorMap(image, saliencyMap, "hot")
34
+ #print(f"{datetime.now()}#hot")
35
+
36
+ return retval
37
+
38
+ def submit_Clicked(image: np.ndarray, algorithm: str):
39
  """
40
+ 入力画像を元に顕著マップを計算します。
41
+
42
+ Parameters:
43
+ image: 入力画像
44
+ str: 顕著性マップのアルゴリズム
45
+ Returns:
46
+ np.ndarray: JET画像
47
+ np.ndarray: HOT画像
 
 
48
  """
49
+ log.info(f"#submit_Clicked")
50
+ watch = utils.Stopwatch.startNew()
51
+
52
+ saliency = SaliencyMap(algorithm)
53
  success, saliencyMap = saliency.computeSaliency(image)
54
+ log.info(f"#SaliencyMap computeSaliency()")
55
 
56
+ if not success:
57
+ return image, image # エラーが発生した場合は入力画像を返します。
58
+
59
+ log.info(f"#jet")
60
+ jet = convertColorMap(image, saliencyMap, "jet")
61
+ #jet = None
62
+ log.info(f"#hot")
63
+ hot = convertColorMap(image, saliencyMap, "hot")
64
+
65
+ saliency = None
66
+ log.info(f"#submit_Clicked End{watch.stop():.3f}")
67
+ return jet, hot
68
 
69
  def run(args: argparse.Namespace, watch: utils.Stopwatch) -> None:
70
  """
71
+ アプリの画面を作成し、Gradioサービスを起動します。
72
+
73
+ Parameters:
74
+ args: コマンドライン引数
75
+ watch: 起動したスタート時間
 
76
  """
77
  # analytics_enabled=False
78
  # https://github.com/gradio-app/gradio/issues/4226
 
87
  gr.Markdown(
88
  """
89
  # Saliency Map demo.
 
 
 
 
90
  """)
91
+ with gr.Accordion("取り扱い説明書", open=False):
92
+ gr.Markdown(
93
+ """
94
+ 1. inputタブで画像を選択します。
95
+ 2. Submitボタンを押します。
96
+ ※画像は外部送信していません。ローカルで処理が完結します。
97
+ 3. 結果は、JETタブとHOTタブに表示します。
98
+ """)
99
+
100
+ algorithmType = gr.Radio(["SpectralResidual", "FineGrained"], label="Saliency", value="SpectralResidual", interactive=True)
101
 
102
  submit_button = gr.Button("submit")
103
+
104
  with gr.Row():
105
+ with gr.Tab("input", id="input"):
 
 
 
106
 
107
+ image_input = gr.Image(sources = ["upload", "clipboard"], interactive=True)
108
+ with gr.Tab("overlay(JET)"):
109
+ image_overlay_jet = gr.Image(interactive=False)
110
+ #tab_jet.select(jetTab_Selected, inputs=[image_input], outputs=image_overlay_jet)
111
+ with gr.Tab("overlay(HOT)"):
112
+ image_overlay_hot = gr.Image(interactive=False)
113
+ #tab_hot.select(hotTab_Selected, inputs=[image_input], outputs=image_overlay_hot, api_name=False)
114
 
115
+ submit_button.click(submit_Clicked, inputs=[image_input, algorithmType], outputs=[image_overlay_jet, image_overlay_hot])
116
 
117
  gr.Markdown(
118
  f"""
 
122
 
123
  demo.queue(default_concurrency_limit=5)
124
 
125
+ log.info(f"#アプリ起動完了({watch.stop():.3f}s)")
126
 
127
  # https://www.gradio.app/docs/gradio/blocks#blocks-launch
128
+ demo.launch(
129
+ max_file_size=args.max_file_size,
130
+ server_port=args.server_port,
131
+ inbrowser=True,
132
+ share=False,
133
+ )
src/launch.py CHANGED
@@ -6,15 +6,18 @@ import argparse
6
  from datetime import datetime
7
 
8
  from utils import get_package_version, Stopwatch
 
9
 
10
  def main():
11
  """
12
  エントリーポイント
13
- コマンドライン引数の解析を行います
 
14
  """
15
- print(f"{datetime.now()}:アプリ起動中")
 
16
  watch = Stopwatch.startNew()
17
-
18
  import app
19
 
20
  parser = argparse.ArgumentParser(prog=app.PROGRAM_NAME, description="SaliencyMapDemo")
@@ -25,4 +28,4 @@ def main():
25
  app.run(parser.parse_args(), watch)
26
 
27
  if __name__ == "__main__":
28
- main()
 
6
  from datetime import datetime
7
 
8
  from utils import get_package_version, Stopwatch
9
+ from reporter import get_current_reporter
10
 
11
  def main():
12
  """
13
  エントリーポイント
14
+ 1, コマンドライン引数の解析を行います
15
+ 2, アプリを起動します。
16
  """
17
+ log = get_current_reporter()
18
+ log.info("#アプリ起動中")
19
  watch = Stopwatch.startNew()
20
+
21
  import app
22
 
23
  parser = argparse.ArgumentParser(prog=app.PROGRAM_NAME, description="SaliencyMapDemo")
 
28
  app.run(parser.parse_args(), watch)
29
 
30
  if __name__ == "__main__":
31
+ main()
src/reporter.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Reporter
4
+ ログハンドラーが重複登録されるのを防ぐために1箇所で生成してログハンドラーを返します。
5
+ Example:
6
+ from reporter import get_current_reporter
7
+
8
+ logger = get_current_reporter()
9
+ logger.info("message");
10
+ """
11
+ from logging import Logger, getLogger, Formatter, StreamHandler
12
+ from logging import DEBUG
13
+
14
+ _reporters = []
15
+
16
+ def get_current_reporter() -> Logger:
17
+ return _reporters[-1]
18
+
19
+ def __make_reporter(name: str='SaliencyMapDemo') -> None:
20
+ """
21
+ ログハンドラーを生成します。
22
+ @see https://docs.python.jp/3/howto/logging-cookbook.html#logging-to-a-single-file-from-multiple-processes
23
+
24
+ Parameters:
25
+ name: アプリ名
26
+ """
27
+ handler = StreamHandler() # コンソールに出力します。
28
+ formatter = Formatter('%(asctime)s%(message)s')
29
+ handler.setFormatter(formatter)
30
+ handler.setLevel(DEBUG)
31
+
32
+ logger = getLogger(name)
33
+ logger.setLevel(DEBUG)
34
+ logger.addHandler(handler)
35
+ _reporters.append(logger)
36
+
37
+ __make_reporter()
38
+
39
+ def main():
40
+ """
41
+ Entry Point
42
+ """
43
+ assert len(_reporters) == 1
44
+
45
+ logger = get_current_reporter()
46
+ logger.debug("main")
47
+
48
+ if __name__ == "__main__":
49
+ main()
src/saliency.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from typing import Literal
3
+
4
+ import numpy as np
5
+ import cv2
6
+
7
+
8
+ class SaliencyMap:
9
+ """
10
+ SaliencyMap 顕著性マップを計算するクラスです。
11
+ Example:
12
+ from lib.saliency import SaliencyMap
13
+
14
+ saliency = SaliencyMap("SpectralResidual")
15
+ success, saliencyMap = saliency.computeSaliency(image)
16
+ """
17
+ def __init__(
18
+ self,
19
+ algorithm: Literal["SpectralResidual", "FineGrained"] = "SpectralResidual",
20
+ ):
21
+ self.algorithm = algorithm
22
+ # OpenCVのsaliencyを作成します。
23
+ if algorithm == "SpectralResidual":
24
+ self.saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
25
+ else:
26
+ self.saliency = cv2.saliency.StaticSaliencyFineGrained_create()
27
+
28
+
29
+ def computeSaliency(self, image: np.ndarray):
30
+ """
31
+ 入力画像から顕著性マップを作成します。
32
+
33
+ Parameters:
34
+ image: 入力画像
35
+
36
+ Returns:
37
+ bool:
38
+ true: SaliencyMap computed, false:NG
39
+ np.ndarray: 顕著性マップ
40
+ """
41
+ # 画像の顕著性を計算します。
42
+ return self.saliency.computeSaliency(image)
43
+
44
+ def convertColorMap(
45
+ image: np.ndarray,
46
+ saliencyMap: np.ndarray,
47
+ colormap_name: Literal["jet", "hot"] = "jet"):
48
+ """
49
+ 顕著性マップをカラーマップに変換後に、入力画像に重ね合わせします。
50
+
51
+ Parameters:
52
+ image: 入力画像
53
+ saliencyMap: 顕著性マップ
54
+ colormap_name: カラーマップの種類
55
+
56
+ Returns:
57
+ np.ndarray: 重ね合わせた画像(RGBA形式)
58
+ """
59
+ #image = (image * 255).astype("uint8")
60
+ #
61
+ #return cv2.applyColorMap(image, cv2.COLORMAP_JET)
62
+
63
+
64
+ # 顕著性マップをカラーマップに変換
65
+ saliencyMap = (saliencyMap * 255).astype("uint8")
66
+ if colormap_name == "jet":
67
+ saliencyMap = cv2.applyColorMap(saliencyMap, cv2.COLORMAP_JET)
68
+ else:
69
+ saliencyMap = cv2.applyColorMap(saliencyMap, cv2.COLORMAP_HOT)
70
+ #return saliencyMap
71
+ # 入力画像とカラーマップを重ね合わせ
72
+ overlay = cv2.addWeighted(image, 0.5, saliencyMap, 0.5, 0)
73
+ #return overlay
74
+
75
+ return cv2.cvtColor(overlay, cv2.COLOR_BGR2RGBA)