|
import torch |
|
from transformers import AutoModelForSeq2SeqLM |
|
from transformers import AutoTokenizer |
|
from transformers import pipeline |
|
|
|
from pprint import pprint |
|
import re |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def skip_special_tokens_and_prettify(text, tokenizer): |
|
recipe_maps = {"<sep>": "--", "<section>": "\n"} |
|
recipe_map_pattern = "|".join(map(re.escape, recipe_maps.keys())) |
|
|
|
text = re.sub( |
|
recipe_map_pattern, |
|
lambda m: recipe_maps[m.group()], |
|
re.sub("|".join(tokenizer.all_special_tokens), "", text) |
|
) |
|
|
|
data = {"title": "", "ingredients": [], "directions": []} |
|
for section in text.split("\n"): |
|
section = section.strip() |
|
section = section.strip() |
|
if section.startswith("title:"): |
|
data["title"] = section.replace("title:", "").strip() |
|
elif section.startswith("ingredients:"): |
|
data["ingredients"] = [s.strip() for s in section.replace("ingredients:", "").split('--')] |
|
elif section.startswith("directions:"): |
|
data["directions"] = [s.strip() for s in section.replace("directions:", "").split('--')] |
|
else: |
|
pass |
|
|
|
return data |
|
|
|
|
|
def post_generator(output_tensors, tokenizer): |
|
output_tensors = [output_tensors[i]["generated_token_ids"] for i in range(len(output_tensors))] |
|
texts = tokenizer.batch_decode(output_tensors, skip_special_tokens=False) |
|
texts = [skip_special_tokens_and_prettify(text, tokenizer) for text in texts] |
|
return texts |
|
|
|
|
|
|
|
generate_kwargs = { |
|
"max_length": 512, |
|
"min_length": 64, |
|
"no_repeat_ngram_size": 3, |
|
"early_stopping": True, |
|
"num_beams": 5, |
|
"length_penalty": 1.5, |
|
"num_return_sequences": 2 |
|
} |
|
items = "potato, cheese" |
|
|
|
|
|
|
|
|