#!/usr/bin/env python3 """ DeepSeek-OCR Dataset Processing Minimal adaptation of official run_dpsk_ocr_eval_batch.py for dataset processing """ import argparse import json import os import sys from datetime import datetime from concurrent.futures import ThreadPoolExecutor import torch if torch.version.cuda == '11.8': os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas" os.environ['VLLM_USE_V1'] = '0' from vllm import LLM, SamplingParams from vllm.model_executor.models.registry import ModelRegistry from PIL import Image, ImageOps from tqdm.auto import tqdm from datasets import load_dataset from huggingface_hub import login # Import DeepSeek-OCR modules (unchanged from original) from deepseek_ocr import DeepseekOCRForCausalLM from process.ngram_norepeat import NoRepeatNGramLogitsProcessor from process.image_process import DeepseekOCRProcessor from config import MODEL_PATH, PROMPT, CROP_MODE # Register custom model (unchanged from original) ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) def check_cuda(): """Check CUDA availability""" if not torch.cuda.is_available(): print("ERROR: CUDA is not available. This script requires a GPU.") sys.exit(1) print(f"Using GPU: {torch.cuda.get_device_name(0)}") def process_single_image(image): """Preprocess single image (unchanged from official batch script)""" prompt_in = PROMPT cache_item = { "prompt": prompt_in, "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images( images=[image], bos=True, eos=True, cropping=CROP_MODE )}, } return cache_item def main(args): """Main processing function""" check_cuda() # Enable HF_TRANSFER os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Login to HF if token provided HF_TOKEN = args.hf_token or os.environ.get("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) # Load dataset print(f"Loading dataset: {args.input_dataset}") dataset = load_dataset(args.input_dataset, split=args.split) if args.image_column not in dataset.column_names: print(f"ERROR: Column '{args.image_column}' not found") print(f"Available columns: {dataset.column_names}") sys.exit(1) # Shuffle if requested if args.shuffle: print(f"Shuffling with seed {args.seed}") dataset = dataset.shuffle(seed=args.seed) # Limit samples if requested if args.max_samples: dataset = dataset.select(range(min(args.max_samples, len(dataset)))) print(f"Processing {len(dataset)} samples") # Initialize vLLM engine (UNCHANGED from official batch script) print("Initializing vLLM engine...") llm = LLM( model=MODEL_PATH, hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]}, block_size=256, enforce_eager=False, trust_remote_code=True, max_model_len=args.max_model_len, swap_space=0, max_num_seqs=args.max_num_seqs, tensor_parallel_size=1, gpu_memory_utilization=args.gpu_memory_utilization, ) # Sampling params (UNCHANGED from official batch script) logits_processors = [NoRepeatNGramLogitsProcessor( ngram_size=40, window_size=90, whitelist_token_ids={128821, 128822} )] sampling_params = SamplingParams( temperature=0.0, max_tokens=args.max_tokens, logits_processors=logits_processors, skip_special_tokens=False, ) # Load and preprocess images print(f"Loading images from dataset...") images = [] for idx in range(len(dataset)): try: image = dataset[idx][args.image_column] if not isinstance(image, Image.Image): image = Image.open(image) if isinstance(image, str) else image image = ImageOps.exif_transpose(image.convert('RGB')) images.append(image) except Exception as e: print(f"Error loading image {idx}: {e}") images.append(None) # Preprocess images in parallel (UNCHANGED from official batch script) print(f"Preprocessing images...") with ThreadPoolExecutor(max_workers=args.num_workers) as executor: batch_inputs = list(tqdm( executor.map(lambda img: process_single_image(img) if img else None, images), total=len(images), desc="Pre-processing images" )) # Filter out None entries and track their indices valid_indices = [i for i, inp in enumerate(batch_inputs) if inp is not None] valid_batch_inputs = [inp for inp in batch_inputs if inp is not None] # Batch inference (UNCHANGED from official batch script) print(f"Running batch inference on {len(valid_batch_inputs)} images...") outputs_list = llm.generate( valid_batch_inputs, sampling_params=sampling_params ) # Extract results all_markdown = ["[OCR FAILED]"] * len(dataset) for idx, output in zip(valid_indices, outputs_list): all_markdown[idx] = output.outputs[0].text.strip() # Add markdown column print("Adding markdown column...") dataset = dataset.add_column("markdown", all_markdown) # Handle inference_info if "inference_info" in dataset.column_names: try: existing_info = json.loads(dataset[0]["inference_info"]) if not isinstance(existing_info, list): existing_info = [existing_info] except: existing_info = [] dataset = dataset.remove_columns(["inference_info"]) else: existing_info = [] new_info = { "column_name": "markdown", "model_id": MODEL_PATH, "processing_date": datetime.now().isoformat(), "prompt": PROMPT, "max_tokens": args.max_tokens, "max_model_len": args.max_model_len, "gpu_memory_utilization": args.gpu_memory_utilization, "max_num_seqs": args.max_num_seqs, "script": "process_dataset.py", "implementation": "vllm-batch (official deepseek batch code)", } existing_info.append(new_info) info_json = json.dumps(existing_info, ensure_ascii=False) dataset = dataset.add_column("inference_info", [info_json] * len(dataset)) # Push to hub print(f"Pushing to {args.output_dataset}") dataset.push_to_hub(args.output_dataset, private=args.private, token=HF_TOKEN) print("✅ Complete!") print(f"Dataset: https://huggingface.co/datasets/{args.output_dataset}") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Process images through DeepSeek-OCR" ) parser.add_argument("input_dataset", help="Input dataset ID") parser.add_argument("output_dataset", help="Output dataset ID") parser.add_argument("--image-column", default="image", help="Image column name") parser.add_argument("--split", default="train", help="Dataset split") parser.add_argument("--max-samples", type=int, help="Limit number of samples") parser.add_argument("--shuffle", action="store_true", help="Shuffle dataset") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--max-model-len", type=int, default=8192) parser.add_argument("--max-tokens", type=int, default=8192) parser.add_argument("--gpu-memory-utilization", type=float, default=0.75) parser.add_argument("--max-num-seqs", type=int, default=100, help="Max concurrent sequences") parser.add_argument("--num-workers", type=int, default=64, help="Image preprocessing workers") parser.add_argument("--hf-token", help="HF API token") parser.add_argument("--private", action="store_true", help="Make output private") args = parser.parse_args() main(args)