Spaces:
Runtime error
Runtime error
| import transformers | |
| import string | |
| model_names = ['microsoft/GODEL-v1_1-large-seq2seq', | |
| 'facebook/blenderbot-1B-distill', | |
| 'facebook/blenderbot_small-90M'] | |
| tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]), | |
| transformers.BlenderbotTokenizer.from_pretrained(model_names[1]), | |
| transformers.BlenderbotSmallTokenizer.from_pretrained(model_names[2])] | |
| model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]), | |
| transformers.BlenderbotForConditionalGeneration.from_pretrained(model_names[1]), | |
| transformers.BlenderbotSmallForConditionalGeneration.from_pretrained(model_names[2])] | |
| def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300): | |
| text = f'{context} {text}' | |
| if 'GODEL' in model_name: | |
| text = f'Instruction: you need to response discreetly. [CONTEXT] {text}' | |
| text.replace('\t', ' EOS ') | |
| else: | |
| text = text.replace('\t', '\n') | |
| input_ids = tokenizer(text, return_tensors="pt").input_ids | |
| outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True) | |
| output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return capitalization(output) | |
| def capitalization(line): | |
| line, end = line[:-1], line[-1] | |
| for mark in '.?!': | |
| line = f'{mark} '.join([part.strip()[0].upper() + part.strip()[1:] for part in line.split(mark) if len(part) > 1]) | |
| line = ' '.join([word.capitalize() if word.translate(str.maketrans('', '', string.punctuation)) == 'i' | |
| else word for word in line.split()]) | |
| return line.replace(' i\'', ' I\'') + end |