Batch pred slower than single image inference on 1x4090?

#5
by 04RR - opened
    image_paths,
    model,
    processor,
    prompt="<md>",
    batch_size=1,
):
    device = "cuda"
    dtype = torch.bfloat16

    outputs = []
    num_batches = (len(image_paths) + batch_size - 1) // batch_size

    for i in tqdm(range(num_batches)):
        batch_paths = image_paths[i * batch_size : (i + 1) * batch_size]
        images = [Image.open(path) for path in batch_paths]

        inputs = processor(
            text=[prompt] * len(images),
            images=images,
            return_tensors="pt",
            padding=True,
        )

        inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
        inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)

        try:
            del inputs["width"]
            del inputs["height"]
        except KeyError:
            pass

        generated_ids = model.generate(
            **inputs,
            max_new_tokens=4096,
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
        outputs.extend(generated_text)

    return outputs

This is the code I used and for some reason batch predictions are wayy slower than just doing it single image at a time. The images are of shape 1700x2000 but i assume they are getting resized by the image processor.

Any fixes?

Thank you for this model, it's amazing for it's size!

04RR changed discussion title from Batch pred slower than single on 1x4090? to Batch pred slower than single image inference on 1x4090?

Sign up or log in to comment