create-caption / train.py
nroggendorff's picture
Update train.py
73418cf verified
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
import datasets
from datasets import Dataset
from typing import cast
import os
import shutil
import multiprocessing as mp
from PIL import Image
def load_model(model_name, device_id=0):
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
processor = AutoProcessor.from_pretrained(model_name)
processor.tokenizer.padding_side = "left"
model = AutoModelForImageTextToText.from_pretrained(
model_name,
quantization_config=bnb_config,
dtype=torch.bfloat16,
device_map={"": device_id},
attn_implementation="flash_attention_2",
)
return processor, model
def getTemplate(processor):
msg = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": "Describe the image concisely, and skip mentioning that it's illustrated or from anime.",
},
],
}
]
return processor.apply_chat_template(
msg, add_generation_prompt=True, tokenize=False
)
def preprocess_example_batch(examples, text):
processed_images = []
for image in examples["image"]:
if isinstance(image, Image.Image):
if image.mode != "RGB":
image = image.convert("RGB")
processed_images.append(image)
else:
raise ValueError("Image must be a PIL Image")
return {
"image": processed_images,
"text": [text] * len(processed_images),
}
def run_preprocessing(input_dataset, output_dir, num_proc=32, batch_size=100, start_idx=0, end_idx=None):
print("Loading dataset for preprocessing...")
ds = datasets.load_dataset(input_dataset, split="train")
if end_idx is None:
end_idx = len(ds)
print(f"Selecting range [{start_idx}:{end_idx}]...")
ds = ds.select(range(start_idx, end_idx))
print("Loading processor...")
processor = AutoProcessor.from_pretrained("datalab-to/chandra")
text = getTemplate(processor)
print("Running preprocessing...")
processed_ds = ds.map(
lambda ex: preprocess_example_batch(ex, text),
remove_columns=[col for col in ds.column_names if col not in ["image", "text"]],
num_proc=num_proc,
batched=True,
batch_size=batch_size,
)
print(f"Saving preprocessed dataset to {output_dir}...")
processed_ds.save_to_disk(output_dir)
print("Preprocessing done.")
def caption_batch(batch, processor, model):
images = batch["image"]
texts = batch["text"]
inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
inputs = {
k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items()
}
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
generated = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
)
decoded = processor.batch_decode(generated, skip_special_tokens=False)
captions = []
special_tokens = set(processor.tokenizer.all_special_tokens)
for d in decoded:
if "<|im_start|>assistant" in d:
d = d.split("<|im_start|>assistant")[-1]
for token in special_tokens:
d = d.replace(token, "")
d = d.strip()
captions.append(d)
return {
"text": captions,
}
def process_shard(
gpu_id, start, end, model_name, batch_size, input_dataset, output_file
):
try:
torch.cuda.set_device(gpu_id)
print(f"[GPU {gpu_id}] Loading model...", flush=True)
processor, model = load_model(model_name, gpu_id)
print(f"[GPU {gpu_id}] Loading data shard [{start}:{end}]...", flush=True)
loaded = datasets.load_from_disk(input_dataset).select(range(start, end))
shard = cast(Dataset, loaded)
print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
result = shard.map(
lambda batch: caption_batch(batch, processor, model),
batched=True,
batch_size=batch_size,
remove_columns=["text"],
)
print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
result.save_to_disk(output_file)
print(f"[GPU {gpu_id}] Done!", flush=True)
return output_file
except Exception as e:
print(f"[GPU {gpu_id}] Error: {e}", flush=True)
raise
def main():
mp.set_start_method("spawn", force=True)
init_stage = os.environ.get("INIT", "0")
input_dataset = "none-yet/anime-captions"
output_dataset = "nroggendorff/anime-captions"
model_name = "datalab-to/chandra"
batch_size = 20
print(f"Running stage INIT={init_stage}")
full_ds = datasets.load_dataset(input_dataset, split="train")
total_dataset_size = len(full_ds)
midpoint = total_dataset_size // 2
if init_stage == "0":
print(f"Stage 0: Processing first half [0:{midpoint}]")
preprocessed_dataset = "temp_preprocessed_0"
start_idx = 0
end_idx = midpoint
final_output = f"{output_dataset}_part0"
else:
print(f"Stage 1: Processing second half [{midpoint}:{total_dataset_size}]")
preprocessed_dataset = "temp_preprocessed_1"
start_idx = midpoint
end_idx = total_dataset_size
final_output = input_dataset
if not os.path.exists(preprocessed_dataset):
run_preprocessing(input_dataset, preprocessed_dataset, start_idx=start_idx, end_idx=end_idx)
print("Loading preprocessed dataset...")
ds = datasets.load_from_disk(preprocessed_dataset)
num_gpus = torch.cuda.device_count()
total_size = len(ds)
shard_size = total_size // num_gpus
print(f"Dataset size: {total_size}")
print(f"Using {num_gpus} GPUs")
print(f"Shard size: {shard_size}")
processes = []
temp_files = []
for i in range(num_gpus):
start = i * shard_size
end = start + shard_size if i < num_gpus - 1 else total_size
output_file = f"temp_shard_{init_stage}_{i}"
temp_files.append(output_file)
p = mp.Process(
target=process_shard,
args=(
i,
start,
end,
model_name,
batch_size,
preprocessed_dataset,
output_file,
),
)
p.start()
processes.append(p)
for p in processes:
p.join()
if p.exitcode != 0:
print(f"\nProcess failed with exit code {p.exitcode}", flush=True)
print("Terminating all processes...", flush=True)
for proc in processes:
if proc.is_alive():
proc.terminate()
for proc in processes:
proc.join()
raise RuntimeError(f"At least one process failed")
print("\nAll processes completed. Loading and concatenating results...")
shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
final_ds = datasets.concatenate_datasets(shards)
print(f"Final dataset size: {len(final_ds)}")
if init_stage == "0":
print(f"Pushing first half to {final_output}...")
final_ds.push_to_hub(final_output, create_pr=False)
else:
print("Loading first half from hub...")
first_half = datasets.load_dataset(f"{output_dataset}_part0", split="train")
print("Concatenating both halves...")
complete_ds = datasets.concatenate_datasets([first_half, final_ds])
print(f"Complete dataset size: {len(complete_ds)}")
print(f"Pushing complete dataset to {final_output} with PR...")
complete_ds.push_to_hub(final_output, create_pr=True)
print("Cleaning up temporary files...")
for f in temp_files:
if os.path.exists(f):
shutil.rmtree(f)
if os.path.exists(preprocessed_dataset):
shutil.rmtree(preprocessed_dataset)
print("Done!")
if __name__ == "__main__":
main()