Text Generation
Transformers
Safetensors
phi-msft
custom_code
Inference Endpoints

Can't get it to generate the EOS token and beam search is not supported

#3
by miguelcarv - opened

This is how I'm using the model

import torch
import transformers
import time

model = transformers.AutoModelForCausalLM.from_pretrained(
    "rhysjones/phi-2-orange",
    trust_remote_code=True
)
tokenizer = transformers.AutoTokenizer.from_pretrained("rhysjones/phi-2-orange")


SYSTEM_PROMPT = "You are an AI assistant. You will be given a task. You must generate a short answer."
input_text = f"""<|im_start|>system
You are a helpful assistant that gives short answers.<|im_end|>
<|im_start|>user
Give me the first 3 prime numbers.<|im_end|>
<|im_start|>assistant
"""

t1 = time.time()
with torch.no_grad():
    outputs = model.generate(
        tokenizer(input_text, return_tensors="pt")['input_ids'],
        max_new_tokens=200,
        num_beams = 1
    )
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print(time.time()-t1)

I think the beam search part is relatively easy to change, you just update modeling_phi.py, right? The EOS token is harder

Hey - there are a few recent updates to Phi-2 to base it on HF's transformer updates that post-date this model's Phi-2 version. One of the updates seems to fix the support for beam search. I'll look to update in the future with this and also allow it to run without needing trust_remote_code.

In the meantime, the other recent Phi-2 update is to make the eos token explicit. If you add this to your code, it should reliably finish on the EOS token:

import torch
import transformers
import time

model = transformers.AutoModelForCausalLM.from_pretrained(
    "rhysjones/phi-2-orange",
    trust_remote_code=True
)
tokenizer = transformers.AutoTokenizer.from_pretrained("rhysjones/phi-2-orange")


SYSTEM_PROMPT = "You are an AI assistant. You will be given a task. You must generate a short answer."
input_text = f"""<|im_start|>system
You are a helpful assistant that gives short answers.<|im_end|>
<|im_start|>user
Give me the first 3 prime numbers.<|im_end|>
<|im_start|>assistant
"""

generation_config = transformers.GenerationConfig(
    eos_token_id = 50256
)

t1 = time.time()
with torch.no_grad():
    outputs = model.generate(
        tokenizer(input_text, return_tensors="pt")['input_ids'],
        max_new_tokens=200,
        generation_config=generation_config,
        num_beams = 1
    )
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print(time.time()-t1)

Newer version also never stops, it is better to add a stoping criteria, for example for token <|im_end|>

Sign up or log in to comment