|
import os |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration |
|
import numpy as np |
|
from pathlib import Path |
|
from tqdm import tqdm |
|
import argparse |
|
import gc |
|
|
|
|
|
PRINT_CAPTIONS = False |
|
PRINT_CAPTIONING_STATUS = False |
|
OVERWRITE = True |
|
PREPEND_STRING = "" |
|
APPEND_STRING = "" |
|
STRIP_LINEBREAKS = True |
|
DEFAULT_SAVE_FORMAT = ".txt" |
|
|
|
|
|
MAX_WIDTH = 512 |
|
MAX_HEIGHT = 512 |
|
|
|
|
|
REPETITION_PENALTY = 1.3 |
|
TEMPERATURE = 0.7 |
|
TOP_K = 50 |
|
|
|
|
|
DEFAULT_INPUT_FOLDER = Path(__file__).parent / "input" |
|
DEFAULT_OUTPUT_FOLDER = DEFAULT_INPUT_FOLDER |
|
DEFAULT_PROMPT = "In two medium sentence, caption the key aspects of this image." |
|
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
parser = argparse.ArgumentParser(description="Process images and generate captions using Qwen model.") |
|
parser.add_argument("--input_folder", type=str, default=DEFAULT_INPUT_FOLDER, help="Path to the input folder containing images.") |
|
parser.add_argument("--output_folder", type=str, default=DEFAULT_OUTPUT_FOLDER, help="Path to the output folder for saving captions.") |
|
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="Prompt for generating the caption.") |
|
parser.add_argument("--save_format", type=str, default=DEFAULT_SAVE_FORMAT, help="Format for saving captions (e.g., .txt, .md, .json).") |
|
parser.add_argument("--max_width", type=int, default=MAX_WIDTH, help="Maximum width for resizing images (default: no resizing).") |
|
parser.add_argument("--max_height", type=int, default=MAX_HEIGHT, help="Maximum height for resizing images (default: no resizing).") |
|
parser.add_argument("--repetition_penalty", type=float, default=REPETITION_PENALTY, help="Penalty for repetition during caption generation (default: 1.10).") |
|
parser.add_argument("--temperature", type=float, default=TEMPERATURE, help="Sampling temperature for generation (default: 0.7).") |
|
parser.add_argument("--top_k", type=int, default=TOP_K, help="Top-k sampling during generation (default: 50).") |
|
return parser.parse_args() |
|
|
|
|
|
def filter_images_without_output(input_folder, save_format): |
|
images_to_caption = [] |
|
skipped_images = 0 |
|
total_images = 0 |
|
|
|
for root, _, files in os.walk(input_folder): |
|
for file in files: |
|
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')): |
|
total_images += 1 |
|
image_path = os.path.join(root, file) |
|
output_path = os.path.splitext(image_path)[0] + save_format |
|
if not OVERWRITE and os.path.exists(output_path): |
|
skipped_images += 1 |
|
else: |
|
images_to_caption.append(image_path) |
|
|
|
return images_to_caption, total_images, skipped_images |
|
|
|
|
|
def save_caption_to_file(image_path, caption, save_format): |
|
txt_file_path = os.path.splitext(image_path)[0] + save_format |
|
caption = PREPEND_STRING + caption + APPEND_STRING |
|
|
|
with open(txt_file_path, "w") as txt_file: |
|
txt_file.write(caption) |
|
|
|
if PRINT_CAPTIONING_STATUS: |
|
print(f"Caption for {os.path.abspath(image_path)} saved in {save_format} format.") |
|
|
|
|
|
|
|
def process_images_in_folder(images_to_caption, prompt, save_format, max_width=MAX_WIDTH, max_height=MAX_HEIGHT, repetition_penalty=REPETITION_PENALTY, temperature=TEMPERATURE, top_k=TOP_K): |
|
for image_path in tqdm(images_to_caption, desc="Processing Images"): |
|
try: |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
image = resize_image_proportionally(image, max_width, max_height) |
|
|
|
|
|
tags_filename = str(Path(image_path).with_suffix('.tag')) |
|
with open(tags_filename, "r") as file: |
|
tag_caption = file.read() |
|
|
|
prompt = f'In two medium sentence, caption the key aspects of this image. Be direct for NSFW, Don\'t be shy in your expressions like pussy, anal. Add approximate age for humans like: a child,young girl or 40 years old woman. You may use following tags as context if they are relevant: {tag_caption}' |
|
|
|
|
|
|
|
|
|
caption = qwen_caption(image, prompt, repetition_penalty, temperature, top_k) |
|
save_caption_to_file(image_path, caption, save_format) |
|
|
|
if PRINT_CAPTIONS: |
|
print(f"Caption for {os.path.abspath(image_path)}: {caption}") |
|
|
|
|
|
del image, tag_caption, caption |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
except Exception as e: |
|
print(f"Error processing {os.path.abspath(image_path)}: {str(e)}") |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
def resize_image_proportionally(image, max_width=None, max_height=None): |
|
""" |
|
If both max_width and max_height are provided, the image is resized to fit within both dimensions, |
|
keeping the aspect ratio intact. If only one dimension is provided, the image is resized based on that dimension. |
|
""" |
|
if (max_width is None or max_width <= 0) and (max_height is None or max_height <= 0): |
|
return image |
|
|
|
original_width, original_height = image.size |
|
aspect_ratio = original_width / original_height |
|
|
|
|
|
if max_width and not max_height: |
|
|
|
new_width = max_width |
|
new_height = int(new_width / aspect_ratio) |
|
elif max_height and not max_width: |
|
|
|
new_height = max_height |
|
new_width = int(new_height * aspect_ratio) |
|
else: |
|
|
|
new_width = max_width |
|
new_height = max_height |
|
|
|
|
|
if new_width / aspect_ratio > new_height: |
|
new_width = int(new_height * aspect_ratio) |
|
else: |
|
new_height = int(new_width / aspect_ratio) |
|
|
|
|
|
resized_image = image.resize((new_width, new_height)) |
|
return resized_image |
|
|
|
|
|
def qwen_caption(image, prompt, repetition_penalty=REPETITION_PENALTY, temperature=TEMPERATURE, top_k=TOP_K): |
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(np.uint8(image)) |
|
|
|
|
|
conversation = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "image", |
|
}, |
|
{"type": "text", "text": prompt}, |
|
], |
|
} |
|
] |
|
|
|
|
|
text_prompt = qwen_processor.apply_chat_template( |
|
conversation, add_generation_prompt=True |
|
) |
|
|
|
|
|
inputs = qwen_processor( |
|
text=[text_prompt], |
|
images=[image], |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
inputs = inputs.to("cuda") |
|
|
|
with torch.no_grad(): |
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
|
output_ids = qwen_model.generate( |
|
**inputs, |
|
max_new_tokens=384, |
|
do_sample=True, |
|
temperature=temperature, |
|
use_cache=True, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
) |
|
|
|
|
|
generated_ids_trimmed = [ |
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids) |
|
] |
|
|
|
|
|
output_text = qwen_processor.batch_decode( |
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
) |
|
|
|
|
|
if STRIP_LINEBREAKS: |
|
output_text[0] = output_text[0].replace('\n', ' ') |
|
|
|
|
|
del inputs, output_ids, generated_ids_trimmed |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
return output_text[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_arguments() |
|
input_folder = args.input_folder |
|
output_folder = args.output_folder |
|
prompt = args.prompt |
|
save_format = args.save_format |
|
max_width = args.max_width |
|
max_height = args.max_height |
|
repetition_penalty = args.repetition_penalty |
|
temperature = args.temperature |
|
top_k = args.top_k |
|
|
|
|
|
model_id = "Ertugrul/Qwen2-VL-7B-Captioner-Relaxed" |
|
|
|
|
|
images_to_caption, total_images, skipped_images = filter_images_without_output(input_folder, save_format) |
|
|
|
|
|
print(f"\nFound {total_images} image{'s' if total_images != 1 else ''}.") |
|
if not OVERWRITE: |
|
print(f"{skipped_images} image{'s' if skipped_images != 1 else ''} already have captions with format {save_format}, skipping.") |
|
print(f"\nCaptioning {len(images_to_caption)} image{'s' if len(images_to_caption) != 1 else ''}.\n\n") |
|
|
|
|
|
if len(images_to_caption) == 0: |
|
print("No images to process. Exiting.\n\n") |
|
else: |
|
|
|
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
model_id, torch_dtype=torch.bfloat16, device_map="auto" |
|
) |
|
qwen_processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
|
|
process_images_in_folder( |
|
images_to_caption, |
|
prompt, |
|
save_format, |
|
max_width=max_width, |
|
max_height=max_height, |
|
repetition_penalty=repetition_penalty, |
|
temperature=temperature, |
|
top_k=top_k |
|
) |
|
|