yo2266911 commited on
Commit
213ef40
1 Parent(s): 31cf972

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import html
9
+ import pathlib
10
+ import tarfile
11
+
12
+ import deepdanbooru as dd
13
+ import gradio as gr
14
+ import huggingface_hub
15
+ import numpy as np
16
+ import PIL.Image
17
+ import tensorflow as tf
18
+ import piexif
19
+ import piexif.helper
20
+
21
+ TITLE = 'DeepDanbooru String'
22
+
23
+ TOKEN = os.environ['TOKEN']
24
+ MODEL_REPO = 'NoCrypt/DeepDanbooru_string'
25
+ MODEL_FILENAME = 'model-resnet_custom_v3.h5'
26
+ LABEL_FILENAME = 'tags.txt'
27
+
28
+
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument('--score-slider-step', type=float, default=0.05)
32
+ parser.add_argument('--score-threshold', type=float, default=0.5)
33
+ parser.add_argument('--theme', type=str, default='dark-grass')
34
+ parser.add_argument('--live', action='store_true')
35
+ parser.add_argument('--share', action='store_true')
36
+ parser.add_argument('--port', type=int)
37
+ parser.add_argument('--disable-queue',
38
+ dest='enable_queue',
39
+ action='store_false')
40
+ parser.add_argument('--allow-flagging', type=str, default='never')
41
+ return parser.parse_args()
42
+
43
+
44
+ def load_sample_image_paths() -> list[pathlib.Path]:
45
+ image_dir = pathlib.Path('images')
46
+ if not image_dir.exists():
47
+ dataset_repo = 'hysts/sample-images-TADNE'
48
+ path = huggingface_hub.hf_hub_download(dataset_repo,
49
+ 'images.tar.gz',
50
+ repo_type='dataset',
51
+ use_auth_token=TOKEN)
52
+ with tarfile.open(path) as f:
53
+ f.extractall()
54
+ return sorted(image_dir.glob('*'))
55
+
56
+
57
+ def load_model() -> tf.keras.Model:
58
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
59
+ MODEL_FILENAME,
60
+ use_auth_token=TOKEN)
61
+ model = tf.keras.models.load_model(path)
62
+ return model
63
+
64
+
65
+ def load_labels() -> list[str]:
66
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
67
+ LABEL_FILENAME,
68
+ use_auth_token=TOKEN)
69
+ with open(path) as f:
70
+ labels = [line.strip() for line in f.readlines()]
71
+ return labels
72
+
73
+ def plaintext_to_html(text):
74
+ text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
75
+ return text
76
+
77
+ def predict(image: PIL.Image.Image, score_threshold: float,
78
+ model: tf.keras.Model, labels: list[str]) -> dict[str, float]:
79
+ rawimage = image
80
+ _, height, width, _ = model.input_shape
81
+ image = np.asarray(image)
82
+ image = tf.image.resize(image,
83
+ size=(height, width),
84
+ method=tf.image.ResizeMethod.AREA,
85
+ preserve_aspect_ratio=True)
86
+ image = image.numpy()
87
+ image = dd.image.transform_and_pad_image(image, width, height)
88
+ image = image / 255.
89
+ probs = model.predict(image[None, ...])[0]
90
+ probs = probs.astype(float)
91
+ res = dict()
92
+ for prob, label in zip(probs.tolist(), labels):
93
+ if prob < score_threshold:
94
+ continue
95
+ res[label] = prob
96
+ b = dict(sorted(res.items(),key=lambda item:item[1], reverse=True))
97
+ a = ', '.join(list(b.keys())).replace('_',' ').replace('(','\(').replace(')','\)')
98
+ c = ', '.join(list(b.keys()))
99
+
100
+ items = rawimage.info
101
+ geninfo = ''
102
+
103
+ if "exif" in rawimage.info:
104
+ exif = piexif.load(rawimage.info["exif"])
105
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
106
+ try:
107
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
108
+ except ValueError:
109
+ exif_comment = exif_comment.decode('utf8', errors="ignore")
110
+
111
+ items['exif comment'] = exif_comment
112
+ geninfo = exif_comment
113
+
114
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
115
+ 'loop', 'background', 'timestamp', 'duration']:
116
+ items.pop(field, None)
117
+
118
+ geninfo = items.get('parameters', geninfo)
119
+
120
+ info = f"""
121
+ <p><h4>PNG Info</h4></p>
122
+ """
123
+ for key, text in items.items():
124
+ info += f"""
125
+ <div>
126
+ <p><b>{plaintext_to_html(str(key))}</b></p>
127
+ <p>{plaintext_to_html(str(text))}</p>
128
+ </div>
129
+ """.strip()+"\n"
130
+
131
+ if len(info) == 0:
132
+ message = "Nothing found in the image."
133
+ info = f"<div><p>{message}<p></div>"
134
+
135
+ return (a,c,res,info)
136
+
137
+
138
+ def main():
139
+ args = parse_args()
140
+ model = load_model()
141
+ labels = load_labels()
142
+
143
+ func = functools.partial(predict, model=model, labels=labels)
144
+ func = functools.update_wrapper(func, predict)
145
+
146
+ gr.Interface(
147
+ func,
148
+ [
149
+ gr.inputs.Image(type='pil', label='Input'),
150
+ gr.inputs.Slider(0,
151
+ 1,
152
+ step=args.score_slider_step,
153
+ default=args.score_threshold,
154
+ label='Score Threshold'),
155
+ ],
156
+ [
157
+ gr.outputs.Textbox(label='Output (string)'),
158
+ gr.outputs.Textbox(label='Output (raw string)'),
159
+ gr.outputs.Label(label='Output (label)'),
160
+ gr.outputs.HTML()
161
+ ],
162
+ examples=[
163
+ ['miku.jpg',0.5],
164
+ ['miku2.jpg',0.5]
165
+ ],
166
+ title=TITLE,
167
+ description='''
168
+ Demo for [KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) with "ready to copy" prompt and a prompt analyzer.
169
+
170
+ Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
171
+
172
+ PNG Info code forked from [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
173
+ ''',
174
+ theme=args.theme,
175
+ allow_flagging=args.allow_flagging,
176
+ live=args.live,
177
+ ).launch(
178
+ enable_queue=args.enable_queue,
179
+ server_port=args.port,
180
+ share=args.share,
181
+ )
182
+
183
+
184
+ if __name__ == '__main__':
185
+ main()