m3hrdadfi commited on
Commit
1ec57a1
•
1 Parent(s): e0da258

Finalize model

Browse files
README.md CHANGED
@@ -7,17 +7,13 @@ tags:
7
  - recipe-generation
8
  pipeline_tag: text2text-generation
9
  widget:
10
- - text: "gold potatoes, olive oil, kosher salt, truffle salt"
11
- - text: "custard, sage, sugar, milk, heavy cream, eggs, egg yolks, salt, vanilla, maple bourbon caramel sauce, maple syrup, bourbon, sugar, water, light corn syrup, cream of tartar"
12
- - text: "bulgar wheat, olive oil, cucumber, tomato, lemon, garlic, green onions, fresh mint, salt"
13
- - text: "active dry yeast, milk, sugar, unsalted butter, chocolate, espresso powder, egg, vanilla, cocoa, salt, bread flour"
14
- - text: "penne pasta, ground beef, pasta sauce, ricotta cheese, mozzarella cheese, parmesan cheese, egg"
15
- - text: "fresh strawberries, sugar, vinegar"
16
- - text: "soy sauce, lime juice, apricot preserves, water, garlic, ground ginger, carrots, green onions, olive oil, shrimp, sweet red pepper, rice, lettuce leaves"
17
- - text: "cucumbers, lemon juice, parsley, water, onion, sour cream, black pepper, salt"
18
  - text: "chicken breasts, onion, garlic, great northern beans, black beans, green chilies, broccoli, garlic oil, butter, cajun seasoning, salt, oregano, thyme, black pepper, basil, worcestershire sauce, chicken broth, sour cream, chardonnay wine"
19
  - text: "serrano peppers, garlic, celery, oregano, canola oil, vinegar, water, kosher salt, salt, black pepper"
20
-
21
  ---
22
 
23
  ![avatar](chef-transformer.png)
@@ -26,7 +22,6 @@ widget:
26
  > This is part of the
