Batch Decoding

#3
by vody-am - opened

Hi! Thanks for your charmingly compact model. For batch decoding, I did not see an example but figured out something roughly like:

# Instruct the model to create a caption 
prompt = "caption es"
prompts = [prompt]*4
images = [image]*4
model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[:, input_len:]
    decoded_batch = processor.batch_decode(generation, skip_special_tokens=True)
    for decoded in decoded_batch:
        print(decoded)

is that correct? An example of batch inference would be helpful!

Thank you.

It looks good, should work!

merve changed discussion status to closed

Sign up or log in to comment