Diffusers
Safetensors
sdxs / qwen2-vl-7b.py
recoilme's picture
imagenet_ep2
0913243
import os
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
import numpy as np
from pathlib import Path
from tqdm import tqdm
import argparse
import gc
# Configuration options
PRINT_CAPTIONS = False # Print captions to the console during inference
PRINT_CAPTIONING_STATUS = False # Print captioning file status to the console
OVERWRITE = True # Allow overwriting existing caption files
PREPEND_STRING = "" # Prefix string to prepend to the generated caption
APPEND_STRING = "" # Suffix string to append to the generated caption
STRIP_LINEBREAKS = True # Remove line breaks from generated captions before saving
DEFAULT_SAVE_FORMAT = ".txt" # Default format for saving captions
# Image resizing options
MAX_WIDTH = 512 # Set to 0 or less to ignore
MAX_HEIGHT = 512 # Set to 0 or less to ignore
# Generation parameters
REPETITION_PENALTY = 1.3 # Penalty for repeating phrases, float ~1.5
TEMPERATURE = 0.7 # Sampling temperature to control randomness
TOP_K = 50 # Top-k sampling to limit number of potential next tokens
# Default values for input folder, output folder, prompt, and save format
DEFAULT_INPUT_FOLDER = Path(__file__).parent / "input"
DEFAULT_OUTPUT_FOLDER = DEFAULT_INPUT_FOLDER
DEFAULT_PROMPT = "In two medium sentence, caption the key aspects of this image."
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Function to parse command-line arguments
def parse_arguments():
parser = argparse.ArgumentParser(description="Process images and generate captions using Qwen model.")
parser.add_argument("--input_folder", type=str, default=DEFAULT_INPUT_FOLDER, help="Path to the input folder containing images.")
parser.add_argument("--output_folder", type=str, default=DEFAULT_OUTPUT_FOLDER, help="Path to the output folder for saving captions.")
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="Prompt for generating the caption.")
parser.add_argument("--save_format", type=str, default=DEFAULT_SAVE_FORMAT, help="Format for saving captions (e.g., .txt, .md, .json).")
parser.add_argument("--max_width", type=int, default=MAX_WIDTH, help="Maximum width for resizing images (default: no resizing).")
parser.add_argument("--max_height", type=int, default=MAX_HEIGHT, help="Maximum height for resizing images (default: no resizing).")
parser.add_argument("--repetition_penalty", type=float, default=REPETITION_PENALTY, help="Penalty for repetition during caption generation (default: 1.10).")
parser.add_argument("--temperature", type=float, default=TEMPERATURE, help="Sampling temperature for generation (default: 0.7).")
parser.add_argument("--top_k", type=int, default=TOP_K, help="Top-k sampling during generation (default: 50).")
return parser.parse_args()
# Function to ignore images that don't have output files yet
def filter_images_without_output(input_folder, save_format):
images_to_caption = []
skipped_images = 0
total_images = 0
for root, _, files in os.walk(input_folder):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
total_images += 1
image_path = os.path.join(root, file)
output_path = os.path.splitext(image_path)[0] + save_format
if not OVERWRITE and os.path.exists(output_path):
skipped_images += 1
else:
images_to_caption.append(image_path)
return images_to_caption, total_images, skipped_images
# Function to save caption to a file
def save_caption_to_file(image_path, caption, save_format):
txt_file_path = os.path.splitext(image_path)[0] + save_format # Same name, but with chosen save format
caption = PREPEND_STRING + caption + APPEND_STRING # Apply prepend/append strings
with open(txt_file_path, "w") as txt_file:
txt_file.write(caption)
if PRINT_CAPTIONING_STATUS:
print(f"Caption for {os.path.abspath(image_path)} saved in {save_format} format.")
# Function to process all images recursively in a folder
def process_images_in_folder(images_to_caption, prompt, save_format, max_width=MAX_WIDTH, max_height=MAX_HEIGHT, repetition_penalty=REPETITION_PENALTY, temperature=TEMPERATURE, top_k=TOP_K):
for image_path in tqdm(images_to_caption, desc="Processing Images"):
try:
# Загружаем и обрабатываем изображение
image = Image.open(image_path).convert("RGB")
image = resize_image_proportionally(image, max_width, max_height)
# Читаем теги
tags_filename = str(Path(image_path).with_suffix('.tag'))
with open(tags_filename, "r") as file:
tag_caption = file.read()
#prompt = f'In one medium sentence, caption the key aspects of this image. Be direct for NSFW. You may use following tags as context: {tag_caption}'
prompt = f'In two medium sentence, caption the key aspects of this image. Be direct for NSFW, Don\'t be shy in your expressions like pussy, anal. Add approximate age for humans like: a child,young girl or 40 years old woman. You may use following tags as context if they are relevant: {tag_caption}'
#print("prompt",prompt)
# Генерируем подпись
caption = qwen_caption(image, prompt, repetition_penalty, temperature, top_k)
save_caption_to_file(image_path, caption, save_format)
if PRINT_CAPTIONS:
print(f"Caption for {os.path.abspath(image_path)}: {caption}")
# Освобождаем память
del image, tag_caption, caption
torch.cuda.empty_cache()
gc.collect()
except Exception as e:
print(f"Error processing {os.path.abspath(image_path)}: {str(e)}")
torch.cuda.empty_cache()
gc.collect()
# Resize the image proportionally based on max width and/or max height.
def resize_image_proportionally(image, max_width=None, max_height=None):
"""
If both max_width and max_height are provided, the image is resized to fit within both dimensions,
keeping the aspect ratio intact. If only one dimension is provided, the image is resized based on that dimension.
"""
if (max_width is None or max_width <= 0) and (max_height is None or max_height <= 0):
return image # No resizing if both dimensions are not provided or set to 0 or less
original_width, original_height = image.size
aspect_ratio = original_width / original_height
# Determine the new dimensions
if max_width and not max_height:
# Resize based on width
new_width = max_width
new_height = int(new_width / aspect_ratio)
elif max_height and not max_width:
# Resize based on height
new_height = max_height
new_width = int(new_height * aspect_ratio)
else:
# Resize based on both width and height, keeping the aspect ratio
new_width = max_width
new_height = max_height
# Adjust the dimensions proportionally to the aspect ratio
if new_width / aspect_ratio > new_height:
new_width = int(new_height * aspect_ratio)
else:
new_height = int(new_width / aspect_ratio)
# Resize the image using LANCZOS (equivalent to ANTIALIAS in older versions)
resized_image = image.resize((new_width, new_height))
return resized_image
# Generate a caption for the provided image using the Ertugrul/Qwen2-VL-7B-Captioner-Relaxed model
def qwen_caption(image, prompt, repetition_penalty=REPETITION_PENALTY, temperature=TEMPERATURE, top_k=TOP_K):
if not isinstance(image, Image.Image):
image = Image.fromarray(np.uint8(image))
# Prepare the conversation content, which includes the image and the text prompt
conversation = [
{
"role": "user",
"content": [
{
"type": "image",
},
{"type": "text", "text": prompt},
],
}
]
# Apply the chat template to format the message for processing
text_prompt = qwen_processor.apply_chat_template(
conversation, add_generation_prompt=True
)
# Prepare the inputs for the model, padding as necessary and converting to tensors
inputs = qwen_processor(
text=[text_prompt],
images=[image],
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output_ids = qwen_model.generate(
**inputs,
max_new_tokens=384,
do_sample=True,
temperature=temperature,
use_cache=True,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
# Trim the generated IDs to remove the input part from the output
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids)
]
# Decode the trimmed output into text, skipping special tokens
output_text = qwen_processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
# Strip line breaks if the option is enabled
if STRIP_LINEBREAKS:
output_text[0] = output_text[0].replace('\n', ' ')
# Освобождаем память
del inputs, output_ids, generated_ids_trimmed
torch.cuda.empty_cache()
gc.collect()
return output_text[0]
# Run the script
if __name__ == "__main__":
args = parse_arguments()
input_folder = args.input_folder
output_folder = args.output_folder
prompt = args.prompt
save_format = args.save_format
max_width = args.max_width
max_height = args.max_height
repetition_penalty = args.repetition_penalty
temperature = args.temperature
top_k = args.top_k
# Define model_id
model_id = "Ertugrul/Qwen2-VL-7B-Captioner-Relaxed"
# Filter images before loading the model
images_to_caption, total_images, skipped_images = filter_images_without_output(input_folder, save_format)
# Print summary of found, skipped, and to-be-processed images
print(f"\nFound {total_images} image{'s' if total_images != 1 else ''}.")
if not OVERWRITE:
print(f"{skipped_images} image{'s' if skipped_images != 1 else ''} already have captions with format {save_format}, skipping.")
print(f"\nCaptioning {len(images_to_caption)} image{'s' if len(images_to_caption) != 1 else ''}.\n\n")
# Only load the model if there are images to caption
if len(images_to_caption) == 0:
print("No images to process. Exiting.\n\n")
else:
# Initialize the Ertugrul/Qwen2-VL-7B-Captioner-Relaxed model
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
qwen_processor = AutoProcessor.from_pretrained(model_id)
# Process the images with optional resizing and caption generation
process_images_in_folder(
images_to_caption,
prompt,
save_format,
max_width=max_width,
max_height=max_height,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_k=top_k
)