stop sequence

#2
by LavGadewar - opened

there is a parameter like "stop sequence " which is present in GPT- 3 is there are similar parameter in GPT_neo -2.7B model to stop the generation of tokens . and will not contain that sequence ?

LavGadewar changed discussion status to closed
LavGadewar changed discussion status to open

Yes, modify the code as needed:

import torch; device = torch.device("cuda")
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            return True
        return False

sequence = ['\n','\n\n', '.\n', '.  ', '. \n', '?', '!']

output = tokenizer.decode(model.generate( 
            **tokenizer( prompt, return_tensors='pt' ).to(device), 
            top_p=1,
            top_k=0,
            temperature=0.2,
            max_new_tokens=18,
            pad_token_id=50256,
            no_repeat_ngram_size = 2,
            stopping_criteria=StoppingCriteriaList([KeywordsStoppingCriteria([tokenizer.encode(w)[0] for w in sequence])]),
            early_stopping=True,
            do_sample=True,
            )[0],
            skip_special_tokens=True
        )

Sign up or log in to comment