Batch generation with starcoder

#65
by YoungEver - opened

I want to use batch generation for better performance. I used the toy code below but found the results is wrong. A lot of <|endoftext|> in the new generated tokens. So how to use batch generation with starcoder? Thanks~

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "/shared/models/huggingface/starcoder/"
device = "cuda" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # padding side = left for a decoder-only architecture
model = AutoModelForCausalLM.from_pretrained(checkpoint).half().to(device)

#prompt = "def print_hello_world():"
prompt = ["write a python function that caculate the max element in a list",
          "write a c++ function that caculate the max element in a list",
         ]
inputs = tokenizer(prompt, padding=True, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=30)
generated_txts = tokenizer.batch_decode(outputs)
print(generated_txts)

The output:

>>> generated_txts
['<|endoftext|>write a python function that caculate the max element in a list\n\ndef max_element(list):\n    max = list[0]\n    for i in list:\n        if i > max:\n            max', 'write a c++ function that caculate the max element in a list\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#']

Have you solved this problem? I encountered the same problem

If it helps, I think the generations are accurate, the second one is just of poor quality. The first seems correct and hits your max_new_tokens limit. The second is bad but those I don't think those are special tokens, and I get the same result when generating using only the second prompt:

o1 = model.generate(inputs['input_ids'][1][None, ...], max_new_tokens=30)
tokenizer.decode(o1[0])

Out[23]: 'write a c++ function that caculate the max element in a list\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#'

In[26]: tokenizer.decode(o1[0], skip_special_tokens=True) == tokenizer.decode(o1[0], skip_special_tokens=False)
Out[26]: True

Sign up or log in to comment