toolkit / caption /joy_single.py
k4d3's picture
Enhance image format support in JoyCaption
f8c4eca
raw
history blame
8.48 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Simplified JoyCaption - Generates captions for a single image input
"""
import argparse
from pathlib import Path
from PIL import Image
import torch
import torchvision.transforms.functional as TVF
from transformers import (
AutoModel,
AutoTokenizer,
AutoModelForCausalLM,
)
from torch import nn
import logging
import sys
CLIP_PATH = "google/siglip-so400m-patch14-384"
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
CHECKPOINT_PATH = Path(__file__).resolve().parent / "cgrkzexw-599808"
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
try:
import pillow_avif
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
except:
pass
# JPEG-XL on Linux
try:
from jxlpy import JXLImagePlugin
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
# JPEG-XL on Windows
try:
import pillow_jxl
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
class ImageAdapter(nn.Module):
def __init__(
self,
input_features: int,
output_features: int,
ln1: bool,
pos_emb: bool,
num_image_tokens: int,
deep_extract: bool,
):
super().__init__()
self.deep_extract = deep_extract
if self.deep_extract:
input_features = input_features * 5
self.linear1 = nn.Linear(input_features, output_features)
self.activation = nn.GELU()
self.linear2 = nn.Linear(output_features, output_features)
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
self.pos_emb = (
None
if not pos_emb
else nn.Parameter(torch.zeros(num_image_tokens, input_features))
)
self.other_tokens = nn.Embedding(3, output_features)
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)
def forward(self, vision_outputs):
if self.deep_extract:
x = torch.concat(
(
vision_outputs[-2],
vision_outputs[3],
vision_outputs[7],
vision_outputs[13],
vision_outputs[20],
),
dim=-1,
)
else:
x = vision_outputs[-2]
x = self.ln1(x)
if self.pos_emb is not None:
x = x + self.pos_emb
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
other_tokens = self.other_tokens(
torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(
x.shape[0], -1
)
)
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
return x
class SimpleCaptioner:
def __init__(self):
self.clip_model = None
self.text_model = None
self.image_adapter = None
self.tokenizer = None
def load_models(self):
logging.info("Loading CLIP")
self.clip_model = AutoModel.from_pretrained(CLIP_PATH)
self.clip_model = self.clip_model.vision_model
if (CHECKPOINT_PATH / "clip_model.pt").exists():
checkpoint = torch.load(
CHECKPOINT_PATH / "clip_model.pt", map_location="cpu"
)
checkpoint = {
k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()
}
self.clip_model.load_state_dict(checkpoint)
self.clip_model.eval()
self.clip_model.requires_grad_(False)
self.clip_model.to("cuda")
logging.info("Loading tokenizer and LLM")
self.tokenizer = AutoTokenizer.from_pretrained(
CHECKPOINT_PATH / "text_model", use_fast=True
)
if (CHECKPOINT_PATH / "text_model").exists():
self.text_model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16
)
else:
self.text_model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16
)
self.text_model.eval()
logging.info("Loading image adapter")
self.image_adapter = ImageAdapter(
self.clip_model.config.hidden_size,
self.text_model.config.hidden_size,
False,
False,
38,
False,
)
self.image_adapter.load_state_dict(
torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")
)
self.image_adapter.eval()
self.image_adapter.to("cuda")
@torch.no_grad()
def generate_caption(self, image_path: str) -> str:
# Load and preprocess image
input_image = Image.open(image_path).convert("RGB")
image = input_image.resize((384, 384), Image.LANCZOS)
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]).to("cuda")
# Generate image embeddings
vision_outputs = self.clip_model(pixel_values=pixel_values, output_hidden_states=True)
embedded_images = self.image_adapter(vision_outputs.hidden_states)
# Prepare prompt
prompt = "Write a descriptive caption for this image in a formal tone."
convo = [
{"role": "system", "content": "You are a helpful image captioner."},
{"role": "user", "content": prompt},
]
convo_string = self.tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=True
)
# Tokenize and prepare inputs
convo_tokens = self.tokenizer.encode(
convo_string, return_tensors="pt", add_special_tokens=False
)
prompt_tokens = self.tokenizer.encode(
prompt, return_tensors="pt", add_special_tokens=False
)
eot_id_indices = (
(convo_tokens == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
.nonzero(as_tuple=True)[0]
.tolist()
)
preamble_len = eot_id_indices[1] - prompt_tokens.shape[1]
convo_embeds = self.text_model.model.embed_tokens(convo_tokens.to("cuda"))
input_embeds = torch.cat(
[
convo_embeds[:, :preamble_len],
embedded_images.to(dtype=convo_embeds.dtype),
convo_embeds[:, preamble_len:],
],
dim=1,
)
input_ids = torch.cat(
[
convo_tokens[:, :preamble_len],
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long, device="cuda"),
convo_tokens[:, preamble_len:],
],
dim=1,
)
attention_mask = torch.ones_like(input_ids)
# Generate caption
generate_ids = self.text_model.generate(
input_ids,
inputs_embeds=input_embeds,
attention_mask=attention_mask,
max_new_tokens=300,
do_sample=True,
repetition_penalty=1.2,
)
# Decode caption
generate_ids = generate_ids[:, input_ids.shape[1]:]
if generate_ids[0][-1] == self.tokenizer.eos_token_id or generate_ids[0][-1] == self.tokenizer.convert_tokens_to_ids("<|eot_id|>"):
generate_ids = generate_ids[:, :-1]
caption = self.tokenizer.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)[0]
return caption.strip()
def main():
parser = argparse.ArgumentParser(description="Generate a caption for a single image")
parser.add_argument("image_path", type=str, help="Path to the input image")
args = parser.parse_args()
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s')
# Validate image extension
if not any(args.image_path.lower().endswith(ext.lower()) for ext in IMAGE_EXTENSIONS):
logging.error(f"Unsupported image extension. Supported extensions are: {IMAGE_EXTENSIONS}")
sys.exit(1)
# Initialize and load the captioner
captioner = SimpleCaptioner()
captioner.load_models()
# Generate and print caption
caption = captioner.generate_caption(args.image_path)
print(f"\nGenerated caption:\n{caption}")
if __name__ == "__main__":
main()