hahafofo commited on
Commit
39fbaa4
1 Parent(s): 48f4d16
utils/__init__.py ADDED
File without changes
utils/dbimutils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DanBooru IMage Utility functions
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
9
+ if img.endswith(".gif"):
10
+ img = Image.open(img)
11
+ img = img.convert("RGB")
12
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
13
+ else:
14
+ img = cv2.imread(img, flag)
15
+ return img
16
+
17
+
18
+ def smart_24bit(img):
19
+ if img.dtype is np.dtype(np.uint16):
20
+ img = (img / 257).astype(np.uint8)
21
+
22
+ if len(img.shape) == 2:
23
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
24
+ elif img.shape[2] == 4:
25
+ trans_mask = img[:, :, 3] == 0
26
+ img[trans_mask] = [255, 255, 255, 255]
27
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
28
+ return img
29
+
30
+
31
+ def make_square(img, target_size):
32
+ old_size = img.shape[:2]
33
+ desired_size = max(old_size)
34
+ desired_size = max(desired_size, target_size)
35
+
36
+ delta_w = desired_size - old_size[1]
37
+ delta_h = desired_size - old_size[0]
38
+ top, bottom = delta_h // 2, delta_h - (delta_h // 2)
39
+ left, right = delta_w // 2, delta_w - (delta_w // 2)
40
+
41
+ color = [255, 255, 255]
42
+ new_im = cv2.copyMakeBorder(
43
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
44
+ )
45
+ return new_im
46
+
47
+
48
+ def smart_resize(img, size):
49
+ # Assumes the image has already gone through make_square
50
+ if img.shape[0] > size:
51
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
52
+ elif img.shape[0] < size:
53
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
54
+ return img
utils/exif.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import piexif
2
+ import piexif.helper
3
+ from .html import plaintext_to_html
4
+
5
+
6
+ def get_image_info(rawimage):
7
+ items = rawimage.info
8
+ geninfo = ""
9
+
10
+ if "exif" in rawimage.info:
11
+ exif = piexif.load(rawimage.info["exif"])
12
+ exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"")
13
+ try:
14
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
15
+ except ValueError:
16
+ exif_comment = exif_comment.decode("utf8", errors="ignore")
17
+
18
+ items["exif comment"] = exif_comment
19
+ geninfo = exif_comment
20
+
21
+ for field in [
22
+ "jfif",
23
+ "jfif_version",
24
+ "jfif_unit",
25
+ "jfif_density",
26
+ "dpi",
27
+ "exif",
28
+ "loop",
29
+ "background",
30
+ "timestamp",
31
+ "duration",
32
+ ]:
33
+ items.pop(field, None)
34
+
35
+ geninfo = items.get("parameters", geninfo)
36
+
37
+ info = f"""
38
+ <p><h4>PNG Info</h4></p>
39
+ """
40
+ for key, text in items.items():
41
+ info += (
42
+ f"""
43
+ <div>
44
+ <p><b>{plaintext_to_html(str(key))}</b></p>
45
+ <p>{plaintext_to_html(str(text))}</p>
46
+ </div>
47
+ """.strip()
48
+ + "\n"
49
+ )
50
+
51
+ if len(info) == 0:
52
+ message = "Nothing found in the image."
53
+ info = f"<div><p>{message}<p></div>"
54
+ return info
utils/html.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import html
2
+
3
+
4
+ def plaintext_to_html(text):
5
+ text = (
6
+ "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split("\n")]) + "</p>"
7
+ )
8
+ return text
utils/image2text.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import PIL.Image
4
+ import huggingface_hub
5
+ import numpy as np
6
+ import onnxruntime as rt
7
+ import pandas as pd
8
+ import torch
9
+ from transformers import AutoModelForCausalLM
10
+ from transformers import AutoProcessor
11
+
12
+ from . import dbimutils
13
+ from .singleton import Singleton
14
+
15
+ import torch
16
+ from clip_interrogator import Config, Interrogator
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+
21
+ @Singleton
22
+ class Models(object):
23
+ # WD14 models
24
+ SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
25
+ CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
26
+ CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
27
+ VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
28
+
29
+ MODEL_FILENAME = "model.onnx"
30
+ LABEL_FILENAME = "selected_tags.csv"
31
+
32
+ # CLIP models
33
+ VIT_H_14_MODEL_REPO = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" # Stable Diffusion 2.X
34
+ VIT_L_14_MODEL_REPO = "openai/clip-vit-large-patch14" # Stable Diffusion 1.X
35
+
36
+ def __init__(self):
37
+ pass
38
+
39
+ @classmethod
40
+ def load_clip_model(cls, model_repo):
41
+ config = Config()
42
+ config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
43
+ config.blip_offload = False if torch.cuda.is_available() else True
44
+ config.chunk_size = 2048
45
+ config.flavor_intermediate_count = 512
46
+ config.blip_num_beams = 64
47
+ config.clip_model_name = model_repo
48
+
49
+ ci = Interrogator(config)
50
+ return ci
51
+
52
+ def __getattr__(self, item):
53
+ if item in self.__dict__:
54
+ return getattr(self, item)
55
+ print(f"Loading {item}...")
56
+ if item in ('clip_vit_h_14_model',):
57
+ self.clip_vit_h_14_model = self.load_clip_model(self.VIT_H_14_MODEL_REPO)
58
+
59
+ if item in ('clip_vit_l_14_model',):
60
+ self.clip_vit_l_14_model = self.load_clip_model(self.VIT_L_14_MODEL_REPO)
61
+
62
+ if item in ('swinv2_model',):
63
+ self.swinv2_model = self.load_model(self.SWIN_MODEL_REPO, self.MODEL_FILENAME)
64
+ if item in ('convnext_model',):
65
+ self.convnext_model = self.load_model(self.CONV_MODEL_REPO, self.MODEL_FILENAME)
66
+ if item in ('vit_model',):
67
+ self.vit_model = self.load_model(self.VIT_MODEL_REPO, self.MODEL_FILENAME)
68
+ if item in ('convnextv2_model',):
69
+ self.convnextv2_model = self.load_model(self.CONV2_MODEL_REPO, self.MODEL_FILENAME)
70
+
71
+ if item in ('git_model', 'git_processor'):
72
+ self.git_model, self.git_processor = self.load_git_model()
73
+
74
+ if item in ('tag_names', 'rating_indexes', 'general_indexes', 'character_indexes'):
75
+ self.tag_names, self.rating_indexes, self.general_indexes, self.character_indexes = self.load_w14_labels()
76
+
77
+ return getattr(self, item)
78
+
79
+ @classmethod
80
+ def load_git_model(cls):
81
+ model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
82
+ processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
83
+
84
+ return model, processor
85
+
86
+ @staticmethod
87
+ def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
88
+ path = huggingface_hub.hf_hub_download(
89
+ model_repo, model_filename,
90
+ )
91
+ model = rt.InferenceSession(path)
92
+ return model
93
+
94
+ @classmethod
95
+ def load_w14_labels(cls) -> list[str]:
96
+ path = huggingface_hub.hf_hub_download(
97
+ cls.CONV2_MODEL_REPO, cls.LABEL_FILENAME
98
+ )
99
+ df = pd.read_csv(path)
100
+
101
+ tag_names = df["name"].tolist()
102
+ rating_indexes = list(np.where(df["category"] == 9)[0])
103
+ general_indexes = list(np.where(df["category"] == 0)[0])
104
+ character_indexes = list(np.where(df["category"] == 4)[0])
105
+ return [tag_names, rating_indexes, general_indexes, character_indexes]
106
+
107
+
108
+ models = Models.instance()
109
+
110
+
111
+ def clip_image2text(image, mode_type='best', model_name='vit_h_14'):
112
+ image = image.convert('RGB')
113
+ model = getattr(models, f'clip_{model_name}_model')
114
+ if mode_type == 'classic':
115
+ prompt = model.interrogate_classic(image)
116
+ elif mode_type == 'fast':
117
+ prompt = model.interrogate_fast(image)
118
+ elif mode_type == 'negative':
119
+ prompt = model.interrogate_negative(image)
120
+ else:
121
+ prompt = model.interrogate(image) # default to best
122
+ return prompt
123
+
124
+
125
+ def git_image2text(input_image, max_length=50):
126
+ image = input_image.convert('RGB')
127
+ pixel_values = models.git_processor(images=image, return_tensors="pt").to(device).pixel_values
128
+
129
+ generated_ids = models.git_model.to(device).generate(pixel_values=pixel_values, max_length=max_length)
130
+ generated_caption = models.git_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
131
+ return generated_caption
132
+
133
+
134
+ def w14_image2text(
135
+ image: PIL.Image.Image,
136
+ model_name: str,
137
+ general_threshold: float,
138
+ character_threshold: float,
139
+
140
+ ):
141
+ tag_names: list[str] = models.tag_names
142
+ rating_indexes: list[np.int64] = models.rating_indexes
143
+ general_indexes: list[np.int64] = models.general_indexes
144
+ character_indexes: list[np.int64] = models.character_indexes
145
+ model_name = "{}_model".format(model_name.lower())
146
+ model = getattr(models, model_name)
147
+
148
+ _, height, width, _ = model.get_inputs()[0].shape
149
+
150
+ # Alpha to white
151
+ image = image.convert("RGBA")
152
+ new_image = PIL.Image.new("RGBA", image.size, "WHITE")
153
+ new_image.paste(image, mask=image)
154
+ image = new_image.convert("RGB")
155
+ image = np.asarray(image)
156
+
157
+ # PIL RGB to OpenCV BGR
158
+ image = image[:, :, ::-1]
159
+
160
+ image = dbimutils.make_square(image, height)
161
+ image = dbimutils.smart_resize(image, height)
162
+ image = image.astype(np.float32)
163
+ image = np.expand_dims(image, 0)
164
+
165
+ input_name = model.get_inputs()[0].name
166
+ label_name = model.get_outputs()[0].name
167
+ probs = model.run([label_name], {input_name: image})[0]
168
+
169
+ labels = list(zip(tag_names, probs[0].astype(float)))
170
+
171
+ # First 4 labels are actually ratings: pick one with argmax
172
+ ratings_names = [labels[i] for i in rating_indexes]
173
+ rating = dict(ratings_names)
174
+
175
+ # Then we have general tags: pick any where prediction confidence > threshold
176
+ general_names = [labels[i] for i in general_indexes]
177
+ general_res = [x for x in general_names if x[1] > general_threshold]
178
+ general_res = dict(general_res)
179
+
180
+ # Everything else is characters: pick any where prediction confidence > threshold
181
+ character_names = [labels[i] for i in character_indexes]
182
+ character_res = [x for x in character_names if x[1] > character_threshold]
183
+ character_res = dict(character_res)
184
+
185
+ b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
186
+ a = (
187
+ ", ".join(list(b.keys()))
188
+ .replace("_", " ")
189
+ .replace("(", "\(")
190
+ .replace(")", "\)")
191
+ )
192
+ c = ", ".join(list(b.keys()))
193
+ d = " ".join(list(b.keys()))
194
+
195
+ return a, c, d, rating, character_res, general_res
utils/singleton.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Singleton:
2
+ """
3
+ A non-thread-safe helper class to ease implementing singletons.
4
+ This should be used as a decorator -- not a metaclass -- to the
5
+ class that should be a singleton.
6
+
7
+ The decorated class can define one `__init__` function that
8
+ takes only the `self` argument. Also, the decorated class cannot be
9
+ inherited from. Other than that, there are no restrictions that apply
10
+ to the decorated class.
11
+
12
+ To get the singleton instance, use the `instance` method. Trying
13
+ to use `__call__` will result in a `TypeError` being raised.
14
+
15
+ """
16
+
17
+ def __init__(self, decorated):
18
+ self._decorated = decorated
19
+
20
+ def instance(self):
21
+ """
22
+ Returns the singleton instance. Upon its first call, it creates a
23
+ new instance of the decorated class and calls its `__init__` method.
24
+ On all subsequent calls, the already created instance is returned.
25
+
26
+ """
27
+ try:
28
+ return self._instance
29
+ except AttributeError:
30
+ self._instance = self._decorated()
31
+ return self._instance
32
+
33
+ def __call__(self):
34
+ raise TypeError('Singletons must be accessed through `instance()`.')
35
+
36
+ def __instancecheck__(self, inst):
37
+ return isinstance(inst, self._decorated)
utils/translate.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import torch
3
+ from .singleton import Singleton
4
+
5
+ device = "cuda" if torch.cuda.is_available() else "cpu"
6
+
7
+
8
+ @Singleton
9
+ class Models(object):
10
+
11
+ def __getattr__(self, item):
12
+ if item in self.__dict__:
13
+ return getattr(self, item)
14
+
15
+ if item in ('zh2en_model', 'zh2en_tokenizer',):
16
+ self.zh2en_model, self.zh2en_tokenizer = self.load_zh2en_model()
17
+
18
+ if item in ('en2zh_model', 'en2zh_tokenizer',):
19
+ self.en2zh_model, self.en2zh_tokenizer = self.load_en2zh_model()
20
+
21
+ return getattr(self, item)
22
+
23
+ @classmethod
24
+ def load_en2zh_model(cls):
25
+ en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
26
+ en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
27
+ return en2zh_model, en2zh_tokenizer
28
+
29
+ @classmethod
30
+ def load_zh2en_model(cls):
31
+ zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
32
+ zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
33
+
34
+ return zh2en_model, zh2en_tokenizer,
35
+
36
+
37
+ models = Models.instance()
38
+
39
+
40
+ def zh2en(text):
41
+ with torch.no_grad():
42
+ encoded = models.zh2en_tokenizer([text], return_tensors="pt")
43
+ sequences = models.zh2en_model.generate(**encoded)
44
+ return models.zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
45
+
46
+
47
+ def en2zh(text):
48
+ with torch.no_grad():
49
+ encoded = models.en2zh_tokenizer([text], return_tensors="pt")
50
+ sequences = models.en2zh_model.generate(**encoded)
51
+ return models.en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
52
+
53
+
54
+ if __name__ == "__main__":
55
+ input = "青春不能回头,所以青春没有终点。 ——《火影忍者》"
56
+ en = zh2en(input)
57
+ print(input, en)
58
+ zh = en2zh(en)
59
+ print(en, zh)