Batch generation with starcoder
#65
by
YoungEver
- opened
I want to use batch generation for better performance. I used the toy code below but found the results is wrong. A lot of <|endoftext|>
in the new generated tokens. So how to use batch generation with starcoder? Thanks~
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "/shared/models/huggingface/starcoder/"
device = "cuda" # for GPU usage or "cpu" for CPU usage
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # padding side = left for a decoder-only architecture
model = AutoModelForCausalLM.from_pretrained(checkpoint).half().to(device)
#prompt = "def print_hello_world():"
prompt = ["write a python function that caculate the max element in a list",
"write a c++ function that caculate the max element in a list",
]
inputs = tokenizer(prompt, padding=True, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=30)
generated_txts = tokenizer.batch_decode(outputs)
print(generated_txts)
The output:
>>> generated_txts
['<|endoftext|>write a python function that caculate the max element in a list\n\ndef max_element(list):\n max = list[0]\n for i in list:\n if i > max:\n max', 'write a c++ function that caculate the max element in a list\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#']
Have you solved this problem? I encountered the same problem
If it helps, I think the generations are accurate, the second one is just of poor quality. The first seems correct and hits your max_new_tokens
limit. The second is bad but those I don't think those are special tokens, and I get the same result when generating using only the second prompt:
o1 = model.generate(inputs['input_ids'][1][None, ...], max_new_tokens=30)
tokenizer.decode(o1[0])
Out[23]: 'write a c++ function that caculate the max element in a list\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#\n#'
In[26]: tokenizer.decode(o1[0], skip_special_tokens=True) == tokenizer.decode(o1[0], skip_special_tokens=False)
Out[26]: True