narugo1992 commited on
Commit
e155f7f
1 Parent(s): 21e723b

dev(narugo): simplify code

Browse files
Files changed (7) hide show
  1. aicheck.py +0 -42
  2. app.py +12 -89
  3. base.py +65 -0
  4. chsex.py +0 -42
  5. cls.py +0 -41
  6. monochrome.py +0 -42
  7. rating.py +0 -42
aicheck.py DELETED
@@ -1,42 +0,0 @@
1
- import json
2
- import os
3
- from functools import lru_cache
4
- from typing import Mapping, List
5
-
6
- from huggingface_hub import HfFileSystem
7
- from huggingface_hub import hf_hub_download
8
- from imgutils.data import ImageTyping, load_image
9
- from natsort import natsorted
10
-
11
- from onnx_ import _open_onnx_model
12
- from preprocess import _img_encode
13
-
14
- hfs = HfFileSystem()
15
-
16
- _REPO = 'deepghs/anime_ai_check'
17
- _AICHECK_MODELS = natsorted([
18
- os.path.dirname(os.path.relpath(file, _REPO))
19
- for file in hfs.glob(f'{_REPO}/*/model.onnx')
20
- ])
21
- _DEFAULT_AICHECK_MODEL = 'mobilenetv3_sce_dist'
22
-
23
-
24
- @lru_cache()
25
- def _open_anime_aicheck_model(model_name):
26
- return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
27
-
28
-
29
- @lru_cache()
30
- def _get_tags(model_name) -> List[str]:
31
- with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
32
- return json.load(f)['labels']
33
-
34
-
35
- def _gr_aicheck(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
36
- image = load_image(image, mode='RGB')
37
- input_ = _img_encode(image, size=(size, size))[None, ...]
38
- output, = _open_anime_aicheck_model(model_name).run(['output'], {'input': input_})
39
-
40
- labels = _get_tags(model_name)
41
- values = dict(zip(labels, map(lambda x: x.item(), output[0])))
42
- return values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -2,98 +2,21 @@ import os
2
 
3
  import gradio as gr
4
 
5
- from aicheck import _gr_aicheck, _DEFAULT_AICHECK_MODEL, _AICHECK_MODELS
6
- from chsex import _gr_chsex, _CHSEX_MODELS, _DEFAULT_CHSEX_MODEL
7
- from cls import _CLS_MODELS, _DEFAULT_CLS_MODEL, _gr_classification
8
- from monochrome import _gr_monochrome, _DEFAULT_MONO_MODEL, _MONO_MODELS
9
- from rating import _RATING_MODELS, _DEFAULT_RATING_MODEL, _gr_rating
 
 
 
 
 
10
 
11
  if __name__ == '__main__':
12
  with gr.Blocks() as demo:
13
  with gr.Tabs():
14
- with gr.Tab('Classification'):
15
- with gr.Row():
16
- with gr.Column():
17
- gr_cls_input_image = gr.Image(type='pil', label='Original Image')
18
- gr_cls_model = gr.Dropdown(_CLS_MODELS, value=_DEFAULT_CLS_MODEL, label='Model')
19
- gr_cls_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size')
20
- gr_cls_submit = gr.Button(value='Submit', variant='primary')
21
-
22
- with gr.Column():
23
- gr_cls_output = gr.Label(label='Classes')
24
-
25
- gr_cls_submit.click(
26
- _gr_classification,
27
- inputs=[gr_cls_input_image, gr_cls_model, gr_cls_infer_size],
28
- outputs=[gr_cls_output],
29
- )
30
-
31
- with gr.Tab('Monochrome'):
32
- with gr.Row():
33
- with gr.Column():
34
- gr_mono_input_image = gr.Image(type='pil', label='Original Image')
35
- gr_mono_model = gr.Dropdown(_MONO_MODELS, value=_DEFAULT_MONO_MODEL, label='Model')
36
- gr_mono_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size')
37
- gr_mono_submit = gr.Button(value='Submit', variant='primary')
38
-
39
- with gr.Column():
40
- gr_mono_output = gr.Label(label='Classes')
41
-
42
- gr_mono_submit.click(
43
- _gr_monochrome,
44
- inputs=[gr_mono_input_image, gr_mono_model, gr_mono_infer_size],
45
- outputs=[gr_mono_output],
46
- )
47
-
48
- with gr.Tab('AI Check'):
49
- with gr.Row():
50
- with gr.Column():
51
- gr_aicheck_input_image = gr.Image(type='pil', label='Original Image')
52
- gr_aicheck_model = gr.Dropdown(_AICHECK_MODELS, value=_DEFAULT_AICHECK_MODEL, label='Model')
53
- gr_aicheck_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size')
54
- gr_aicheck_submit = gr.Button(value='Submit', variant='primary')
55
-
56
- with gr.Column():
57
- gr_aicheck_output = gr.Label(label='Classes')
58
-
59
- gr_aicheck_submit.click(
60
- _gr_aicheck,
61
- inputs=[gr_aicheck_input_image, gr_aicheck_model, gr_aicheck_infer_size],
62
- outputs=[gr_aicheck_output],
63
- )
64
-
65
- with gr.Tab('Rating'):
66
- with gr.Row():
67
- with gr.Column():
68
- gr_rating_input_image = gr.Image(type='pil', label='Original Image')
69
- gr_rating_model = gr.Dropdown(_RATING_MODELS, value=_DEFAULT_RATING_MODEL, label='Model')
70
- gr_rating_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size')
71
- gr_rating_submit = gr.Button(value='Submit', variant='primary')
72
-
73
- with gr.Column():
74
- gr_rating_output = gr.Label(label='Classes')
75
-
76
- gr_rating_submit.click(
77
- _gr_rating,
78
- inputs=[gr_rating_input_image, gr_rating_model, gr_rating_infer_size],
79
- outputs=[gr_rating_output],
80
- )
81
-
82
- with gr.Tab('Character Sex'):
83
- with gr.Row():
84
- with gr.Column():
85
- gr_chsex_input_image = gr.Image(type='pil', label='Original Image')
86
- gr_chsex_model = gr.Dropdown(_CHSEX_MODELS, value=_DEFAULT_CHSEX_MODEL, label='Model')
87
- gr_chsex_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size')
88
- gr_chsex_submit = gr.Button(value='Submit', variant='primary')
89
-
90
- with gr.Column():
91
- gr_chsex_output = gr.Label(label='Classes')
92
-
93
- gr_chsex_submit.click(
94
- _gr_chsex,
95
- inputs=[gr_chsex_input_image, gr_chsex_model, gr_chsex_infer_size],
96
- outputs=[gr_chsex_output],
97
- )
98
 
99
  demo.queue(os.cpu_count()).launch()
 
2
 
3
  import gradio as gr
4
 
5
+ from base import Classification
6
+
7
+ apps = [
8
+ Classification('Classification', 'deepghs/anime_classification', 'mobilenetv3_sce_dist'),
9
+ Classification('Monochrome', 'deepghs/monochrome_detect', 'mobilenetv3_large_100_dist'),
10
+ Classification('AI Check', 'deepghs/anime_ai_check', 'mobilenetv3_sce_dist'),
11
+ Classification('Rating', 'deepghs/anime_rating', 'mobilenetv3_sce_dist'),
12
+ Classification('Character Sex', 'deepghs/anime_ch_sex', 'caformer_s36_v1'),
13
+ Classification('Character Skin', 'deepghs/anime_ch_skin_color', 'caformer_s36'),
14
+ ]
15
 
16
  if __name__ == '__main__':
17
  with gr.Blocks() as demo:
18
  with gr.Tabs():
19
+ for cls in apps:
20
+ cls.create_gr()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  demo.queue(os.cpu_count()).launch()
base.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Mapping
5
+
6
+ import gradio as gr
7
+ from huggingface_hub import HfFileSystem, hf_hub_download
8
+ from imgutils.data import ImageTyping, load_image
9
+ from natsort import natsorted
10
+
11
+ from onnx_ import _open_onnx_model
12
+ from preprocess import _img_encode
13
+
14
+ hfs = HfFileSystem()
15
+
16
+
17
+ @lru_cache()
18
+ def open_model_from_repo(repository, model):
19
+ runtime = _open_onnx_model(hf_hub_download(repository, f'{model}/model.onnx'))
20
+ with open(hf_hub_download(repository, f'{model}/meta.json'), 'r') as f:
21
+ labels = json.load(f)['labels']
22
+
23
+ return runtime, labels
24
+
25
+
26
+ class Classification:
27
+ def __init__(self, title: str, repository: str, default_model=None, imgsize: int = 384):
28
+ self.title = title
29
+ self.repository = repository
30
+ self.models = natsorted([
31
+ os.path.dirname(os.path.relpath(file, self.repository))
32
+ for file in hfs.glob(f'{self.repository}/*/model.onnx')
33
+ ])
34
+ self.default_model = default_model or self.models[0]
35
+ self.imgsize = imgsize
36
+
37
+ def _open_onnx_model(self, model):
38
+ return open_model_from_repo(self.repository, model)
39
+
40
+ def _gr_classification(self, image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
41
+ image = load_image(image, mode='RGB')
42
+ input_ = _img_encode(image, size=(size, size))[None, ...]
43
+ model, labels = self._open_onnx_model(model_name)
44
+ output, = model.run(['output'], {'input': input_})
45
+
46
+ values = dict(zip(labels, map(lambda x: x.item(), output[0])))
47
+ return values
48
+
49
+ def create_gr(self):
50
+ with gr.Tab(self.title):
51
+ with gr.Row():
52
+ with gr.Column():
53
+ gr_input_image = gr.Image(type='pil', label='Original Image')
54
+ gr_model = gr.Dropdown(self.models, value=self.default_model, label='Model')
55
+ gr_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size')
56
+ gr_submit = gr.Button(value='Submit', variant='primary')
57
+
58
+ with gr.Column():
59
+ gr_output = gr.Label(label='Classes')
60
+
61
+ gr_submit.click(
62
+ self._gr_classification,
63
+ inputs=[gr_input_image, gr_model, gr_infer_size],
64
+ outputs=[gr_output],
65
+ )
chsex.py DELETED
@@ -1,42 +0,0 @@
1
- import json
2
- import os
3
- from functools import lru_cache
4
- from typing import Mapping, List
5
-
6
- from huggingface_hub import HfFileSystem
7
- from huggingface_hub import hf_hub_download
8
- from imgutils.data import ImageTyping, load_image
9
- from natsort import natsorted
10
-
11
- from onnx_ import _open_onnx_model
12
- from preprocess import _img_encode
13
-
14
- hfs = HfFileSystem()
15
-
16
- _REPO = 'deepghs/anime_ch_sex'
17
- _CHSEX_MODELS = natsorted([
18
- os.path.dirname(os.path.relpath(file, _REPO))
19
- for file in hfs.glob(f'{_REPO}/*/model.onnx')
20
- ])
21
- _DEFAULT_CHSEX_MODEL = 'caformer_s36_v1'
22
-
23
-
24
- @lru_cache()
25
- def _open_anime_chsex_model(model_name):
26
- return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
27
-
28
-
29
- @lru_cache()
30
- def _get_tags(model_name) -> List[str]:
31
- with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
32
- return json.load(f)['labels']
33
-
34
-
35
- def _gr_chsex(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
36
- image = load_image(image, mode='RGB')
37
- input_ = _img_encode(image, size=(size, size))[None, ...]
38
- output, = _open_anime_chsex_model(model_name).run(['output'], {'input': input_})
39
-
40
- labels = _get_tags(model_name)
41
- values = dict(zip(labels, map(lambda x: x.item(), output[0])))
42
- return values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cls.py DELETED
@@ -1,41 +0,0 @@
1
- import json
2
- import os
3
- from functools import lru_cache
4
- from typing import Mapping, List
5
-
6
- from huggingface_hub import hf_hub_download, HfFileSystem
7
- from imgutils.data import ImageTyping, load_image
8
- from natsort import natsorted
9
-
10
- from onnx_ import _open_onnx_model
11
- from preprocess import _img_encode
12
-
13
- hfs = HfFileSystem()
14
-
15
- _REPO = 'deepghs/anime_classification'
16
- _CLS_MODELS = natsorted([
17
- os.path.dirname(os.path.relpath(file, _REPO))
18
- for file in hfs.glob(f'{_REPO}/*/model.onnx')
19
- ])
20
- _DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist'
21
-
22
-
23
- @lru_cache()
24
- def _open_anime_classify_model(model_name):
25
- return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
26
-
27
-
28
- @lru_cache()
29
- def _get_tags(model_name) -> List[str]:
30
- with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
31
- return json.load(f)['labels']
32
-
33
-
34
- def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
35
- image = load_image(image, mode='RGB')
36
- input_ = _img_encode(image, size=(size, size))[None, ...]
37
- output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_})
38
-
39
- labels = _get_tags(model_name)
40
- values = dict(zip(labels, map(lambda x: x.item(), output[0])))
41
- return values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
monochrome.py DELETED
@@ -1,42 +0,0 @@
1
- import json
2
- import os
3
- from functools import lru_cache
4
- from typing import Mapping, List
5
-
6
- from huggingface_hub import HfFileSystem
7
- from huggingface_hub import hf_hub_download
8
- from imgutils.data import ImageTyping, load_image
9
- from natsort import natsorted
10
-
11
- from onnx_ import _open_onnx_model
12
- from preprocess import _img_encode
13
-
14
- hfs = HfFileSystem()
15
-
16
- _REPO = 'deepghs/monochrome_detect'
17
- _MONO_MODELS = natsorted([
18
- os.path.dirname(os.path.relpath(file, _REPO))
19
- for file in hfs.glob(f'{_REPO}/*/model.onnx')
20
- ])
21
- _DEFAULT_MONO_MODEL = 'mobilenetv3_large_100_dist'
22
-
23
-
24
- @lru_cache()
25
- def _open_anime_monochrome_model(model_name):
26
- return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
27
-
28
-
29
- @lru_cache()
30
- def _get_tags(model_name) -> List[str]:
31
- with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
32
- return json.load(f)['labels']
33
-
34
-
35
- def _gr_monochrome(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
36
- image = load_image(image, mode='RGB')
37
- input_ = _img_encode(image, size=(size, size))[None, ...]
38
- output, = _open_anime_monochrome_model(model_name).run(['output'], {'input': input_})
39
-
40
- labels = _get_tags(model_name)
41
- values = dict(zip(labels, map(lambda x: x.item(), output[0])))
42
- return values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rating.py DELETED
@@ -1,42 +0,0 @@
1
- import json
2
- import os
3
- from functools import lru_cache
4
- from typing import Mapping, List
5
-
6
- from huggingface_hub import HfFileSystem
7
- from huggingface_hub import hf_hub_download
8
- from imgutils.data import ImageTyping, load_image
9
- from natsort import natsorted
10
-
11
- from onnx_ import _open_onnx_model
12
- from preprocess import _img_encode
13
-
14
- hfs = HfFileSystem()
15
-
16
- _REPO = 'deepghs/anime_rating'
17
- _RATING_MODELS = natsorted([
18
- os.path.dirname(os.path.relpath(file, _REPO))
19
- for file in hfs.glob(f'{_REPO}/*/model.onnx')
20
- ])
21
- _DEFAULT_RATING_MODEL = 'mobilenetv3_sce_dist'
22
-
23
-
24
- @lru_cache()
25
- def _open_anime_rating_model(model_name):
26
- return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
27
-
28
-
29
- @lru_cache()
30
- def _get_tags(model_name) -> List[str]:
31
- with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
32
- return json.load(f)['labels']
33
-
34
-
35
- def _gr_rating(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
36
- image = load_image(image, mode='RGB')
37
- input_ = _img_encode(image, size=(size, size))[None, ...]
38
- output, = _open_anime_rating_model(model_name).run(['output'], {'input': input_})
39
-
40
- labels = _get_tags(model_name)
41
- values = dict(zip(labels, map(lambda x: x.item(), output[0])))
42
- return values