Finalize model
Browse files- README.md +144 -17
- config.json +3 -3
- {notes → data}/test.csv +0 -0
- data/test_generated.csv +3 -0
- src/evaluation.py +124 -0
- {notes → src}/flax_to_pytorch.py +0 -0
- {notes → src}/flax_to_tf.py +0 -0
- {notes → src}/generation.py +0 -0
- src/prediction.py +102 -0
README.md
CHANGED
@@ -7,17 +7,13 @@ tags:
|
|
7 |
- recipe-generation
|
8 |
pipeline_tag: text2text-generation
|
9 |
widget:
|
10 |
-
- text: "
|
11 |
-
- text: "
|
12 |
-
- text: "
|
13 |
-
- text: "
|
14 |
-
- text: "
|
15 |
-
- text: "fresh strawberries, sugar, vinegar"
|
16 |
-
- text: "soy sauce, lime juice, apricot preserves, water, garlic, ground ginger, carrots, green onions, olive oil, shrimp, sweet red pepper, rice, lettuce leaves"
|
17 |
-
- text: "cucumbers, lemon juice, parsley, water, onion, sour cream, black pepper, salt"
|
18 |
- text: "chicken breasts, onion, garlic, great northern beans, black beans, green chilies, broccoli, garlic oil, butter, cajun seasoning, salt, oregano, thyme, black pepper, basil, worcestershire sauce, chicken broth, sour cream, chardonnay wine"
|
19 |
- text: "serrano peppers, garlic, celery, oregano, canola oil, vinegar, water, kosher salt, salt, black pepper"
|
20 |
-
|
21 |
---
|
22 |
|
23 |
![avatar](chef-transformer.png)
|
@@ -26,7 +22,6 @@ widget:
|
|
26 |
> This is part of the
|
27 |
[Flax/Jax Community Week](https://discuss.huggingface.co/t/recipe-generation-model/7475), organized by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
|
28 |
|
29 |
-
... SOON
|
30 |
|
31 |
## Team Members
|
32 |
- Mehrdad Farahani ([m3hrdadfi](https://huggingface.co/m3hrdadfi))
|
@@ -73,19 +68,151 @@ widget:
|
|
73 |
|
74 |
## How To Use
|
75 |
|
76 |
-
|
77 |
|
78 |
-
|
|
|
|
|
79 |
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
|
87 |
|
88 |
-
|
|
|
|
|
|
|
|
|
89 |
|
90 |
|
91 |
## Copyright
|
|
|
7 |
- recipe-generation
|
8 |
pipeline_tag: text2text-generation
|
9 |
widget:
|
10 |
+
- text: "provolone cheese, bacon, bread, ginger"
|
11 |
+
- text: "sugar, crunchy jif peanut butter, cornflakes"
|
12 |
+
- text: "sweet butter, confectioners sugar, flaked coconut, condensed milk, nuts, vanilla, dipping chocolate"
|
13 |
+
- text: "macaroni, butter, salt, bacon, milk, flour, pepper, cream corn"
|
14 |
+
- text: "hamburger, sausage, onion, regular, american cheese, colby cheese"
|
|
|
|
|
|
|
15 |
- text: "chicken breasts, onion, garlic, great northern beans, black beans, green chilies, broccoli, garlic oil, butter, cajun seasoning, salt, oregano, thyme, black pepper, basil, worcestershire sauce, chicken broth, sour cream, chardonnay wine"
|
16 |
- text: "serrano peppers, garlic, celery, oregano, canola oil, vinegar, water, kosher salt, salt, black pepper"
|
|
|
17 |
---
|
18 |
|
19 |
![avatar](chef-transformer.png)
|
|
|
22 |
> This is part of the
|
23 |
[Flax/Jax Community Week](https://discuss.huggingface.co/t/recipe-generation-model/7475), organized by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
|
24 |
|
|
|
25 |
|
26 |
## Team Members
|
27 |
- Mehrdad Farahani ([m3hrdadfi](https://huggingface.co/m3hrdadfi))
|
|
|
68 |
|
69 |
## How To Use
|
70 |
|
71 |
+
### Installing requirements
|
72 |
|
73 |
+
```bash
|
74 |
+
pip install transformers
|
75 |
+
```
|
76 |
|
77 |
+
```python
|
78 |
+
from transformers import FlaxAutoModelForSeq2SeqLM
|
79 |
+
from transformers import AutoTokenizer
|
80 |
+
|
81 |
+
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
|
83 |
+
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
|
84 |
+
|
85 |
+
prefix = "items: "
|
86 |
+
# generation_kwargs = {
|
87 |
+
# "max_length": 1024,
|
88 |
+
# "min_length": 128,
|
89 |
+
# "no_repeat_ngram_size": 3,
|
90 |
+
# "do_sample": True,
|
91 |
+
# "top_k": 60,
|
92 |
+
# "top_p": 0.95
|
93 |
+
# }
|
94 |
+
generation_kwargs = {
|
95 |
+
"max_length": 512,
|
96 |
+
"min_length": 64,
|
97 |
+
"no_repeat_ngram_size": 3,
|
98 |
+
"early_stopping": True,
|
99 |
+
"num_beams": 5,
|
100 |
+
"length_penalty": 1.5,
|
101 |
+
}
|
102 |
|
103 |
+
special_tokens = tokenizer.all_special_tokens
|
104 |
+
tokens_map = {
|
105 |
+
"<sep>": "--",
|
106 |
+
"<section>": "\n"
|
107 |
+
}
|
108 |
+
def skip_special_tokens(text, special_tokens):
|
109 |
+
for token in special_tokens:
|
110 |
+
text = text.replace(token, '')
|
111 |
+
|
112 |
+
return text
|
113 |
+
|
114 |
+
def target_postprocessing(texts, special_tokens):
|
115 |
+
if not isinstance(texts, list):
|
116 |
+
texts = [texts]
|
117 |
+
|
118 |
+
new_texts = []
|
119 |
+
for text in texts:
|
120 |
+
text = skip_special_tokens(text, special_tokens)
|
121 |
+
|
122 |
+
for k, v in tokens_map.items():
|
123 |
+
text = text.replace(k, v)
|
124 |
+
|
125 |
+
new_texts.append(text)
|
126 |
+
|
127 |
+
return new_texts
|
128 |
+
|
129 |
+
def generation_function(texts):
|
130 |
+
_inputs = texts if isinstance(texts, list) else [texts]
|
131 |
+
inputs = [prefix + inp for inp in _inputs]
|
132 |
+
inputs = tokenizer(
|
133 |
+
inputs,
|
134 |
+
max_length=256,
|
135 |
+
padding="max_length",
|
136 |
+
truncation=True,
|
137 |
+
return_tensors='jax'
|
138 |
+
)
|
139 |
+
|
140 |
+
input_ids = inputs.input_ids
|
141 |
+
attention_mask = inputs.attention_mask
|
142 |
+
|
143 |
+
output_ids = model.generate(
|
144 |
+
input_ids=input_ids,
|
145 |
+
attention_mask=attention_mask,
|
146 |
+
**generation_kwargs
|
147 |
+
)
|
148 |
+
generated = output_ids.sequences
|
149 |
+
generated_recipe = target_postprocessing(
|
150 |
+
tokenizer.batch_decode(generated, skip_special_tokens=False),
|
151 |
+
special_tokens
|
152 |
+
)
|
153 |
+
return generated_recipe
|
154 |
+
```
|
155 |
|
156 |
+
```python
|
157 |
+
items = [
|
158 |
+
"macaroni, butter, salt, bacon, milk, flour, pepper, cream corn"
|
159 |
+
]
|
160 |
+
generated = generation_function(items)
|
161 |
+
for text in generated:
|
162 |
+
sections = text.split("\n")
|
163 |
+
for section in sections:
|
164 |
+
section = section.strip()
|
165 |
+
if section.startswith("title:"):
|
166 |
+
section = section.replace("title:", "")
|
167 |
+
headline = "TITLE"
|
168 |
+
elif section.startswith("ingredients:"):
|
169 |
+
section = section.replace("ingredients:", "")
|
170 |
+
headline = "INGREDIENTS"
|
171 |
+
elif section.startswith("directions:"):
|
172 |
+
section = section.replace("directions:", "")
|
173 |
+
headline = "DIRECTIONS"
|
174 |
+
|
175 |
+
section_info = [f" - {i+1}: {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))]
|
176 |
+
print(f"[{headline}]:")
|
177 |
+
print("\n".join(section_info))
|
178 |
+
|
179 |
+
print("-" * 130)
|
180 |
+
```
|
181 |
+
|
182 |
+
Output:
|
183 |
+
```text
|
184 |
+
[TITLE]:
|
185 |
+
- 1: Macaroni and corn
|
186 |
+
[INGREDIENTS]:
|
187 |
+
- 1: 2 c. macaroni
|
188 |
+
- 2: 2 tbsp. butter
|
189 |
+
- 3: 1 tsp. salt
|
190 |
+
- 4: 4 slices bacon
|
191 |
+
- 5: 2 c. milk
|
192 |
+
- 6: 2 tbsp. flour
|
193 |
+
- 7: 1/4 tsp. pepper
|
194 |
+
- 8: 1 can cream corn
|
195 |
+
[DIRECTIONS]:
|
196 |
+
- 1: Cook macaroni in boiling salted water until tender.
|
197 |
+
- 2: Drain.
|
198 |
+
- 3: Melt butter in saucepan.
|
199 |
+
- 4: Blend in flour, salt and pepper.
|
200 |
+
- 5: Add milk all at once.
|
201 |
+
- 6: Cook and stir until thickened and bubbly.
|
202 |
+
- 7: Stir in corn and bacon.
|
203 |
+
- 8: Pour over macaroni and mix well.
|
204 |
+
----------------------------------------------------------------------------------------------------------------------------------
|
205 |
+
```
|
206 |
+
|
207 |
+
## Evaluation
|
208 |
|
209 |
+
The following tables summarize the scores obtained by the **Chef Transformer**. Those marked as (*) are the baseline models.
|
210 |
|
211 |
+
| Model | BLEU | WER | COSIM | ROUGE-2 |
|
212 |
+
|:---------------:|:-----:|:-----:|:-----:|:-------:|
|
213 |
+
| Recipe1M+ * | 0.844 | 0.786 | 0.589 | - |
|
214 |
+
| RecipeNLG * | 0.866 | 0.751 | 0.666 | - |
|
215 |
+
| ChefTransformer | 0.203 | 0.709 | 0.714 | 0.290 |
|
216 |
|
217 |
|
218 |
## Copyright
|
config.json
CHANGED
@@ -51,9 +51,9 @@
|
|
51 |
},
|
52 |
"text2text-generation": {
|
53 |
"early_stopping": true,
|
54 |
-
"max_length":
|
55 |
-
"repetition_penalty": 1.
|
56 |
-
"length_penalty": 1.
|
57 |
"num_beams": 5,
|
58 |
"prefix": "items: "
|
59 |
}
|
|
|
51 |
},
|
52 |
"text2text-generation": {
|
53 |
"early_stopping": true,
|
54 |
+
"max_length": 512,
|
55 |
+
"repetition_penalty": 1.2,
|
56 |
+
"length_penalty": 1.2,
|
57 |
"num_beams": 5,
|
58 |
"prefix": "items: "
|
59 |
}
|
{notes → data}/test.csv
RENAMED
File without changes
|
data/test_generated.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:479d8789478d92caafa5f31c7c6dbe0fbe55deb34fb0bda72fe410b77a914427
|
3 |
+
size 145654259
|
src/evaluation.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
5 |
+
|
6 |
+
from datasets import load_metric
|
7 |
+
|
8 |
+
import nltk
|
9 |
+
from nltk.tokenize import wordpunct_tokenize
|
10 |
+
from nltk.corpus import stopwords
|
11 |
+
import nltk.translate.bleu_score as bleu
|
12 |
+
from nltk.translate.bleu_score import SmoothingFunction
|
13 |
+
import nltk.translate.gleu_score as gleu
|
14 |
+
import nltk.translate.meteor_score as meteor
|
15 |
+
|
16 |
+
from jiwer import wer, mer
|
17 |
+
|
18 |
+
import re
|
19 |
+
import math
|
20 |
+
from collections import Counter
|
21 |
+
import string
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
|
25 |
+
nltk.download('stopwords')
|
26 |
+
stopwords = stopwords.words("english")
|
27 |
+
|
28 |
+
|
29 |
+
df = pd.read_csv("./test_generated.csv", sep="\t")
|
30 |
+
true_recipes = df["true_recipe"].values.tolist()
|
31 |
+
generated_recipes = df["generated_recipe"].values.tolist()
|
32 |
+
|
33 |
+
def cleaning(text, rm_sep=True, rm_nl=True, rm_punk_stopwords=True):
|
34 |
+
if rm_sep:
|
35 |
+
text = text.replace("--", " ")
|
36 |
+
|
37 |
+
if rm_nl:
|
38 |
+
text = text.replace("\n", " ")
|
39 |
+
|
40 |
+
if rm_punk_stopwords:
|
41 |
+
text = " ".join([word.strip() for word in wordpunct_tokenize(text) if word not in string.punctuation and word not in stopwords and word])
|
42 |
+
else:
|
43 |
+
text = " ".join([word.strip() for word in wordpunct_tokenize(text) if word.strip()])
|
44 |
+
|
45 |
+
text = text.lower()
|
46 |
+
return text
|
47 |
+
|
48 |
+
X, Y = [], []
|
49 |
+
for x, y in tqdm(zip(true_recipes, generated_recipes), total=len(df)):
|
50 |
+
x, y = cleaning(x, True, True, True), cleaning(y, True, True, True)
|
51 |
+
|
52 |
+
if len(x) > 16 and len(y) > 16:
|
53 |
+
X.append(x)
|
54 |
+
Y.append(y)
|
55 |
+
|
56 |
+
|
57 |
+
print(f"Sample X: {X[0]}")
|
58 |
+
print(f"Sample Y: {Y[0]}")
|
59 |
+
|
60 |
+
def get_cosine(vec1, vec2):
|
61 |
+
intersection = set(vec1.keys()) & set(vec2.keys())
|
62 |
+
numerator = sum([vec1[x] * vec2[x] for x in intersection])
|
63 |
+
|
64 |
+
sum1 = sum([vec1[x]**2 for x in vec1.keys()])
|
65 |
+
sum2 = sum([vec2[x]**2 for x in vec2.keys()])
|
66 |
+
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
67 |
+
|
68 |
+
if not denominator:
|
69 |
+
return 0.0
|
70 |
+
else:
|
71 |
+
return float(numerator) / denominator
|
72 |
+
|
73 |
+
def text_to_vector(text):
|
74 |
+
word = re.compile(r'\w+')
|
75 |
+
words = word.findall(text)
|
76 |
+
return Counter(words)
|
77 |
+
|
78 |
+
def get_result(content_a, content_b):
|
79 |
+
text1 = content_a
|
80 |
+
text2 = content_b
|
81 |
+
|
82 |
+
vector1 = text_to_vector(text1)
|
83 |
+
vector2 = text_to_vector(text2)
|
84 |
+
|
85 |
+
cosine_result = get_cosine(vector1, vector2)
|
86 |
+
return cosine_result
|
87 |
+
|
88 |
+
|
89 |
+
cosim_scores = []
|
90 |
+
for i in tqdm(range(len(X))):
|
91 |
+
cosim_scores.append(get_result(X[i], Y[i]))
|
92 |
+
|
93 |
+
cosim_score = np.array(cosim_scores).mean()
|
94 |
+
print(f"Cosine similarity score: {cosim_score}") # 0.714542
|
95 |
+
|
96 |
+
X, Y = [], []
|
97 |
+
for x, y in tqdm(zip(true_recipes, generated_recipes), total=len(df)):
|
98 |
+
x, y = cleaning(x, True, True, False), cleaning(y, True, True, False)
|
99 |
+
|
100 |
+
if len(x) > 16 and len(y) > 16:
|
101 |
+
X.append(x)
|
102 |
+
Y.append(y)
|
103 |
+
|
104 |
+
|
105 |
+
wer = load_metric("wer")
|
106 |
+
wer_score = wer.compute(predictions=Y, references=X)
|
107 |
+
print(f"WER score: {wer_score}") # 0.70938
|
108 |
+
|
109 |
+
|
110 |
+
rouge = load_metric("rouge")
|
111 |
+
rouge_score = rouge.compute(predictions=Y, references=X, use_stemmer=True)
|
112 |
+
rouge_score = {key: value.mid.fmeasure * 100 for key, value in rouge_score.items()}
|
113 |
+
print(f"Rouge score: {rouge_score}") # {'rouge1': 56.30779082900833, 'rouge2': 29.07704230163075, 'rougeL': 45.812165960365924, 'rougeLsum': 45.813971137090654}
|
114 |
+
|
115 |
+
bleu = load_metric("bleu")
|
116 |
+
def postprocess_text(preds, labels):
|
117 |
+
preds = [wordpunct_tokenize(pred) for pred in preds]
|
118 |
+
labels = [[wordpunct_tokenize(label)] for label in labels]
|
119 |
+
|
120 |
+
return preds, labels
|
121 |
+
|
122 |
+
Y, X = postprocess_text(Y, X)
|
123 |
+
bleu_score = bleu.compute(predictions=Y, references=X)["bleu"]
|
124 |
+
print(f"BLEU score: {bleu_score}") # 0.203867
|
{notes → src}/flax_to_pytorch.py
RENAMED
File without changes
|
{notes → src}/flax_to_tf.py
RENAMED
File without changes
|
{notes → src}/generation.py
RENAMED
File without changes
|
src/prediction.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import FlaxAutoModelForSeq2SeqLM
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import textwrap
|
4 |
+
|
5 |
+
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
|
7 |
+
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
|
8 |
+
|
9 |
+
prefix = "items: "
|
10 |
+
# generation_kwargs = {
|
11 |
+
# "max_length": 1024,
|
12 |
+
# "min_length": 128,
|
13 |
+
# "no_repeat_ngram_size": 3,
|
14 |
+
# "do_sample": True,
|
15 |
+
# "top_k": 60,
|
16 |
+
# "top_p": 0.95
|
17 |
+
# }
|
18 |
+
generation_kwargs = {
|
19 |
+
"max_length": 512,
|
20 |
+
"min_length": 64,
|
21 |
+
"no_repeat_ngram_size": 3,
|
22 |
+
"early_stopping": True,
|
23 |
+
"num_beams": 5,
|
24 |
+
"length_penalty": 1.5,
|
25 |
+
}
|
26 |
+
|
27 |
+
special_tokens = tokenizer.all_special_tokens
|
28 |
+
tokens_map = {
|
29 |
+
"<sep>": "--",
|
30 |
+
"<section>": "\n"
|
31 |
+
}
|
32 |
+
def skip_special_tokens(text, special_tokens):
|
33 |
+
for token in special_tokens:
|
34 |
+
text = text.replace(token, '')
|
35 |
+
|
36 |
+
return text
|
37 |
+
|
38 |
+
def target_postprocessing(texts, special_tokens):
|
39 |
+
if not isinstance(texts, list):
|
40 |
+
texts = [texts]
|
41 |
+
|
42 |
+
new_texts = []
|
43 |
+
for text in texts:
|
44 |
+
text = skip_special_tokens(text, special_tokens)
|
45 |
+
|
46 |
+
for k, v in tokens_map.items():
|
47 |
+
text = text.replace(k, v)
|
48 |
+
|
49 |
+
new_texts.append(text)
|
50 |
+
|
51 |
+
return new_texts
|
52 |
+
|
53 |
+
def generation_function(texts):
|
54 |
+
_inputs = texts if isinstance(texts, list) else [texts]
|
55 |
+
inputs = [prefix + inp for inp in _inputs]
|
56 |
+
inputs = tokenizer(
|
57 |
+
inputs,
|
58 |
+
max_length=256,
|
59 |
+
padding="max_length",
|
60 |
+
truncation=True,
|
61 |
+
return_tensors='jax'
|
62 |
+
)
|
63 |
+
|
64 |
+
input_ids = inputs.input_ids
|
65 |
+
attention_mask = inputs.attention_mask
|
66 |
+
|
67 |
+
output_ids = model.generate(
|
68 |
+
input_ids=input_ids,
|
69 |
+
attention_mask=attention_mask,
|
70 |
+
**generation_kwargs
|
71 |
+
)
|
72 |
+
generated = output_ids.sequences
|
73 |
+
generated_recipe = target_postprocessing(
|
74 |
+
tokenizer.batch_decode(generated, skip_special_tokens=False),
|
75 |
+
special_tokens
|
76 |
+
)
|
77 |
+
return generated_recipe
|
78 |
+
|
79 |
+
|
80 |
+
items = [
|
81 |
+
"macaroni, butter, salt, bacon, milk, flour, pepper, cream corn"
|
82 |
+
]
|
83 |
+
generated = generation_function(items)
|
84 |
+
for text in generated:
|
85 |
+
sections = text.split("\n")
|
86 |
+
for section in sections:
|
87 |
+
section = section.strip()
|
88 |
+
if section.startswith("title:"):
|
89 |
+
section = section.replace("title:", "")
|
90 |
+
headline = "TITLE"
|
91 |
+
elif section.startswith("ingredients:"):
|
92 |
+
section = section.replace("ingredients:", "")
|
93 |
+
headline = "INGREDIENTS"
|
94 |
+
elif section.startswith("directions:"):
|
95 |
+
section = section.replace("directions:", "")
|
96 |
+
headline = "DIRECTIONS"
|
97 |
+
|
98 |
+
section_info = [f" - {i+1}: {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))]
|
99 |
+
print(f"[{headline}]:")
|
100 |
+
print("\n".join(section_info))
|
101 |
+
|
102 |
+
print("-" * 130)
|