can not generate with mode: Fill-in-the-middle

#8
by miraclezst - opened

my code as below:

pip install -q transformers

from transformers import AutoModelForCausalLM, AutoTokenizer
import os

checkpoint = "bigcode/starcoder"
device = "cuda" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint,use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True,load_in_8bit=True,device_map={"": 0})

input_text = "def print_hello_world():\n \n print('Hello world!')"
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0]))


output:
image.png

Does anyone know what is the reason for this?

Can you leave the code you want to fill in blank? It seems the code has been completed and no further action is required in your snippet.

The warning can be eliminated by passing an additional argument to model.generate:

outputs = model.generate(inputs, pad_token_id=tokenizer.eos_token_id)

Also, I strongly suspect that the example has <fim_suffix> and <fim_middle> swapped. When I do this:

input_text = "<fim_prefix>def print_hello_world():\n    <fim_middle>\n    print('Hello world!')<fim_suffix>"
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs, pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(outputs[0]))

I get this output:

<fim_prefix>def print_hello_world():
    <fim_middle>
    print('Hello world!')<fim_suffix>

if

That trailing if is a bit weird, but it seems not unusual for these models to throw in a stray token at the end; I think I've seen another model do it. Except for that, I gather this output is as intended.

Sign up or log in to comment