neggles commited on
Commit
c24a176
1 Parent(s): dc9c12b

make the thing work

Browse files
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import getenv
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from rich.traceback import install as traceback_install
7
+
8
+ from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image
9
+ from tagger.model import load_model_and_transform, process_heatmap
10
+
11
+ TITLE = "WD Tagger Heatmap"
12
+ DESCRIPTION = """WD Tagger v3 Heatmap Generator."""
13
+ # get HF token
14
+ HF_TOKEN = getenv("HF_TOKEN", None)
15
+
16
+ # model repo and cache
17
+ MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3"
18
+ # get the repo root (or the current working directory if running in ipython)
19
+ WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve()
20
+ # allowed extensions
21
+ IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
22
+
23
+ _ = traceback_install(show_locals=True, locals_max_length=0)
24
+
25
+ # get the example images
26
+ example_images = sorted(
27
+ [
28
+ str(x.relative_to(WORK_DIR))
29
+ for x in WORK_DIR.joinpath("examples").iterdir()
30
+ if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
31
+ ]
32
+ )
33
+
34
+
35
+ def predict(
36
+ image: Image.Image,
37
+ threshold: float = 0.5,
38
+ ):
39
+ # join variant for cache key
40
+ model, transform = load_model_and_transform(MODEL_REPO)
41
+ # load labels
42
+ labels: LabelData = load_labels_hf(MODEL_REPO)
43
+ # preprocess image
44
+ image = preprocess_image(image, (448, 448))
45
+ image = transform(image).unsqueeze(0)
46
+
47
+ # get the model output
48
+ heatmaps: list[Heatmap]
49
+ image_labels: ImageLabels
50
+ heatmaps, heatmap_grid, image_labels = process_heatmap(model, image, labels, threshold)
51
+
52
+ heatmap_images = [(x.image, x.label) for x in heatmaps]
53
+
54
+ return (
55
+ heatmap_images,
56
+ heatmap_grid,
57
+ image_labels.caption,
58
+ image_labels.booru,
59
+ image_labels.rating,
60
+ image_labels.character,
61
+ image_labels.general,
62
+ )
63
+
64
+
65
+ css = """
66
+ #use_mcut, #char_mcut {
67
+ padding-top: var(--scale-3);
68
+ }
69
+ #threshold.dimmed {
70
+ filter: brightness(75%);
71
+ }
72
+ """
73
+
74
+ with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo:
75
+ with gr.Row(equal_height=False):
76
+ with gr.Column(min_width=720):
77
+ with gr.Group():
78
+ img_input = gr.Image(
79
+ label="Input",
80
+ type="pil",
81
+ image_mode="RGB",
82
+ sources=["upload", "clipboard"],
83
+ )
84
+ with gr.Group():
85
+ with gr.Row():
86
+ threshold = gr.Slider(
87
+ minimum=0.0,
88
+ maximum=1.0,
89
+ value=0.35,
90
+ step=0.01,
91
+ label="Tag Threshold",
92
+ scale=5,
93
+ elem_id="threshold",
94
+ )
95
+ with gr.Row():
96
+ clear = gr.ClearButton(
97
+ components=[],
98
+ variant="secondary",
99
+ size="lg",
100
+ )
101
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
102
+
103
+ with gr.Column(min_width=720):
104
+ with gr.Tab(label="Heatmaps"):
105
+ heatmap_gallery = gr.Gallery(columns=3, show_label=False)
106
+ with gr.Tab(label="Grid"):
107
+ heatmap_grid = gr.Image(show_label=False)
108
+ with gr.Tab(label="Tags"):
109
+ with gr.Group():
110
+ rating = gr.Label(label="Rating")
111
+ with gr.Group():
112
+ character = gr.Label(label="Character")
113
+ with gr.Group():
114
+ general = gr.Label(label="General")
115
+
116
+ with gr.Group():
117
+ caption = gr.Textbox(label="Caption", show_copy_button=True)
118
+ tags = gr.Textbox(label="Tags", show_copy_button=True)
119
+
120
+ with gr.Row():
121
+ examples = [[imgpath, 0.35] for imgpath in example_images]
122
+ examples = gr.Examples(
123
+ examples=examples,
124
+ inputs=[img_input, threshold],
125
+ )
126
+
127
+ # tell clear button which components to clear
128
+ clear.add([img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general])
129
+
130
+ submit.click(
131
+ predict,
132
+ inputs=[img_input, threshold],
133
+ outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general],
134
+ api_name="predict",
135
+ )
136
+
137
+ if __name__ == "__main__":
138
+ demo.queue(max_size=10)
139
+ if getenv("SPACE_ID", None) is not None:
140
+ demo.launch()
141
+ else:
142
+ demo.launch(
143
+ server_name="0.0.0.0",
144
+ server_port=7871,
145
+ debug=True,
146
+ )
examples/img-01.png ADDED

