Spaces:
Paused
Paused
| import torch | |
| from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig | |
| import datasets | |
| from datasets import Dataset | |
| from typing import cast | |
| import os | |
| import shutil | |
| import multiprocessing as mp | |
| from PIL import Image | |
| def load_model(model_name, device_id=0): | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| processor.tokenizer.padding_side = "left" | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_name, | |
| quantization_config=bnb_config, | |
| dtype=torch.bfloat16, | |
| device_map={"": device_id}, | |
| attn_implementation="flash_attention_2", | |
| ) | |
| return processor, model | |
| def getTemplate(processor): | |
| msg = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| { | |
| "type": "text", | |
| "text": "Describe the image concisely, and skip mentioning that it's illustrated or from anime.", | |
| }, | |
| ], | |
| } | |
| ] | |
| return processor.apply_chat_template( | |
| msg, add_generation_prompt=True, tokenize=False | |
| ) | |
| def preprocess_example_batch(examples, text): | |
| processed_images = [] | |
| for image in examples["image"]: | |
| if isinstance(image, Image.Image): | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| processed_images.append(image) | |
| else: | |
| raise ValueError("Image must be a PIL Image") | |
| return { | |
| "image": processed_images, | |
| "text": [text] * len(processed_images), | |
| } | |
| def run_preprocessing(input_dataset, output_dir, num_proc=32, batch_size=100, start_idx=0, end_idx=None): | |
| print("Loading dataset for preprocessing...") | |
| ds = datasets.load_dataset(input_dataset, split="train") | |
| if end_idx is None: | |
| end_idx = len(ds) | |
| print(f"Selecting range [{start_idx}:{end_idx}]...") | |
| ds = ds.select(range(start_idx, end_idx)) | |
| print("Loading processor...") | |
| processor = AutoProcessor.from_pretrained("datalab-to/chandra") | |
| text = getTemplate(processor) | |
| print("Running preprocessing...") | |
| processed_ds = ds.map( | |
| lambda ex: preprocess_example_batch(ex, text), | |
| remove_columns=[col for col in ds.column_names if col not in ["image", "text"]], | |
| num_proc=num_proc, | |
| batched=True, | |
| batch_size=batch_size, | |
| ) | |
| print(f"Saving preprocessed dataset to {output_dir}...") | |
| processed_ds.save_to_disk(output_dir) | |
| print("Preprocessing done.") | |
| def caption_batch(batch, processor, model): | |
| images = batch["image"] | |
| texts = batch["text"] | |
| inputs = processor(text=texts, images=images, return_tensors="pt", padding=True) | |
| inputs = { | |
| k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items() | |
| } | |
| with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): | |
| generated = model.generate( | |
| **inputs, | |
| max_new_tokens=128, | |
| do_sample=False, | |
| ) | |
| decoded = processor.batch_decode(generated, skip_special_tokens=False) | |
| captions = [] | |
| special_tokens = set(processor.tokenizer.all_special_tokens) | |
| for d in decoded: | |
| if "<|im_start|>assistant" in d: | |
| d = d.split("<|im_start|>assistant")[-1] | |
| for token in special_tokens: | |
| d = d.replace(token, "") | |
| d = d.strip() | |
| captions.append(d) | |
| return { | |
| "text": captions, | |
| } | |
| def process_shard( | |
| gpu_id, start, end, model_name, batch_size, input_dataset, output_file | |
| ): | |
| try: | |
| torch.cuda.set_device(gpu_id) | |
| print(f"[GPU {gpu_id}] Loading model...", flush=True) | |
| processor, model = load_model(model_name, gpu_id) | |
| print(f"[GPU {gpu_id}] Loading data shard [{start}:{end}]...", flush=True) | |
| loaded = datasets.load_from_disk(input_dataset).select(range(start, end)) | |
| shard = cast(Dataset, loaded) | |
| print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True) | |
| result = shard.map( | |
| lambda batch: caption_batch(batch, processor, model), | |
| batched=True, | |
| batch_size=batch_size, | |
| remove_columns=["text"], | |
| ) | |
| print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True) | |
| result.save_to_disk(output_file) | |
| print(f"[GPU {gpu_id}] Done!", flush=True) | |
| return output_file | |
| except Exception as e: | |
| print(f"[GPU {gpu_id}] Error: {e}", flush=True) | |
| raise | |
| def main(): | |
| mp.set_start_method("spawn", force=True) | |
| init_stage = os.environ.get("INIT", "0") | |
| input_dataset = "none-yet/anime-captions" | |
| output_dataset = "nroggendorff/anime-captions" | |
| model_name = "datalab-to/chandra" | |
| batch_size = 20 | |
| print(f"Running stage INIT={init_stage}") | |
| full_ds = datasets.load_dataset(input_dataset, split="train") | |
| total_dataset_size = len(full_ds) | |
| midpoint = total_dataset_size // 2 | |
| if init_stage == "0": | |
| print(f"Stage 0: Processing first half [0:{midpoint}]") | |
| preprocessed_dataset = "temp_preprocessed_0" | |
| start_idx = 0 | |
| end_idx = midpoint | |
| final_output = f"{output_dataset}_part0" | |
| else: | |
| print(f"Stage 1: Processing second half [{midpoint}:{total_dataset_size}]") | |
| preprocessed_dataset = "temp_preprocessed_1" | |
| start_idx = midpoint | |
| end_idx = total_dataset_size | |
| final_output = input_dataset | |
| if not os.path.exists(preprocessed_dataset): | |
| run_preprocessing(input_dataset, preprocessed_dataset, start_idx=start_idx, end_idx=end_idx) | |
| print("Loading preprocessed dataset...") | |
| ds = datasets.load_from_disk(preprocessed_dataset) | |
| num_gpus = torch.cuda.device_count() | |
| total_size = len(ds) | |
| shard_size = total_size // num_gpus | |
| print(f"Dataset size: {total_size}") | |
| print(f"Using {num_gpus} GPUs") | |
| print(f"Shard size: {shard_size}") | |
| processes = [] | |
| temp_files = [] | |
| for i in range(num_gpus): | |
| start = i * shard_size | |
| end = start + shard_size if i < num_gpus - 1 else total_size | |
| output_file = f"temp_shard_{init_stage}_{i}" | |
| temp_files.append(output_file) | |
| p = mp.Process( | |
| target=process_shard, | |
| args=( | |
| i, | |
| start, | |
| end, | |
| model_name, | |
| batch_size, | |
| preprocessed_dataset, | |
| output_file, | |
| ), | |
| ) | |
| p.start() | |
| processes.append(p) | |
| for p in processes: | |
| p.join() | |
| if p.exitcode != 0: | |
| print(f"\nProcess failed with exit code {p.exitcode}", flush=True) | |
| print("Terminating all processes...", flush=True) | |
| for proc in processes: | |
| if proc.is_alive(): | |
| proc.terminate() | |
| for proc in processes: | |
| proc.join() | |
| raise RuntimeError(f"At least one process failed") | |
| print("\nAll processes completed. Loading and concatenating results...") | |
| shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files] | |
| final_ds = datasets.concatenate_datasets(shards) | |
| print(f"Final dataset size: {len(final_ds)}") | |
| if init_stage == "0": | |
| print(f"Pushing first half to {final_output}...") | |
| final_ds.push_to_hub(final_output, create_pr=False) | |
| else: | |
| print("Loading first half from hub...") | |
| first_half = datasets.load_dataset(f"{output_dataset}_part0", split="train") | |
| print("Concatenating both halves...") | |
| complete_ds = datasets.concatenate_datasets([first_half, final_ds]) | |
| print(f"Complete dataset size: {len(complete_ds)}") | |
| print(f"Pushing complete dataset to {final_output} with PR...") | |
| complete_ds.push_to_hub(final_output, create_pr=True) | |
| print("Cleaning up temporary files...") | |
| for f in temp_files: | |
| if os.path.exists(f): | |
| shutil.rmtree(f) | |
| if os.path.exists(preprocessed_dataset): | |
| shutil.rmtree(preprocessed_dataset) | |
| print("Done!") | |
| if __name__ == "__main__": | |
| main() |