|
|
|
|
|
|
|
""" |
|
Simplified JoyCaption - Generates captions for a single image input |
|
""" |
|
|
|
import argparse |
|
from pathlib import Path |
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms.functional as TVF |
|
from transformers import ( |
|
AutoModel, |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
) |
|
from torch import nn |
|
import logging |
|
import sys |
|
|
|
CLIP_PATH = "google/siglip-so400m-patch14-384" |
|
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B" |
|
CHECKPOINT_PATH = Path(__file__).resolve().parent / "cgrkzexw-599808" |
|
|
|
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] |
|
|
|
try: |
|
import pillow_avif |
|
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) |
|
except: |
|
pass |
|
|
|
|
|
try: |
|
from jxlpy import JXLImagePlugin |
|
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) |
|
except: |
|
pass |
|
|
|
|
|
try: |
|
import pillow_jxl |
|
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) |
|
except: |
|
pass |
|
|
|
class ImageAdapter(nn.Module): |
|
def __init__( |
|
self, |
|
input_features: int, |
|
output_features: int, |
|
ln1: bool, |
|
pos_emb: bool, |
|
num_image_tokens: int, |
|
deep_extract: bool, |
|
): |
|
super().__init__() |
|
self.deep_extract = deep_extract |
|
if self.deep_extract: |
|
input_features = input_features * 5 |
|
|
|
self.linear1 = nn.Linear(input_features, output_features) |
|
self.activation = nn.GELU() |
|
self.linear2 = nn.Linear(output_features, output_features) |
|
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features) |
|
self.pos_emb = ( |
|
None |
|
if not pos_emb |
|
else nn.Parameter(torch.zeros(num_image_tokens, input_features)) |
|
) |
|
self.other_tokens = nn.Embedding(3, output_features) |
|
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
def forward(self, vision_outputs): |
|
if self.deep_extract: |
|
x = torch.concat( |
|
( |
|
vision_outputs[-2], |
|
vision_outputs[3], |
|
vision_outputs[7], |
|
vision_outputs[13], |
|
vision_outputs[20], |
|
), |
|
dim=-1, |
|
) |
|
else: |
|
x = vision_outputs[-2] |
|
|
|
x = self.ln1(x) |
|
if self.pos_emb is not None: |
|
x = x + self.pos_emb |
|
|
|
x = self.linear1(x) |
|
x = self.activation(x) |
|
x = self.linear2(x) |
|
|
|
other_tokens = self.other_tokens( |
|
torch.tensor([0, 1], device=self.other_tokens.weight.device).expand( |
|
x.shape[0], -1 |
|
) |
|
) |
|
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1) |
|
return x |
|
|
|
class SimpleCaptioner: |
|
def __init__(self): |
|
self.clip_model = None |
|
self.text_model = None |
|
self.image_adapter = None |
|
self.tokenizer = None |
|
|
|
def load_models(self): |
|
logging.info("Loading CLIP") |
|
self.clip_model = AutoModel.from_pretrained(CLIP_PATH) |
|
self.clip_model = self.clip_model.vision_model |
|
|
|
if (CHECKPOINT_PATH / "clip_model.pt").exists(): |
|
checkpoint = torch.load( |
|
CHECKPOINT_PATH / "clip_model.pt", map_location="cpu" |
|
) |
|
checkpoint = { |
|
k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items() |
|
} |
|
self.clip_model.load_state_dict(checkpoint) |
|
|
|
self.clip_model.eval() |
|
self.clip_model.requires_grad_(False) |
|
self.clip_model.to("cuda") |
|
|
|
logging.info("Loading tokenizer and LLM") |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
CHECKPOINT_PATH / "text_model", use_fast=True |
|
) |
|
|
|
if (CHECKPOINT_PATH / "text_model").exists(): |
|
self.text_model = AutoModelForCausalLM.from_pretrained( |
|
CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16 |
|
) |
|
else: |
|
self.text_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16 |
|
) |
|
self.text_model.eval() |
|
|
|
logging.info("Loading image adapter") |
|
self.image_adapter = ImageAdapter( |
|
self.clip_model.config.hidden_size, |
|
self.text_model.config.hidden_size, |
|
False, |
|
False, |
|
38, |
|
False, |
|
) |
|
self.image_adapter.load_state_dict( |
|
torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu") |
|
) |
|
self.image_adapter.eval() |
|
self.image_adapter.to("cuda") |
|
|
|
@torch.no_grad() |
|
def generate_caption(self, image_path: str) -> str: |
|
|
|
input_image = Image.open(image_path).convert("RGB") |
|
image = input_image.resize((384, 384), Image.LANCZOS) |
|
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0 |
|
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]).to("cuda") |
|
|
|
|
|
vision_outputs = self.clip_model(pixel_values=pixel_values, output_hidden_states=True) |
|
embedded_images = self.image_adapter(vision_outputs.hidden_states) |
|
|
|
|
|
prompt = "Write a descriptive caption for this image in a formal tone." |
|
convo = [ |
|
{"role": "system", "content": "You are a helpful image captioner."}, |
|
{"role": "user", "content": prompt}, |
|
] |
|
convo_string = self.tokenizer.apply_chat_template( |
|
convo, tokenize=False, add_generation_prompt=True |
|
) |
|
|
|
|
|
convo_tokens = self.tokenizer.encode( |
|
convo_string, return_tensors="pt", add_special_tokens=False |
|
) |
|
prompt_tokens = self.tokenizer.encode( |
|
prompt, return_tensors="pt", add_special_tokens=False |
|
) |
|
|
|
eot_id_indices = ( |
|
(convo_tokens == self.tokenizer.convert_tokens_to_ids("<|eot_id|>")) |
|
.nonzero(as_tuple=True)[0] |
|
.tolist() |
|
) |
|
preamble_len = eot_id_indices[1] - prompt_tokens.shape[1] |
|
|
|
convo_embeds = self.text_model.model.embed_tokens(convo_tokens.to("cuda")) |
|
|
|
input_embeds = torch.cat( |
|
[ |
|
convo_embeds[:, :preamble_len], |
|
embedded_images.to(dtype=convo_embeds.dtype), |
|
convo_embeds[:, preamble_len:], |
|
], |
|
dim=1, |
|
) |
|
|
|
input_ids = torch.cat( |
|
[ |
|
convo_tokens[:, :preamble_len], |
|
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long, device="cuda"), |
|
convo_tokens[:, preamble_len:], |
|
], |
|
dim=1, |
|
) |
|
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
generate_ids = self.text_model.generate( |
|
input_ids, |
|
inputs_embeds=input_embeds, |
|
attention_mask=attention_mask, |
|
max_new_tokens=300, |
|
do_sample=True, |
|
repetition_penalty=1.2, |
|
) |
|
|
|
|
|
generate_ids = generate_ids[:, input_ids.shape[1]:] |
|
if generate_ids[0][-1] == self.tokenizer.eos_token_id or generate_ids[0][-1] == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"): |
|
generate_ids = generate_ids[:, :-1] |
|
|
|
caption = self.tokenizer.batch_decode( |
|
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
)[0] |
|
|
|
return caption.strip() |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Generate a caption for a single image") |
|
parser.add_argument("image_path", type=str, help="Path to the input image") |
|
args = parser.parse_args() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s') |
|
|
|
|
|
if not any(args.image_path.lower().endswith(ext.lower()) for ext in IMAGE_EXTENSIONS): |
|
logging.error(f"Unsupported image extension. Supported extensions are: {IMAGE_EXTENSIONS}") |
|
sys.exit(1) |
|
|
|
|
|
captioner = SimpleCaptioner() |
|
captioner.load_models() |
|
|
|
|
|
caption = captioner.generate_caption(args.image_path) |
|
print(f"\nGenerated caption:\n{caption}") |
|
|
|
if __name__ == "__main__": |
|
main() |