batched inference

#1
by xeph - opened
Files changed (1) hide show
  1. batched_inference.py +179 -0
batched_inference.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.multiprocessing as multiprocessing
2
+ import torchvision.transforms as transforms
3
+ from torch import autocast
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from PIL import Image
6
+ import torch
7
+ from torchvision.transforms import InterpolationMode
8
+ from tqdm import tqdm
9
+ import json
10
+ import os
11
+
12
+ torch.backends.cuda.matmul.allow_tf32 = True
13
+ torch.backends.cudnn.allow_tf32 = True
14
+ torch.autograd.set_detect_anomaly(False)
15
+ torch.autograd.profiler.emit_nvtx(enabled=False)
16
+ torch.autograd.profiler.profile(enabled=False)
17
+ torch.backends.cudnn.benchmark = True
18
+
19
+
20
+ class ImageDataset(Dataset):
21
+ def __init__(self, image_folder_path, allowed_extensions):
22
+ self.allowed_extensions = allowed_extensions
23
+ self.all_image_paths, self.all_image_names, self.image_base_paths = self.get_image_paths(image_folder_path)
24
+ self.train_size = len(self.all_image_paths)
25
+ print(f"Number of images to be tagged: {self.train_size}")
26
+ self.thin_transform = transforms.Compose([
27
+ transforms.Resize(448, interpolation=InterpolationMode.BICUBIC),
28
+ transforms.CenterCrop(448),
29
+ transforms.ToTensor(),
30
+ # Normalize image
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
32
+ ])
33
+ self.normal_transform = transforms.Compose([
34
+ transforms.Resize((448, 448), interpolation=InterpolationMode.BICUBIC),
35
+ transforms.ToTensor(),
36
+ # Normalize image
37
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+
39
+ ])
40
+
41
+ def get_image_paths(self, folder_path):
42
+ image_paths = []
43
+ image_file_names = []
44
+ image_base_paths = []
45
+ for root, dirs, files in os.walk(folder_path):
46
+ for file in files:
47
+ if file.lower().split(".")[-1] in self.allowed_extensions:
48
+ image_paths.append((os.path.abspath(os.path.join(root, file))))
49
+ image_file_names.append(file.split(".")[0])
50
+ image_base_paths.append(root)
51
+ return image_paths, image_file_names, image_base_paths
52
+
53
+ def __len__(self):
54
+ return len(self.all_image_paths)
55
+
56
+ def __getitem__(self, index):
57
+ image = Image.open(self.all_image_paths[index]).convert("RGB")
58
+ ratio = image.height / image.width
59
+ if ratio > 2.0 or ratio < 0.5:
60
+ image = self.thin_transform(image)
61
+ else:
62
+ image = self.normal_transform(image)
63
+
64
+ return {
65
+ 'image': image,
66
+ "image_name": self.all_image_names[index],
67
+ "image_root": self.image_base_paths[index]
68
+ }
69
+
70
+
71
+ def prepare_model(model_path: str):
72
+ model = torch.load(model_path)
73
+ model.to(memory_format=torch.channels_last)
74
+ model = model.eval()
75
+ return model
76
+
77
+
78
+ def train(tagging_is_running, model, dataloader, train_data, output_queue):
79
+ print('Begin tagging')
80
+ model.eval()
81
+ counter = 0
82
+
83
+ with torch.no_grad():
84
+ for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)):
85
+ this_data = data['image'].to("cuda")
86
+ with autocast(device_type='cuda', dtype=torch.bfloat16):
87
+ outputs = model(this_data)
88
+
89
+ probabilities = torch.nn.functional.sigmoid(outputs)
90
+ output_queue.put((probabilities.to("cpu"), data["image_name"], data["image_root"]))
91
+
92
+ counter += 1
93
+ _ = tagging_is_running.get()
94
+ print("Tagging finished!")
95
+
96
+
97
+ def tag_writer(tagging_is_running, output_queue, threshold):
98
+ with open("tags_8034.json", "r") as f:
99
+ tags = json.load(f)
100
+ tags.append("placeholder0")
101
+ tags = sorted(tags)
102
+ tag_count = len(tags)
103
+ assert tag_count == 8035, f"The length of tag list is not correct. Correct: 8035, current: {tag_count}"
104
+
105
+ while not (tagging_is_running.qsize() > 0 and output_queue.qsize() > 0):
106
+ tag_probabilities, image_names, image_roots = output_queue.get()
107
+ tag_probabilities = tag_probabilities.tolist()
108
+
109
+ for per_image_tag_probabilities, image_name, image_root in zip(tag_probabilities, image_names, image_roots,
110
+ strict=True):
111
+ this_image_tags = []
112
+ this_image_tag_probabilities = []
113
+ for index, per_tag_probability in enumerate(per_image_tag_probabilities):
114
+ if per_tag_probability > threshold:
115
+ tag = allowed_tags[index]
116
+ if "placeholder" not in tag:
117
+ this_image_tags.append(tag)
118
+ this_image_tag_probabilities.append(str(int(round(per_tag_probability, 3) * 1000)))
119
+ output_file = os.path.join(image_root, os.path.splitext(image_name)[0] + ".txt")
120
+ with open(output_file, "w", encoding="utf-8") as this_output:
121
+ # set this to true if you want tags separated with commas instead of spaces (will output "tag0, tag1...")
122
+ use_comma_sep = True
123
+ sep = " "
124
+ if use_comma_sep:
125
+ sep = ", "
126
+ # set this to true if you want to replace underscores with spaces
127
+ remove_underscores = True
128
+ if remove_underscores:
129
+ this_image_tags = map(lambda e: e.replace('_', ' '), this_image_tags)
130
+ this_output.write(sep.join(this_image_tags))
131
+ # change output_probabilities to True if you want probabilities
132
+ output_probabilities = False
133
+ if output_probabilities:
134
+ this_output.write("\n")
135
+ this_output.write(sep.join(this_image_tag_probabilities))
136
+
137
+
138
+ def main():
139
+ image_folder_path = "/path/to/img/folder"
140
+ # all images should be in this folder and/or its subfolders.
141
+ # I will generate a text file for every image.
142
+ model_path = "/path/to/your/model.pth"
143
+ allowed_extensions = {"jpg", "jpeg", "png", "webp"}
144
+ batch_size = 64
145
+ # if you have a 24GB card, you can try 256
146
+ threshold = 0.3
147
+
148
+ multiprocessing.set_start_method('spawn')
149
+ output_queue = multiprocessing.Queue()
150
+ tagging_is_running = multiprocessing.Queue(maxsize=5)
151
+ tagging_is_running.put("Running!")
152
+
153
+ if not torch.cuda.is_available():
154
+ raise RuntimeError("CUDA is not available!")
155
+
156
+ model = prepare_model(model_path).to("cuda")
157
+
158
+ dataset = ImageDataset(image_folder_path, allowed_extensions)
159
+
160
+ batched_loader = DataLoader(
161
+ dataset,
162
+ batch_size=batch_size,
163
+ shuffle=False,
164
+ num_workers=12, # if you have a big batch size, a good cpu, and enough cpu memory, try 12
165
+ pin_memory=True,
166
+ drop_last=False,
167
+ )
168
+ process_writer = multiprocessing.Process(target=tag_writer,
169
+ args=(tagging_is_running, output_queue, threshold))
170
+ process_writer.start()
171
+ process_tagger = multiprocessing.Process(target=train,
172
+ args=(tagging_is_running, model, batched_loader, dataset, output_queue,))
173
+ process_tagger.start()
174
+ process_writer.join()
175
+ process_tagger.join()
176
+
177
+
178
+ if __name__ == "__main__":
179
+ main()