File size: 9,750 Bytes
6c7b048
 
 
 
 
 
 
 
 
 
 
 
 
d400378
6c7b048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d400378
6c7b048
 
d400378
 
 
 
 
 
 
6c7b048
 
 
 
 
 
d400378
 
6c7b048
d400378
 
 
 
 
 
 
 
6c7b048
d400378
 
6c7b048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d400378
6c7b048
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
from pathlib import Path
from torch import nn
import torchvision.transforms.functional as TVF

CLIP_PATH = "google/siglip-so400m-patch14-384"
CHECKPOINT_PATH = Path("./checkpoint")
LLMA_CHECKPOINT = "John6666/Llama-3.1-8B-Lexi-Uncensored-V2-nf4"
WORDS=200
PROMPT = "In one paragraph, write a very descriptive caption for this image, describe all objects, characters and their actions, describe in detail what is happening and their emotions. Include information about lighting, the style of this image and information about camera angle within {word_count} words. Don't create any title for the image."
IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
HF_TOKEN = os.environ.get("HF_TOKEN", None)

class ImageAdapter(nn.Module):
	def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
		super().__init__()
		self.deep_extract = deep_extract

		if self.deep_extract:
			input_features = input_features * 5

		self.linear1 = nn.Linear(input_features, output_features)
		self.activation = nn.GELU()
		self.linear2 = nn.Linear(output_features, output_features)
		self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
		self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))

		# Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
		self.other_tokens = nn.Embedding(3, output_features)
		self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)	 # Matches HF's implementation of llama3

	def forward(self, vision_outputs: torch.Tensor):
		if self.deep_extract:
			x = torch.concat((
				vision_outputs[-2],
				vision_outputs[3],
				vision_outputs[7],
				vision_outputs[13],
				vision_outputs[20],
			), dim=-1)
			assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"	# batch, tokens, features
			assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
		else:
			x = vision_outputs[-2]

		x = self.ln1(x)

		if self.pos_emb is not None:
			assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
			x = x + self.pos_emb

		x = self.linear1(x)
		x = self.activation(x)
		x = self.linear2(x)

		# <|image_start|>, IMAGE, <|image_end|>
		other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
		assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
		x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)

		return x

	def get_eot_embedding(self):
		return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
			

def proc_img(input_image):
	# Preprocess image
	# NOTE: I found the default processor for so400M to have worse results than just using PIL directly
	#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
	image = input_image.resize((384, 384), Image.LANCZOS)
	pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
	pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
	pixel_values = pixel_values.to(device)

	# Embed image
	# This results in Batch x Image Tokens x Features
	with torch.amp.autocast_mode.autocast(device, enabled=True):
		vision_outputs = model(pixel_values=pixel_values, output_hidden_states=True)
		embedded_images = image_adapter(vision_outputs.hidden_states)
		embedded_images = embedded_images.to(device)
	
	# Build the conversation
	convo = [
		{
			"role": "system",
			"content": "You are a helpful image captioner.",
		},
		{
			"role": "user",
			"content": prompt_str,
		},
	]

	# Format the conversation
	convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
	assert isinstance(convo_string, str)

	# Tokenize the conversation
	# prompt_str is tokenized separately so we can do the calculations below
	convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
	prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
	assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
	convo_tokens = convo_tokens.squeeze(0)	 # Squeeze just to make the following easier
	prompt_tokens = prompt_tokens.squeeze(0)

	# Calculate where to inject the image
	eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
	assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"

	preamble_len = eot_id_indices[1] - prompt_tokens.shape[0]	 # Number of tokens before the prompt

	# Embed the tokens
	convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))

	# Construct the input
	input_embeds = torch.cat([
		convo_embeds[:, :preamble_len],	 # Part before the prompt
		embedded_images.to(dtype=convo_embeds.dtype),	 # Image
		convo_embeds[:, preamble_len:],	 # The prompt and anything after it
	], dim=1).to(device)

	input_ids = torch.cat([
		convo_tokens[:preamble_len].unsqueeze(0),
		torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),	 # Dummy tokens for the image (TODO: Should probably use a special token here so as not to confuse any generation algorithms that might be inspecting the input)
		convo_tokens[preamble_len:].unsqueeze(0),
	], dim=1).to(device)
	attention_mask = torch.ones_like(input_ids)

	#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
	generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None)	 # Uses the default which is temp=0.6, top_p=0.9

	# Trim off the prompt
	generate_ids = generate_ids[:, input_ids.shape[1]:]
	if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
		generate_ids = generate_ids[:, :-1]

	caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]

	return caption.strip('\"')

def describe_image(image_path):
		if not os.path.exists(image_path):
				print(f"File not found: {image_path}")
				return

		image = Image.open(image_path).convert("RGB")

		description = proc_img(image)

		# Output filename
		output_path = os.path.splitext(image_path)[0] + ".txt"

		# Save caption file
		with open(output_path, "w", encoding="utf-8") as f:
				f.write(description)

		print(f"Description save in: {output_path}")

if __name__ == "__main__":
		import argparse
		
		parser = argparse.ArgumentParser(description="Caption all PNG image files in a folder")
		parser.add_argument("folder_path", type=str, help="Folder containing images.")
		parser.add_argument("--prompt", type=str, help="Prompt to ask a caption.", default=None, required=False)
		parser.add_argument("--output_dir", type=str, help="Output dir.", default=None, required=False)
		args = parser.parse_args()

		# Prompt
		if args.prompt is None: 
			prompt_str = PROMPT.format(word_count=WORDS)
		else:
			prompt_str = args.prompt

		# Process all images in the folder
		folder_path = Path(args.folder_path)
		if not folder_path.is_dir():
				print(f"Error: {folder_path} is not a valid directory.")
				exit(1)

		# Prompt
		if args.output_dir is None: 
			output_dir = folder_path
		else:
			output_dir = args.output_dir

		img_files = [f for f in folder_path.iterdir() if f.suffix.lower() in IMAGE_EXTENSIONS]
		img_files = [f for f in img_files if not Path(output_dir,f"{f.stem}.txt").exists()]

		if not img_files:
			print(f"No image files without caption found in the directory: {folder_path}")
			exit(1)

		total = len(img_files)
		print(f"Found {total} IMAGE files without caption. Processing...")

		device = "cuda" if torch.cuda.is_available() else "cpu"

		# Load CLIP
		print("Loading CLIP")
		processor = AutoProcessor.from_pretrained(CLIP_PATH)
		model = AutoModel.from_pretrained(CLIP_PATH).to(device)
		model = model.vision_model

		assert (CHECKPOINT_PATH / "clip_model.pt").exists()
		print("Loading VLM's custom vision model")
		checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu',weights_only=True)
		checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
		model.load_state_dict(checkpoint)
		del checkpoint

		# Tokenizer
		print("Loading tokenizer")
		tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH / "text_model", use_fast=True)
		assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"

		# LLM
		print("Loading VLM's custom text model")
		text_model = AutoModelForCausalLM.from_pretrained(LLMA_CHECKPOINT , device_map=0, trust_remote_code=True,torch_dtype=torch.bfloat16)
		text_model.eval()

		# Image Adapter
		print("Loading image adapter")
		image_adapter = ImageAdapter(model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
		image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu",weights_only=True))
		image_adapter.eval()
		image_adapter.to(device)

		curr = 1
		for image_path in img_files:
				print(f"Processing image {curr} of {total}: {image_path}")
				curr += 1
				describe_image(str(image_path))