jucendrero
commited on
Commit
•
20066dd
1
Parent(s):
9e46f6d
Second functional version
Browse files
app.py
CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import re
|
4 |
from resources import banner, error_html_response
|
|
|
|
|
5 |
|
6 |
model_checkpoint = 'gastronomia-para-to2/gastronomia_para_to2'
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
@@ -71,64 +73,59 @@ def rerun_model_output(pre_output):
|
|
71 |
if pre_output is None:
|
72 |
return True
|
73 |
elif not '<RECIPE_END>' in pre_output:
|
74 |
-
|
75 |
return True
|
76 |
pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
|
77 |
if not all(special_token in pre_output_trimmed for special_token in special_tokens):
|
78 |
-
|
79 |
return True
|
80 |
elif not check_special_tokens_order(pre_output_trimmed):
|
81 |
-
|
82 |
return True
|
83 |
elif len(pre_output_trimmed.split())<75:
|
84 |
-
|
85 |
return True
|
86 |
else:
|
87 |
return False
|
88 |
|
89 |
-
|
90 |
-
def generate_output(tokenized_input):
|
91 |
-
pre_output = None
|
92 |
-
while rerun_model_output(pre_output):
|
93 |
-
output = model.generate(**tokenized_input,
|
94 |
-
max_length=600,
|
95 |
-
do_sample=True,
|
96 |
-
top_p=0.92,
|
97 |
-
top_k=50,
|
98 |
-
# no_repeat_ngram_size=2,
|
99 |
-
num_return_sequences=3)
|
100 |
-
pre_output = tokenizer.decode(output[0], skip_special_tokens=False)
|
101 |
-
pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
|
102 |
-
return pre_output_trimmed
|
103 |
-
|
104 |
-
|
105 |
def check_wrong_ingredients(ingredients):
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
111 |
|
112 |
|
113 |
def make_recipe(input_ingredients):
|
|
|
114 |
input_ingredients = re.sub(' y ', ', ', input_ingredients)
|
115 |
input = '<RECIPE_START> '
|
116 |
input += '<INPUT_START> ' + ' <NEXT_INPUT> '.join(input_ingredients.split(', ')) + ' <INPUT_END> '
|
117 |
input += '<INGR_START> '
|
118 |
tokenized_input = tokenizer(input, return_tensors='pt')
|
119 |
|
120 |
-
|
121 |
i = 0
|
122 |
-
while
|
123 |
-
if i ==
|
124 |
return frame_html_response(error_html_response)
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
130 |
i += 1
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
132 |
output_title = re.search('<TITLE_START> (.*) <TITLE_END>', pre_output_trimmed).group(1).strip().capitalize()
|
133 |
output_instructions = re.search('<INSTR_START> (.*) <INSTR_END>', pre_output_trimmed).group(1)
|
134 |
output_instructions = output_instructions.split(' <NEXT_INSTR> ')
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import re
|
4 |
from resources import banner, error_html_response
|
5 |
+
import logging
|
6 |
+
logging.basicConfig(format='%(asctime)s: [%(levelname)s]: %(message)s', level=logging.INFO)
|
7 |
|
8 |
model_checkpoint = 'gastronomia-para-to2/gastronomia_para_to2'
|
9 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
|
|
73 |
if pre_output is None:
|
74 |
return True
|
75 |
elif not '<RECIPE_END>' in pre_output:
|
76 |
+
logging.info('<RECIPE_END> not in pre_output')
|
77 |
return True
|
78 |
pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
|
79 |
if not all(special_token in pre_output_trimmed for special_token in special_tokens):
|
80 |
+
logging.info('Not all special tokens are in preoutput')
|
81 |
return True
|
82 |
elif not check_special_tokens_order(pre_output_trimmed):
|
83 |
+
logging.info('Special tokens are unordered in preoutput')
|
84 |
return True
|
85 |
elif len(pre_output_trimmed.split())<75:
|
86 |
+
logging.info('Length of the recipe is <75')
|
87 |
return True
|
88 |
else:
|
89 |
return False
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
def check_wrong_ingredients(ingredients):
|
92 |
+
new_ingredients = []
|
93 |
+
for ingredient in ingredients:
|
94 |
+
if ingredient.startswith('De '):
|
95 |
+
new_ingredients.append(ingredient.strip('De ').capitalize())
|
96 |
+
else:
|
97 |
+
new_ingredients.append(ingredient)
|
98 |
+
return new_ingredients
|
99 |
|
100 |
|
101 |
def make_recipe(input_ingredients):
|
102 |
+
logging.info(f'Received inputs: {input_ingredients}')
|
103 |
input_ingredients = re.sub(' y ', ', ', input_ingredients)
|
104 |
input = '<RECIPE_START> '
|
105 |
input += '<INPUT_START> ' + ' <NEXT_INPUT> '.join(input_ingredients.split(', ')) + ' <INPUT_END> '
|
106 |
input += '<INGR_START> '
|
107 |
tokenized_input = tokenizer(input, return_tensors='pt')
|
108 |
|
109 |
+
pre_output = None
|
110 |
i = 0
|
111 |
+
while rerun_model_output(pre_output):
|
112 |
+
if i == 3:
|
113 |
return frame_html_response(error_html_response)
|
114 |
+
output = model.generate(**tokenized_input,
|
115 |
+
max_length=600,
|
116 |
+
do_sample=True,
|
117 |
+
top_p=0.92,
|
118 |
+
top_k=50,
|
119 |
+
# no_repeat_ngram_size=3,
|
120 |
+
num_return_sequences=3)
|
121 |
+
pre_output = tokenizer.decode(output[0], skip_special_tokens=False)
|
122 |
i += 1
|
123 |
+
pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
|
124 |
+
output_ingredients = re.search('<INGR_START> (.*) <INGR_END>', pre_output_trimmed).group(1)
|
125 |
+
output_ingredients = output_ingredients.split(' <NEXT_INGR> ')
|
126 |
+
output_ingredients = list(set([output_ingredient.strip() for output_ingredient in output_ingredients]))
|
127 |
+
output_ingredients = [output_ing.capitalize() for output_ing in output_ingredients]
|
128 |
+
output_ingredients = check_wrong_ingredients(output_ingredients)
|
129 |
output_title = re.search('<TITLE_START> (.*) <TITLE_END>', pre_output_trimmed).group(1).strip().capitalize()
|
130 |
output_instructions = re.search('<INSTR_START> (.*) <INSTR_END>', pre_output_trimmed).group(1)
|
131 |
output_instructions = output_instructions.split(' <NEXT_INSTR> ')
|