Git LFS Details

  • SHA256: 37a2bec1c653272457c6b6e5fec6da8ac4676d973f7cd87c545a6e1ab6be288c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.53 MB
examples/img-02.png ADDED

Git LFS Details

  • SHA256: 90ee6035ce0caec46bbda3a9d48bdcd2cd7384487781615c4251301ab5422d45
  • Pointer size: 131 Bytes
  • Size of remote file: 434 kB
examples/img-03.jpg ADDED
tagger/__init__.py ADDED
File without changes
tagger/common.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from functools import lru_cache
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from huggingface_hub import hf_hub_download
11
+ from huggingface_hub.utils import HfHubHTTPError
12
+ from PIL import Image
13
+ from torch import Tensor, nn
14
+
15
+
16
+ @dataclass
17
+ class Heatmap:
18
+ label: str
19
+ score: float
20
+ image: Image.Image
21
+
22
+
23
+ @dataclass
24
+ class LabelData:
25
+ names: list[str]
26
+ rating: list[np.int64]
27
+ general: list[np.int64]
28
+ character: list[np.int64]
29
+
30
+
31
+ @dataclass
32
+ class ImageLabels:
33
+ caption: str
34
+ booru: str
35
+ rating: dict[str, float]
36
+ general: dict[str, float]
37
+ character: dict[str, float]
38
+
39
+
40
+ @lru_cache(maxsize=5)
41
+ def load_labels_hf(
42
+ repo_id: str,
43
+ revision: Optional[str] = None,
44
+ token: Optional[str] = None,
45
+ ) -> LabelData:
46
+ try:
47
+ csv_path = hf_hub_download(
48
+ repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
49
+ )
50
+ csv_path = Path(csv_path).resolve()
51
+ except HfHubHTTPError as e:
52
+ raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
53
+
54
+ df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
55
+ tag_data = LabelData(
56
+ names=df["name"].tolist(),
57
+ rating=list(np.where(df["category"] == 9)[0]),
58
+ general=list(np.where(df["category"] == 0)[0]),
59
+ character=list(np.where(df["category"] == 4)[0]),
60
+ )
61
+
62
+ return tag_data
63
+
64
+
65
+ def mcut_threshold(probs: np.ndarray) -> float:
66
+ """
67
+ Maximum Cut Thresholding (MCut)
68
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
69
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
70
+ (pp. 172-183).
71
+ """
72
+ probs = probs[probs.argsort()[::-1]]
73
+ diffs = probs[:-1] - probs[1:]
74
+ idx = diffs.argmax()
75
+ thresh = (probs[idx] + probs[idx + 1]) / 2
76
+ return float(thresh)
77
+
78
+
79
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
80
+ # convert to RGB/RGBA if not already (deals with palette images etc.)
81
+ if image.mode not in ["RGB", "RGBA"]:
82
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
83
+ # convert RGBA to RGB with white background
84
+ if image.mode == "RGBA":
85
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
86
+ canvas.alpha_composite(image)
87
+ image = canvas.convert("RGB")
88
+ return image
89
+
90
+
91
+ def pil_pad_square(
92
+ image: Image.Image,
93
+ fill: tuple[int, int, int] = (255, 255, 255),
94
+ ) -> Image.Image:
95
+ w, h = image.size
96
+ # get the largest dimension so we can pad to a square
97
+ px = max(image.size)
98
+ # pad to square with white background
99
+ canvas = Image.new("RGB", (px, px), fill)
100
+ canvas.paste(image, ((px - w) // 2, (px - h) // 2))
101
+ return canvas
102
+
103
+
104
+ def preprocess_image(
105
+ image: Image.Image,
106
+ size_px: int | tuple[int, int],
107
+ upscale: bool = True,
108
+ ) -> Image.Image:
109
+ """
110
+ Preprocess an image to be square and centered on a white background.
111
+ """
112
+ if isinstance(size_px, int):
113
+ size_px = (size_px, size_px)
114
+
115
+ # ensure RGB and pad to square
116
+ image = pil_ensure_rgb(image)
117
+ image = pil_pad_square(image)
118
+
119
+ # resize to target size
120
+ if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
121
+ if upscale is False:
122
+ raise ValueError("Image is smaller than target size, and upscaling is disabled")
123
+ image = image.resize(size_px, Image.LANCZOS)
124
+ if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
125
+ image.thumbnail(size_px, Image.BICUBIC)
126
+
127
+ return image
128
+
129
+
130
+ def pil_make_grid(
131
+ images: list[Image.Image],
132
+ max_cols: int = 8,
133
+ padding: int = 4,
134
+ bg_color: tuple[int, int, int] = (40, 42, 54), # dracula background color
135
+ partial_rows: bool = True,
136
+ ) -> Image.Image:
137
+ n_cols = min(math.floor(math.sqrt(len(images))), max_cols)
138
+ n_rows = math.ceil(len(images) / n_cols)
139
+
140
+ # if the final row is not full and partial_rows is False, remove a row
141
+ if n_cols * n_rows > len(images) and not partial_rows:
142
+ n_rows -= 1
143
+
144
+ # assumes all images are same size
145
+ image_width, image_height = images[0].size
146
+
147
+ canvas_width = ((image_width + padding) * n_cols) + padding
148
+ canvas_height = ((image_height + padding) * n_rows) + padding
149
+
150
+ canvas = Image.new("RGB", (canvas_width, canvas_height), bg_color)
151
+ for i, img in enumerate(images):
152
+ x = (i % n_cols) * (image_width + padding) + padding
153
+ y = (i // n_cols) * (image_height + padding) + padding
154
+ canvas.paste(img, (x, y))
155
+
156
+ return canvas
157
+
158
+
159
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
160
+ kaomojis = [
161
+ "0_0",
162
+ "(o)_(o)",
163
+ "+_+",
164
+ "+_-",
165
+ "._.",
166
+ "<o>_<o>",
167
+ "<|>_<|>",
168
+ "=_=",
169
+ ">_<",
170
+ "3_3",
171
+ "6_9",
172
+ ">_o",
173
+ "@_@",
174
+ "^_^",
175
+ "o_o",
176
+ "u_u",
177
+ "x_x",
178
+ "|_|",
179
+ "||_||",
180
+ ]
tagger/model.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ from dataclasses import dataclass, field
4
+ from os import PathLike, cpu_count
5
+ from pathlib import Path
6
+ from typing import Any, Optional, TypeAlias
7
+
8
+ import colorcet as cc
9
+ import cv2
10
+ import numpy as np
11
+ import pandas as pd
12
+ import timm
13
+ import torch
14
+ from matplotlib.colors import LinearSegmentedColormap
15
+ from PIL import Image
16
+ from timm.data import create_transform, resolve_data_config
17
+ from timm.models import VisionTransformer
18
+ from torch import Tensor, nn
19
+ from torch.nn import functional as F
20
+ from torchvision import transforms as T
21
+
22
+ from .common import Heatmap, ImageLabels, LabelData, load_labels_hf, pil_ensure_rgb, pil_make_grid
23
+
24
+ # working dir, either file parent dir or cwd if interactive
25
+ work_dir = (Path(__file__).parent if "__file__" in locals() else Path.cwd()).resolve()
26
+ temp_dir = work_dir.joinpath("temp")
27
+ temp_dir.mkdir(exist_ok=True, parents=True)
28
+
29
+ # model cache
30
+ model_cache: dict[str, VisionTransformer] = {}
31
+ transform_cache: dict[str, T.Compose] = {}
32
+
33
+ # device to use
34
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+
37
+ class RGBtoBGR(nn.Module):
38
+ def forward(self, x: Tensor) -> Tensor:
39
+ if x.ndim == 4:
40
+ return x[:, [2, 1, 0], :, :]
41
+ return x[[2, 1, 0], :, :]
42
+
43
+
44
+ def model_device(model: nn.Module) -> torch.device:
45
+ return next(model.parameters()).device
46
+
47
+
48
+ def load_model(repo_id: str) -> VisionTransformer:
49
+ global model_cache
50
+
51
+ if model_cache.get(repo_id, None) is None:
52
+ # save model to cache
53
+ model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval().to(torch_device)
54
+
55
+ return model_cache[repo_id]
56
+
57
+
58
+ def load_model_and_transform(repo_id: str) -> tuple[VisionTransformer, T.Compose]:
59
+ global transform_cache
60
+ global model_cache
61
+
62
+ if model_cache.get(repo_id, None) is None:
63
+ # save model to cache
64
+ model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval()
65
+ model = model_cache[repo_id]
66
+
67
+ if transform_cache.get(repo_id, None) is None:
68
+ transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
69
+ # hack in the RGBtoBGR transform, save to cache
70
+ transform_cache[repo_id] = T.Compose(transforms.transforms + [RGBtoBGR()])
71
+ transform = transform_cache[repo_id]
72
+
73
+ return model, transform
74
+
75
+
76
+ def get_tags(
77
+ probs: Tensor,
78
+ labels: LabelData,
79
+ gen_threshold: float,
80
+ char_threshold: float,
81
+ ):
82
+ # Convert indices+probs to labels
83
+ probs = list(zip(labels.names, probs.numpy()))
84
+
85
+ # First 4 labels are actually ratings
86
+ rating_labels = dict([probs[i] for i in labels.rating])
87
+
88
+ # General labels, pick any where prediction confidence > threshold
89
+ gen_labels = [probs[i] for i in labels.general]
90
+ gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
91
+ gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
92
+
93
+ # Character labels, pick any where prediction confidence > threshold
94
+ char_labels = [probs[i] for i in labels.character]
95
+ char_labels = dict([x for x in char_labels if x[1] > char_threshold])
96
+ char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
97
+
98
+ # Combine general and character labels, sort by confidence
99
+ combined_names = [x for x in gen_labels]
100
+ combined_names.extend([x for x in char_labels])
101
+
102
+ # Convert to a string suitable for use as a training caption
103
+ caption = ", ".join(combined_names).replace("(", "\(").replace(")", "\)")
104
+ booru = caption.replace("_", " ")
105
+
106
+ return caption, booru, rating_labels, char_labels, gen_labels
107
+
108
+
109
+ @torch.no_grad()
110
+ def render_heatmap(
111
+ image: Tensor,
112
+ gradients: Tensor,
113
+ image_feats: Tensor,
114
+ image_probs: Tensor,
115
+ image_labels: list[str],
116
+ cmap: LinearSegmentedColormap = cc.m_linear_bmy_10_95_c71,
117
+ pos_embed_dim: int = 784,
118
+ image_size: tuple[int, int] = (448, 448),
119
+ font_args: dict = {
120
+ "fontFace": cv2.FONT_HERSHEY_SIMPLEX,
121
+ "fontScale": 1,
122
+ "color": (255, 255, 255),
123
+ "thickness": 2,
124
+ "lineType": cv2.LINE_AA,
125
+ },
126
+ partial_rows: bool = True,
127
+ ) -> tuple[list[Heatmap], Image.Image]:
128
+ hmap_dim = int(math.sqrt(pos_embed_dim))
129
+
130
+ image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
131
+ image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), hmap_dim, hmap_dim)
132
+ image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
133
+
134
+ image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1)
135
+ # normalize to 0-1
136
+ image_hmaps = torch.stack([(x - x.min()) / (x.max() - x.min()) for x in image_hmaps]).unsqueeze(1)
137
+ # interpolate to input image size
138
+ image_hmaps = F.interpolate(image_hmaps, size=image_size, mode="bilinear").squeeze(1)
139
+
140
+ hmap_imgs: list[Heatmap] = []
141
+ for tag, hmap, score in zip(image_labels, image_hmaps, image_probs.cpu()):
142
+ image_pixels = image.add(1).mul(127.5).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8)
143
+ hmap_pixels = cmap(hmap.cpu().numpy(), bytes=True)[:, :, :3]
144
+
145
+ hmap_cv2 = cv2.cvtColor(hmap_pixels, cv2.COLOR_RGB2BGR)
146
+ hmap_image = cv2.addWeighted(image_pixels, 0.5, hmap_cv2, 0.5, 0)
147
+ if tag is not None:
148
+ cv2.putText(hmap_image, tag, (10, 30), **font_args)
149
+ cv2.putText(hmap_image, f"{score:.3f}", org=(10, 60), **font_args)
150
+
151
+ hmap_pil = Image.fromarray(cv2.cvtColor(hmap_image, cv2.COLOR_BGR2RGB))
152
+ hmap_imgs.append(Heatmap(tag, score.item(), hmap_pil))
153
+
154
+ hmap_imgs = sorted(hmap_imgs, key=lambda x: x.score, reverse=True)
155
+ hmap_grid = pil_make_grid([x.image for x in hmap_imgs], partial_rows=partial_rows)
156
+
157
+ return hmap_imgs, hmap_grid
158
+
159
+
160
+ def process_heatmap(
161
+ model: VisionTransformer,
162
+ image: Tensor,
163
+ labels: LabelData,
164
+ threshold: float = 0.5,
165
+ partial_rows: bool = True,
166
+ ) -> tuple[list[tuple[float, str, Image.Image]], Image.Image, ImageLabels]:
167
+ torch_device = model_device(model)
168
+
169
+ with torch.set_grad_enabled(True):
170
+ features = model.forward_features(image.to(torch_device))
171
+ probs = model.forward_head(features)
172
+ probs = F.sigmoid(probs).squeeze(0)
173
+
174
+ probs_mask = probs > threshold
175
+ heatmap_probs = probs[probs_mask]
176
+
177
+ label_indices = torch.nonzero(probs_mask, as_tuple=False).squeeze(1)
178
+ image_labels = [labels.names[label_indices[i]] for i in range(len(label_indices))]
179
+
180
+ eye = torch.eye(heatmap_probs.shape[0], device=torch_device)
181
+ grads = torch.autograd.grad(
182
+ outputs=heatmap_probs,
183
+ inputs=features,
184
+ grad_outputs=eye,
185
+ is_grads_batched=True,
186
+ retain_graph=True,
187
+ )
188
+ grads = grads[0].detach().requires_grad_(False)[:, 0, :, :].unsqueeze(1)
189
+
190
+ with torch.set_grad_enabled(False):
191
+ hmap_imgs, hmap_grid = render_heatmap(
192
+ image=image,
193
+ gradients=grads,
194
+ image_feats=features,
195
+ image_probs=heatmap_probs,
196
+ image_labels=image_labels,
197
+ partial_rows=partial_rows,
198
+ )
199
+
200
+ caption, booru, ratings, character, general = get_tags(
201
+ probs=probs.cpu(),
202
+ labels=labels,
203
+ gen_threshold=threshold,
204
+ char_threshold=threshold,
205
+ )
206
+ labels = ImageLabels(caption, booru, ratings, general, character)
207
+
208
+ return hmap_imgs, hmap_grid, labels