VNAT commited on
Commit
daba3f8
1 Parent(s): a142643

Add batched_inference.py with some goodies

Browse files
Files changed (1) hide show
  1. batched_inference.py +180 -0
batched_inference.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.multiprocessing as multiprocessing
2
+ import torchvision.transforms as transforms
3
+ from torch import autocast
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from PIL import Image
6
+ import torch
7
+ from torchvision.transforms import InterpolationMode
8
+ from tqdm import tqdm
9
+ import json
10
+ import os
11
+
12
+ torch.backends.cuda.matmul.allow_tf32 = True
13
+ torch.backends.cudnn.allow_tf32 = True
14
+ torch.autograd.set_detect_anomaly(False)
15
+ torch.autograd.profiler.emit_nvtx(enabled=False)
16
+ torch.autograd.profiler.profile(enabled=False)
17
+ torch.backends.cudnn.benchmark = True
18
+
19
+
20
+ class ImageDataset(Dataset):
21
+ def __init__(self, image_folder_path, allowed_extensions):
22
+ self.allowed_extensions = allowed_extensions
23
+ self.all_image_paths, self.all_image_names, self.image_base_paths = self.get_image_paths(image_folder_path)
24
+ self.train_size = len(self.all_image_paths)
25
+ print(f"Number of images to be tagged: {self.train_size}")
26
+ self.thin_transform = transforms.Compose([
27
+ transforms.Resize(448, interpolation=InterpolationMode.BICUBIC),
28
+ transforms.CenterCrop(448),
29
+ transforms.ToTensor(),
30
+ # Normalize image
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
32
+ ])
33
+ self.normal_transform = transforms.Compose([
34
+ transforms.Resize((448, 448), interpolation=InterpolationMode.BICUBIC),
35
+ transforms.ToTensor(),
36
+ # Normalize image
37
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+
39
+ ])
40
+
41
+ def get_image_paths(self, folder_path):
42
+ image_paths = []
43
+ image_file_names = []
44
+ image_base_paths = []
45
+ for root, dirs, files in os.walk(folder_path):
46
+ for file in files:
47
+ if file.lower().split(".")[-1] in self.allowed_extensions:
48
+ image_paths.append((os.path.abspath(os.path.join(root, file))))
49
+ image_file_names.append(file.split(".")[0])
50
+ image_base_paths.append(root)
51
+ return image_paths, image_file_names, image_base_paths
52
+
53
+ def __len__(self):
54
+ return len(self.all_image_paths)
55
+
56
+ def __getitem__(self, index):
57
+ image = Image.open(self.all_image_paths[index]).convert("RGB")
58
+ ratio = image.height / image.width
59
+ if ratio > 2.0 or ratio < 0.5:
60
+ image = self.thin_transform(image)
61
+ else:
62
+ image = self.normal_transform(image)
63
+
64
+ return {
65
+ 'image': image,
66
+ "image_name": self.all_image_names[index],
67
+ "image_root": self.image_base_paths[index]
68
+ }
69
+
70
+
71
+ def prepare_model(model_path: str):
72
+ model = torch.load(model_path)
73
+ model.to(memory_format=torch.channels_last)
74
+ model = model.eval()
75
+ return model
76
+
77
+
78
+ def train(tagging_is_running, model, dataloader, train_data, output_queue):
79
+ print('Begin tagging')
80
+ model.eval()
81
+ counter = 0
82
+
83
+ with torch.no_grad():
84
+ for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)):
85
+ this_data = data['image'].to("cuda")
86
+ with autocast(device_type='cuda', dtype=torch.bfloat16):
87
+ outputs = model(this_data)
88
+
89
+ probabilities = torch.nn.functional.sigmoid(outputs)
90
+ output_queue.put((probabilities.to("cpu"), data["image_name"], data["image_root"]))
91
+
92
+ counter += 1
93
+ _ = tagging_is_running.get()
94
+ print("Tagging finished!")
95
+
96
+
97
+ def tag_writer(tagging_is_running, output_queue, threshold):
98
+ with open("tags_8034.json", "r") as file:
99
+ tags = json.load(file)
100
+ allowed_tags = sorted(tags)
101
+ del tags
102
+ allowed_tags.extend(["placeholder0"])
103
+ tag_count = len(allowed_tags)
104
+ assert tag_count == 8035, f"The length of tag list is not correct. Correct: 8035, current: {tag_count}"
105
+
106
+ while not (tagging_is_running.qsize() > 0 and output_queue.qsize() > 0):
107
+ tag_probabilities, image_names, image_roots = output_queue.get()
108
+ tag_probabilities = tag_probabilities.tolist()
109
+
110
+ for per_image_tag_probabilities, image_name, image_root in zip(tag_probabilities, image_names, image_roots,
111
+ strict=True):
112
+ this_image_tags = []
113
+ this_image_tag_probabilities = []
114
+ for index, per_tag_probability in enumerate(per_image_tag_probabilities):
115
+ if per_tag_probability > threshold:
116
+ tag = allowed_tags[index]
117
+ if "placeholder" not in tag:
118
+ this_image_tags.append(tag)
119
+ this_image_tag_probabilities.append(str(int(round(per_tag_probability, 3) * 1000)))
120
+ output_file = os.path.join(image_root, os.path.splitext(image_name)[0] + ".txt")
121
+ with open(output_file, "w", encoding="utf-8") as this_output:
122
+ # set this to true if you want tags separated with commas instead of spaces (will output "tag0, tag1...")
123
+ use_comma_sep = True
124
+ sep = " "
125
+ if use_comma_sep:
126
+ sep = ", "
127
+ # set this to true if you want to replace underscores with spaces
128
+ remove_underscores = True
129
+ if remove_underscores:
130
+ this_image_tags = map(lambda e: e.replace('_', ' '), this_image_tags)
131
+ this_output.write(sep.join(this_image_tags))
132
+ # change output_probabilities to True if you want probabilities
133
+ output_probabilities = False
134
+ if output_probabilities:
135
+ this_output.write("\n")
136
+ this_output.write(sep.join(this_image_tag_probabilities))
137
+
138
+
139
+ def main():
140
+ image_folder_path = "/path/to/img/folder"
141
+ # all images should be in this folder and/or its subfolders.
142
+ # I will generate a text file for every image.
143
+ model_path = "/path/to/your/model.pth"
144
+ allowed_extensions = {"jpg", "jpeg", "png", "webp"}
145
+ batch_size = 64
146
+ # if you have a 24GB card, you can try 256
147
+ threshold = 0.3
148
+
149
+ multiprocessing.set_start_method('spawn')
150
+ output_queue = multiprocessing.Queue()
151
+ tagging_is_running = multiprocessing.Queue(maxsize=5)
152
+ tagging_is_running.put("Running!")
153
+
154
+ if not torch.cuda.is_available():
155
+ raise RuntimeError("CUDA is not available!")
156
+
157
+ model = prepare_model(model_path).to("cuda")
158
+
159
+ dataset = ImageDataset(image_folder_path, allowed_extensions)
160
+
161
+ batched_loader = DataLoader(
162
+ dataset,
163
+ batch_size=batch_size,
164
+ shuffle=False,
165
+ num_workers=12, # if you have a big batch size, a good cpu, and enough cpu memory, try 12
166
+ pin_memory=True,
167
+ drop_last=False,
168
+ )
169
+ process_writer = multiprocessing.Process(target=tag_writer,
170
+ args=(tagging_is_running, output_queue, threshold))
171
+ process_writer.start()
172
+ process_tagger = multiprocessing.Process(target=train,
173
+ args=(tagging_is_running, model, batched_loader, dataset, output_queue,))
174
+ process_tagger.start()
175
+ process_writer.join()
176
+ process_tagger.join()
177
+
178
+
179
+ if __name__ == "__main__":
180
+ main()