prevent repetitions

#7
by vinnitu - opened

How to prevent repetitions like "It's working"?

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch

device = torch.device('cuda')

max_new_tokens = 200
model_name = "facebook/mbart-large-50-many-to-one-mmt"
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
tokenizer.src_lang = 'ko_KR'
input = 'μ„œλΉ„μŠ€ 쀑지가 계속 λœ¨λŠ”λ° 잘 된거 λ§žλ‚˜μš”?' # google translation is: 'The service stop keeps popping up, is it okay?'
encoded = tokenizer(input, return_tensors="pt").to(device)
generated_tokens = model.generate(**encoded)
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
print(return result[0])

And it's working, right? It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working. It's working

Hello @vinnitu

You can use no_repeat_ngram_size (doc) to prevent such repetition.

Code:

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch

device = torch.device('cuda')

max_new_tokens = 200
model_name = "facebook/mbart-large-50-many-to-one-mmt"
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
tokenizer.src_lang = 'ko_KR'

input = 'μ„œλΉ„μŠ€ 쀑지가 계속 λœ¨λŠ”λ° 잘 된거 λ§žλ‚˜μš”?' 
encoded = tokenizer(input, return_tensors="pt").to(device)

# Adjust the num_beams and no_repeat_ngram_size parameters
generated_tokens = model.generate(
    **encoded,
    num_beams=5,
    no_repeat_ngram_size=2,
    max_length=max_new_tokens,
)

result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
print(result[0])

Output:

And it's working, right?
vinnitu changed discussion status to closed

Sign up or log in to comment