Completion stuck in a loop

#1
by ParkerBurchett - opened

Has anybody checked how often the model gets stuck in loops?

image.png

Hi, @ParkerBurchett . In the inference widget it doesn't seem to work too fine. I didn't tested it there because it is slowly. Can you test in a Colab? I will add the recommend generation params

Sure I got it up in Colab. It's still in a loop
https://colab.research.google.com/drive/1YcH40MhJfYiFd39Q361SC6cauq2YFx9v?usp=sharing

prompt = f"McDonald's hamburger promotion on a red billboard, white lettering, "
input_ids = tokenizer(prompt, return_tensors="pt").to('cuda')
sample = model.generate(**input_ids, max_new_tokens=50)
tokenizer.decode(sample[0])


McDonald's hamburger promotion on a red billboard, white lettering,  digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital bill

@mrm8488 which are the generation params you are using?

I had the same issue using the default parameters but I put a high repetition_penalty to fix it.

In Colab it still gets stuck in a loop

import torch
from transformers import BloomTokenizerFast, BloomForCausalLM

device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt = 'mrm8488/bloom-560m-finetuned-sd-prompts' 

tokenizer = BloomTokenizerFast.from_pretrained(ckpt)
model = BloomForCausalLM.from_pretrained(ckpt).to(device)

def generate_prompt(text):
    torch.cuda.empty_cache()
    inputs = tokenizer(text, return_tensors='pt')
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)
    output = model.generate(input_ids, attention_mask=attention_mask, max_length=512, eos_token_id=tokenizer.eos_token_id)

    return tokenizer.decode(output[0], skip_special_tokens=False)
    
text = "<s>Prompt: pikachu dinning in the eiffel tower"
text2 = f"<s>Prompt: McDonald's hamburger promotion on a red billboard, white lettering, "

generate_prompt(text2)

Returns

<s>Prompt: McDonald's hamburger promotion on a red billboard, white lettering,  digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital billboard, digital bil

@undefined2 changing the repetition_penalty to 1.05 worked for me too.

def generate_prompt(text):
    inputs = tokenizer(text, return_tensors='pt')
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)
    output = model.generate(input_ids, attention_mask=attention_mask, repetition_penalty=1.05, max_length=512, eos_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(output[0], skip_special_tokens=False)
    

text2 = f"<s>Prompt: McDonald's hamburger promotion on a red billboard, white lettering,"
generate_prompt(text2)
<s>Prompt: McDonald's hamburger promotion on a red billboard, white lettering, advertisement with posters and flyrets in the style of artgerm.</s>

@mrm8488 Maybe change the default repetition_penalty to slightly over 1 in the example code?

I got the problem when running the code, it generates the exactly same sentence every time. What should I do?

Sign up or log in to comment