|
import os |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoProcessor, AutoModel, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM |
|
from pathlib import Path |
|
from torch import nn |
|
import torchvision.transforms.functional as TVF |
|
|
|
CLIP_PATH = "google/siglip-so400m-patch14-384" |
|
CHECKPOINT_PATH = Path("./checkpoint") |
|
LLMA_CHECKPOINT = "John6666/Llama-3.1-8B-Lexi-Uncensored-V2-nf4" |
|
WORDS=200 |
|
PROMPT = "In one paragraph, write a very descriptive caption for this image, describe all objects, characters and their actions, describe in detail what is happening and their emotions. Include information about lighting, the style of this image and information about camera angle within {word_count} words. Don't create any title for the image." |
|
IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp') |
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
class ImageAdapter(nn.Module): |
|
def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool): |
|
super().__init__() |
|
self.deep_extract = deep_extract |
|
|
|
if self.deep_extract: |
|
input_features = input_features * 5 |
|
|
|
self.linear1 = nn.Linear(input_features, output_features) |
|
self.activation = nn.GELU() |
|
self.linear2 = nn.Linear(output_features, output_features) |
|
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features) |
|
self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features)) |
|
|
|
|
|
self.other_tokens = nn.Embedding(3, output_features) |
|
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
def forward(self, vision_outputs: torch.Tensor): |
|
if self.deep_extract: |
|
x = torch.concat(( |
|
vision_outputs[-2], |
|
vision_outputs[3], |
|
vision_outputs[7], |
|
vision_outputs[13], |
|
vision_outputs[20], |
|
), dim=-1) |
|
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" |
|
assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}" |
|
else: |
|
x = vision_outputs[-2] |
|
|
|
x = self.ln1(x) |
|
|
|
if self.pos_emb is not None: |
|
assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}" |
|
x = x + self.pos_emb |
|
|
|
x = self.linear1(x) |
|
x = self.activation(x) |
|
x = self.linear2(x) |
|
|
|
|
|
other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1)) |
|
assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}" |
|
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1) |
|
|
|
return x |
|
|
|
def get_eot_embedding(self): |
|
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0) |
|
|
|
|
|
def proc_img(input_image): |
|
|
|
|
|
|
|
image = input_image.resize((384, 384), Image.LANCZOS) |
|
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0 |
|
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) |
|
pixel_values = pixel_values.to(device) |
|
|
|
|
|
|
|
with torch.amp.autocast_mode.autocast(device, enabled=True): |
|
vision_outputs = model(pixel_values=pixel_values, output_hidden_states=True) |
|
embedded_images = image_adapter(vision_outputs.hidden_states) |
|
embedded_images = embedded_images.to(device) |
|
|
|
|
|
convo = [ |
|
{ |
|
"role": "system", |
|
"content": "You are a helpful image captioner.", |
|
}, |
|
{ |
|
"role": "user", |
|
"content": prompt_str, |
|
}, |
|
] |
|
|
|
|
|
convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True) |
|
assert isinstance(convo_string, str) |
|
|
|
|
|
|
|
convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False) |
|
prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False) |
|
assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor) |
|
convo_tokens = convo_tokens.squeeze(0) |
|
prompt_tokens = prompt_tokens.squeeze(0) |
|
|
|
|
|
eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist() |
|
assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}" |
|
|
|
preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] |
|
|
|
|
|
convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device)) |
|
|
|
|
|
input_embeds = torch.cat([ |
|
convo_embeds[:, :preamble_len], |
|
embedded_images.to(dtype=convo_embeds.dtype), |
|
convo_embeds[:, preamble_len:], |
|
], dim=1).to(device) |
|
|
|
input_ids = torch.cat([ |
|
convo_tokens[:preamble_len].unsqueeze(0), |
|
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), |
|
convo_tokens[preamble_len:].unsqueeze(0), |
|
], dim=1).to(device) |
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) |
|
|
|
|
|
generate_ids = generate_ids[:, input_ids.shape[1]:] |
|
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"): |
|
generate_ids = generate_ids[:, :-1] |
|
|
|
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0] |
|
|
|
return caption.strip('\"') |
|
|
|
def describe_image(image_path): |
|
if not os.path.exists(image_path): |
|
print(f"File not found: {image_path}") |
|
return |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
|
description = proc_img(image) |
|
|
|
|
|
output_path = os.path.splitext(image_path)[0] + ".txt" |
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
f.write(description) |
|
|
|
print(f"Description save in: {output_path}") |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Caption all PNG image files in a folder") |
|
parser.add_argument("folder_path", type=str, help="Folder containing images.") |
|
parser.add_argument("--prompt", type=str, help="Prompt to ask a caption.", default=None, required=False) |
|
parser.add_argument("--output_dir", type=str, help="Output dir.", default=None, required=False) |
|
args = parser.parse_args() |
|
|
|
|
|
if args.prompt is None: |
|
prompt_str = PROMPT.format(word_count=WORDS) |
|
else: |
|
prompt_str = args.prompt |
|
|
|
|
|
folder_path = Path(args.folder_path) |
|
if not folder_path.is_dir(): |
|
print(f"Error: {folder_path} is not a valid directory.") |
|
exit(1) |
|
|
|
|
|
if args.output_dir is None: |
|
output_dir = folder_path |
|
else: |
|
output_dir = args.output_dir |
|
|
|
img_files = [f for f in folder_path.iterdir() if f.suffix.lower() in IMAGE_EXTENSIONS] |
|
img_files = [f for f in img_files if not Path(output_dir,f"{f.stem}.txt").exists()] |
|
|
|
if not img_files: |
|
print(f"No image files without caption found in the directory: {folder_path}") |
|
exit(1) |
|
|
|
total = len(img_files) |
|
print(f"Found {total} IMAGE files without caption. Processing...") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
print("Loading CLIP") |
|
processor = AutoProcessor.from_pretrained(CLIP_PATH) |
|
model = AutoModel.from_pretrained(CLIP_PATH).to(device) |
|
model = model.vision_model |
|
|
|
assert (CHECKPOINT_PATH / "clip_model.pt").exists() |
|
print("Loading VLM's custom vision model") |
|
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu',weights_only=True) |
|
checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()} |
|
model.load_state_dict(checkpoint) |
|
del checkpoint |
|
|
|
|
|
print("Loading tokenizer") |
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH / "text_model", use_fast=True) |
|
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}" |
|
|
|
|
|
print("Loading VLM's custom text model") |
|
text_model = AutoModelForCausalLM.from_pretrained(LLMA_CHECKPOINT , device_map=0, trust_remote_code=True,torch_dtype=torch.bfloat16) |
|
text_model.eval() |
|
|
|
|
|
print("Loading image adapter") |
|
image_adapter = ImageAdapter(model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False) |
|
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu",weights_only=True)) |
|
image_adapter.eval() |
|
image_adapter.to(device) |
|
|
|
curr = 1 |
|
for image_path in img_files: |
|
print(f"Processing image {curr} of {total}: {image_path}") |
|
curr += 1 |
|
describe_image(str(image_path)) |
|
|