27
  [Flax/Jax Community Week](https://discuss.huggingface.co/t/recipe-generation-model/7475), organized by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
28
 
29
- ... SOON
30
 
31
  ## Team Members
32
  - Mehrdad Farahani ([m3hrdadfi](https://huggingface.co/m3hrdadfi))
@@ -73,19 +68,151 @@ widget:
73
 
74
  ## How To Use
75
 
76
- ... SOON
77
 
78
- ## Evaluation
 
 
79
 
80
- ... SOON
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- ### Baseline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- ... SOON
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- ### Our Results
87
 
88
- ... SOON
 
 
 
 
89
 
90
 
91
  ## Copyright
7
  - recipe-generation
8
  pipeline_tag: text2text-generation
9
  widget:
10
+ - text: "provolone cheese, bacon, bread, ginger"
11
+ - text: "sugar, crunchy jif peanut butter, cornflakes"
12
+ - text: "sweet butter, confectioners sugar, flaked coconut, condensed milk, nuts, vanilla, dipping chocolate"
13
+ - text: "macaroni, butter, salt, bacon, milk, flour, pepper, cream corn"
14
+ - text: "hamburger, sausage, onion, regular, american cheese, colby cheese"
 
 
 
15
  - text: "chicken breasts, onion, garlic, great northern beans, black beans, green chilies, broccoli, garlic oil, butter, cajun seasoning, salt, oregano, thyme, black pepper, basil, worcestershire sauce, chicken broth, sour cream, chardonnay wine"
16
  - text: "serrano peppers, garlic, celery, oregano, canola oil, vinegar, water, kosher salt, salt, black pepper"
 
17
  ---
18
 
19
  ![avatar](chef-transformer.png)
22
  > This is part of the
23
  [Flax/Jax Community Week](https://discuss.huggingface.co/t/recipe-generation-model/7475), organized by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
24
 
 
25
 
26
  ## Team Members
27
  - Mehrdad Farahani ([m3hrdadfi](https://huggingface.co/m3hrdadfi))
68
 
69
  ## How To Use
70
 
71
+ ### Installing requirements
72
 
73
+ ```bash
74
+ pip install transformers
75
+ ```
76
 
77
+ ```python
78
+ from transformers import FlaxAutoModelForSeq2SeqLM
79
+ from transformers import AutoTokenizer
80
+
81
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
82
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
83
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
84
+
85
+ prefix = "items: "
86
+ # generation_kwargs = {
87
+ # "max_length": 1024,
88
+ # "min_length": 128,
89
+ # "no_repeat_ngram_size": 3,
90
+ # "do_sample": True,
91
+ # "top_k": 60,
92
+ # "top_p": 0.95
93
+ # }
94
+ generation_kwargs = {
95
+ "max_length": 512,
96
+ "min_length": 64,
97
+ "no_repeat_ngram_size": 3,
98
+ "early_stopping": True,
99
+ "num_beams": 5,
100
+ "length_penalty": 1.5,
101
+ }
102
 
103
+ special_tokens = tokenizer.all_special_tokens
104
+ tokens_map = {
105
+ "<sep>": "--",
106
+ "<section>": "\n"
107
+ }
108
+ def skip_special_tokens(text, special_tokens):
109
+ for token in special_tokens:
110
+ text = text.replace(token, '')
111
+
112
+ return text
113
+
114
+ def target_postprocessing(texts, special_tokens):
115
+ if not isinstance(texts, list):
116
+ texts = [texts]
117
+
118
+ new_texts = []
119
+ for text in texts:
120
+ text = skip_special_tokens(text, special_tokens)
121
+
122
+ for k, v in tokens_map.items():
123
+ text = text.replace(k, v)
124
+
125
+ new_texts.append(text)
126
+
127
+ return new_texts
128
+
129
+ def generation_function(texts):
130
+ _inputs = texts if isinstance(texts, list) else [texts]
131
+ inputs = [prefix + inp for inp in _inputs]
132
+ inputs = tokenizer(
133
+ inputs,
134
+ max_length=256,
135
+ padding="max_length",
136
+ truncation=True,
137
+ return_tensors='jax'
138
+ )
139
+
140
+ input_ids = inputs.input_ids
141
+ attention_mask = inputs.attention_mask
142
+
143
+ output_ids = model.generate(
144
+ input_ids=input_ids,
145
+ attention_mask=attention_mask,
146
+ **generation_kwargs
147
+ )
148
+ generated = output_ids.sequences
149
+ generated_recipe = target_postprocessing(
150
+ tokenizer.batch_decode(generated, skip_special_tokens=False),
151
+ special_tokens
152
+ )
153
+ return generated_recipe
154
+ ```
155
 
156
+ ```python
157
+ items = [
158
+ "macaroni, butter, salt, bacon, milk, flour, pepper, cream corn"
159
+ ]
160
+ generated = generation_function(items)
161
+ for text in generated:
162
+ sections = text.split("\n")
163
+ for section in sections:
164
+ section = section.strip()
165
+ if section.startswith("title:"):
166
+ section = section.replace("title:", "")
167
+ headline = "TITLE"
168
+ elif section.startswith("ingredients:"):
169
+ section = section.replace("ingredients:", "")
170
+ headline = "INGREDIENTS"
171
+ elif section.startswith("directions:"):
172
+ section = section.replace("directions:", "")
173
+ headline = "DIRECTIONS"
174
+
175
+ section_info = [f" - {i+1}: {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))]
176
+ print(f"[{headline}]:")
177
+ print("\n".join(section_info))
178
+
179
+ print("-" * 130)
180
+ ```
181
+
182
+ Output:
183
+ ```text
184
+ [TITLE]:
185
+ - 1: Macaroni and corn
186
+ [INGREDIENTS]:
187
+ - 1: 2 c. macaroni
188
+ - 2: 2 tbsp. butter
189
+ - 3: 1 tsp. salt
190
+ - 4: 4 slices bacon
191
+ - 5: 2 c. milk
192
+ - 6: 2 tbsp. flour
193
+ - 7: 1/4 tsp. pepper
194
+ - 8: 1 can cream corn
195
+ [DIRECTIONS]:
196
+ - 1: Cook macaroni in boiling salted water until tender.
197
+ - 2: Drain.
198
+ - 3: Melt butter in saucepan.
199
+ - 4: Blend in flour, salt and pepper.
200
+ - 5: Add milk all at once.
201
+ - 6: Cook and stir until thickened and bubbly.
202
+ - 7: Stir in corn and bacon.
203
+ - 8: Pour over macaroni and mix well.
204
+ ----------------------------------------------------------------------------------------------------------------------------------
205
+ ```
206
+
207
+ ## Evaluation
208
 
209
+ The following tables summarize the scores obtained by the **Chef Transformer**. Those marked as (*) are the baseline models.
210
 
211
+ | Model | BLEU | WER | COSIM | ROUGE-2 |
212
+ |:---------------:|:-----:|:-----:|:-----:|:-------:|
213
+ | Recipe1M+ * | 0.844 | 0.786 | 0.589 | - |
214
+ | RecipeNLG * | 0.866 | 0.751 | 0.666 | - |
215
+ | ChefTransformer | 0.203 | 0.709 | 0.714 | 0.290 |
216
 
217
 
218
  ## Copyright
config.json CHANGED
@@ -51,9 +51,9 @@
51
  },
52
  "text2text-generation": {
53
  "early_stopping": true,
54
- "max_length": 1024,
55
- "repetition_penalty": 1.0,
56
- "length_penalty": 1.0,
57
  "num_beams": 5,
58
  "prefix": "items: "
59
  }
51
  },
52
  "text2text-generation": {
53
  "early_stopping": true,
54
+ "max_length": 512,
55
+ "repetition_penalty": 1.2,
56
+ "length_penalty": 1.2,
57
  "num_beams": 5,
58
  "prefix": "items: "
59
  }
{notes → data}/test.csv RENAMED
File without changes
data/test_generated.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:479d8789478d92caafa5f31c7c6dbe0fbe55deb34fb0bda72fe410b77a914427
3
+ size 145654259
src/evaluation.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ from sklearn.feature_extraction.text import CountVectorizer
5
+
6
+ from datasets import load_metric
7
+
8
+ import nltk
9
+ from nltk.tokenize import wordpunct_tokenize
10
+ from nltk.corpus import stopwords
11
+ import nltk.translate.bleu_score as bleu
12
+ from nltk.translate.bleu_score import SmoothingFunction
13
+ import nltk.translate.gleu_score as gleu
14
+ import nltk.translate.meteor_score as meteor
15
+
16
+ from jiwer import wer, mer
17
+
18
+ import re
19
+ import math
20
+ from collections import Counter
21
+ import string
22
+ from tqdm import tqdm
23
+
24
+
25
+ nltk.download('stopwords')
26
+ stopwords = stopwords.words("english")
27
+
28
+
29
+ df = pd.read_csv("./test_generated.csv", sep="\t")
30
+ true_recipes = df["true_recipe"].values.tolist()
31
+ generated_recipes = df["generated_recipe"].values.tolist()
32
+
33
+ def cleaning(text, rm_sep=True, rm_nl=True, rm_punk_stopwords=True):
34
+ if rm_sep:
35
+ text = text.replace("--", " ")
36
+
37
+ if rm_nl:
38
+ text = text.replace("\n", " ")
39
+
40
+ if rm_punk_stopwords:
41
+ text = " ".join([word.strip() for word in wordpunct_tokenize(text) if word not in string.punctuation and word not in stopwords and word])
42
+ else:
43
+ text = " ".join([word.strip() for word in wordpunct_tokenize(text) if word.strip()])
44
+
45
+ text = text.lower()
46
+ return text
47
+
48
+ X, Y = [], []
49
+ for x, y in tqdm(zip(true_recipes, generated_recipes), total=len(df)):
50
+ x, y = cleaning(x, True, True, True), cleaning(y, True, True, True)
51
+
52
+ if len(x) > 16 and len(y) > 16:
53
+ X.append(x)
54
+ Y.append(y)
55
+
56
+
57
+ print(f"Sample X: {X[0]}")
58
+ print(f"Sample Y: {Y[0]}")
59
+
60
+ def get_cosine(vec1, vec2):
61
+ intersection = set(vec1.keys()) & set(vec2.keys())
62
+ numerator = sum([vec1[x] * vec2[x] for x in intersection])
63
+
64
+ sum1 = sum([vec1[x]**2 for x in vec1.keys()])
65
+ sum2 = sum([vec2[x]**2 for x in vec2.keys()])
66
+ denominator = math.sqrt(sum1) * math.sqrt(sum2)
67
+
68
+ if not denominator:
69
+ return 0.0
70
+ else:
71
+ return float(numerator) / denominator
72
+
73
+ def text_to_vector(text):
74
+ word = re.compile(r'\w+')
75
+ words = word.findall(text)
76
+ return Counter(words)
77
+
78
+ def get_result(content_a, content_b):
79
+ text1 = content_a
80
+ text2 = content_b
81
+
82
+ vector1 = text_to_vector(text1)
83
+ vector2 = text_to_vector(text2)
84
+
85
+ cosine_result = get_cosine(vector1, vector2)
86
+ return cosine_result
87
+
88
+
89
+ cosim_scores = []
90
+ for i in tqdm(range(len(X))):
91
+ cosim_scores.append(get_result(X[i], Y[i]))
92
+
93
+ cosim_score = np.array(cosim_scores).mean()
94
+ print(f"Cosine similarity score: {cosim_score}") # 0.714542
95
+
96
+ X, Y = [], []
97
+ for x, y in tqdm(zip(true_recipes, generated_recipes), total=len(df)):
98
+ x, y = cleaning(x, True, True, False), cleaning(y, True, True, False)
99
+
100
+ if len(x) > 16 and len(y) > 16:
101
+ X.append(x)
102
+ Y.append(y)
103
+
104
+
105
+ wer = load_metric("wer")
106
+ wer_score = wer.compute(predictions=Y, references=X)
107
+ print(f"WER score: {wer_score}") # 0.70938
108
+
109
+
110
+ rouge = load_metric("rouge")
111
+ rouge_score = rouge.compute(predictions=Y, references=X, use_stemmer=True)
112
+ rouge_score = {key: value.mid.fmeasure * 100 for key, value in rouge_score.items()}
113
+ print(f"Rouge score: {rouge_score}") # {'rouge1': 56.30779082900833, 'rouge2': 29.07704230163075, 'rougeL': 45.812165960365924, 'rougeLsum': 45.813971137090654}
114
+
115
+ bleu = load_metric("bleu")
116
+ def postprocess_text(preds, labels):
117
+ preds = [wordpunct_tokenize(pred) for pred in preds]
118
+ labels = [[wordpunct_tokenize(label)] for label in labels]
119
+
120
+ return preds, labels
121
+
122
+ Y, X = postprocess_text(Y, X)
123
+ bleu_score = bleu.compute(predictions=Y, references=X)["bleu"]
124
+ print(f"BLEU score: {bleu_score}") # 0.203867
{notes → src}/flax_to_pytorch.py RENAMED
File without changes
{notes → src}/flax_to_tf.py RENAMED
File without changes
{notes → src}/generation.py RENAMED
File without changes
src/prediction.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import FlaxAutoModelForSeq2SeqLM
2
+ from transformers import AutoTokenizer
3
+ import textwrap
4
+
5
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
6
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
7
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
8
+
9
+ prefix = "items: "
10
+ # generation_kwargs = {
11
+ # "max_length": 1024,
12
+ # "min_length": 128,
13
+ # "no_repeat_ngram_size": 3,
14
+ # "do_sample": True,
15
+ # "top_k": 60,
16
+ # "top_p": 0.95
17
+ # }
18
+ generation_kwargs = {
19
+ "max_length": 512,
20
+ "min_length": 64,
21
+ "no_repeat_ngram_size": 3,
22
+ "early_stopping": True,
23
+ "num_beams": 5,
24
+ "length_penalty": 1.5,
25
+ }
26
+
27
+ special_tokens = tokenizer.all_special_tokens
28
+ tokens_map = {
29
+ "<sep>": "--",
30
+ "<section>": "\n"
31
+ }
32
+ def skip_special_tokens(text, special_tokens):
33
+ for token in special_tokens:
34
+ text = text.replace(token, '')
35
+
36
+ return text
37
+
38
+ def target_postprocessing(texts, special_tokens):
39
+ if not isinstance(texts, list):
40
+ texts = [texts]
41
+
42
+ new_texts = []
43
+ for text in texts:
44
+ text = skip_special_tokens(text, special_tokens)
45
+
46
+ for k, v in tokens_map.items():
47
+ text = text.replace(k, v)
48
+
49
+ new_texts.append(text)
50
+
51
+ return new_texts
52
+
53
+ def generation_function(texts):
54
+ _inputs = texts if isinstance(texts, list) else [texts]
55
+ inputs = [prefix + inp for inp in _inputs]
56
+ inputs = tokenizer(
57
+ inputs,
58
+ max_length=256,
59
+ padding="max_length",
60
+ truncation=True,
61
+ return_tensors='jax'
62
+ )
63
+
64
+ input_ids = inputs.input_ids
65
+ attention_mask = inputs.attention_mask
66
+
67
+ output_ids = model.generate(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ **generation_kwargs
71
+ )
72
+ generated = output_ids.sequences
73
+ generated_recipe = target_postprocessing(
74
+ tokenizer.batch_decode(generated, skip_special_tokens=False),
75
+ special_tokens
76
+ )
77
+ return generated_recipe
78
+
79
+
80
+ items = [
81
+ "macaroni, butter, salt, bacon, milk, flour, pepper, cream corn"
82
+ ]
83
+ generated = generation_function(items)
84
+ for text in generated:
85
+ sections = text.split("\n")
86
+ for section in sections:
87
+ section = section.strip()
88
+ if section.startswith("title:"):
89
+ section = section.replace("title:", "")
90
+ headline = "TITLE"
91
+ elif section.startswith("ingredients:"):
92
+ section = section.replace("ingredients:", "")
93
+ headline = "INGREDIENTS"
94
+ elif section.startswith("directions:"):
95
+ section = section.replace("directions:", "")
96
+ headline = "DIRECTIONS"
97
+
98
+ section_info = [f" - {i+1}: {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))]
99
+ print(f"[{headline}]:")
100
+ print("\n".join(section_info))
101
+
102
+ print("-" * 130)