jucendrero's picture
Second functional version
20066dd
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
from resources import banner, error_html_response
import logging
logging.basicConfig(format='%(asctime)s: [%(levelname)s]: %(message)s', level=logging.INFO)
model_checkpoint = 'gastronomia-para-to2/gastronomia_para_to2'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
special_tokens = [
'<INPUT_START>',
'<NEXT_INPUT>',
'<INPUT_END>',
'<TITLE_START>',
'<TITLE_END>',
'<INGR_START>',
'<NEXT_INGR>',
'<INGR_END>',
'<INSTR_START>',
'<NEXT_INSTR>',
'<INSTR_END>']
def frame_html_response(html_response):
return f"""<iframe style="width: 100%; height: 800px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{html_response}'></iframe>"""
def check_special_tokens_order(pre_output):
return (pre_output.find('<INPUT_START>') <
pre_output.find('<NEXT_INPUT>') <=
pre_output.rfind('<NEXT_INPUT>') <
pre_output.find('<INPUT_END>') <
pre_output.find('<INGR_START>') <
pre_output.find('<NEXT_INGR>') <=
pre_output.rfind('<NEXT_INGR>') <
pre_output.find('<INGR_END>') <
pre_output.find('<INSTR_START>') <
pre_output.find('<NEXT_INSTR>') <=
pre_output.rfind('<NEXT_INSTR>') <
pre_output.find('<INSTR_END>') <
pre_output.find('<TITLE_START>') <
pre_output.find('<TITLE_END>'))
def make_html_response(title, ingredients, instructions):
ingredients_html_list = '<ul><li>' + '</li><li>'.join(ingredients) + '</li></ul>'
instructions_html_list = '<ol><li>' + '</li><li>'.join(instructions) + '</li></ol>'
html_response = f'''
<!DOCTYPE html>
<html>
<body>
<h1>{title}</h1>
<h2>Ingredientes</h2>
{ingredients_html_list}
<h2>Instrucciones</h2>
{instructions_html_list}
</body>
</html>
'''
return html_response
def rerun_model_output(pre_output):
if pre_output is None:
return True
elif not '<RECIPE_END>' in pre_output:
logging.info('<RECIPE_END> not in pre_output')
return True
pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
if not all(special_token in pre_output_trimmed for special_token in special_tokens):
logging.info('Not all special tokens are in preoutput')
return True
elif not check_special_tokens_order(pre_output_trimmed):
logging.info('Special tokens are unordered in preoutput')
return True
elif len(pre_output_trimmed.split())<75:
logging.info('Length of the recipe is <75')
return True
else:
return False
def check_wrong_ingredients(ingredients):
new_ingredients = []
for ingredient in ingredients:
if ingredient.startswith('De '):
new_ingredients.append(ingredient.strip('De ').capitalize())
else:
new_ingredients.append(ingredient)
return new_ingredients
def make_recipe(input_ingredients):
logging.info(f'Received inputs: {input_ingredients}')
input_ingredients = re.sub(' y ', ', ', input_ingredients)
input = '<RECIPE_START> '
input += '<INPUT_START> ' + ' <NEXT_INPUT> '.join(input_ingredients.split(', ')) + ' <INPUT_END> '
input += '<INGR_START> '
tokenized_input = tokenizer(input, return_tensors='pt')
pre_output = None
i = 0
while rerun_model_output(pre_output):
if i == 3:
return frame_html_response(error_html_response)
output = model.generate(**tokenized_input,
max_length=600,
do_sample=True,
top_p=0.92,
top_k=50,
# no_repeat_ngram_size=3,
num_return_sequences=3)
pre_output = tokenizer.decode(output[0], skip_special_tokens=False)
i += 1
pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
output_ingredients = re.search('<INGR_START> (.*) <INGR_END>', pre_output_trimmed).group(1)
output_ingredients = output_ingredients.split(' <NEXT_INGR> ')
output_ingredients = list(set([output_ingredient.strip() for output_ingredient in output_ingredients]))
output_ingredients = [output_ing.capitalize() for output_ing in output_ingredients]
output_ingredients = check_wrong_ingredients(output_ingredients)
output_title = re.search('<TITLE_START> (.*) <TITLE_END>', pre_output_trimmed).group(1).strip().capitalize()
output_instructions = re.search('<INSTR_START> (.*) <INSTR_END>', pre_output_trimmed).group(1)
output_instructions = output_instructions.split(' <NEXT_INSTR> ')
html_response = make_html_response(output_title, output_ingredients, output_instructions)
return frame_html_response(html_response)
iface = gr.Interface(
fn=make_recipe,
inputs=
[
gr.inputs.Textbox(lines=1, placeholder='ingrediente_1, ingrediente_2, ..., ingrediente_n',
label='Dime con qué ingredientes quieres que cocinemos hoy y te sugeriremos una receta tan pronto como nuestros fogones estén libres'),
],
outputs=
[
gr.outputs.HTML(label="¡Esta es mi propuesta para ti! ¡Buen provecho!")
],
examples=
[
['salmón, zumo de naranja, aceite de oliva, sal, pimienta'],
['harina, azúcar, huevos, chocolate, levadura Royal']
],
description=banner)
iface.launch(enable_queue=True)