File size: 6,079 Bytes
9e46f6d
 
 
 
20066dd
 
9e46f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20066dd
9e46f6d
 
 
20066dd
9e46f6d
 
20066dd
9e46f6d
 
20066dd
9e46f6d
 
 
 
 
20066dd
 
 
 
 
 
 
9e46f6d
 
 
20066dd
9e46f6d
 
 
 
 
 
20066dd
9e46f6d
20066dd
 
9e46f6d
20066dd
 
 
 
 
 
 
 
9e46f6d
20066dd
 
 
 
 
 
9e46f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
from resources import banner, error_html_response
import logging
logging.basicConfig(format='%(asctime)s: [%(levelname)s]: %(message)s', level=logging.INFO)

model_checkpoint = 'gastronomia-para-to2/gastronomia_para_to2'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)

special_tokens = [
    '<INPUT_START>',
    '<NEXT_INPUT>',
    '<INPUT_END>',
    '<TITLE_START>',
    '<TITLE_END>',
    '<INGR_START>',
    '<NEXT_INGR>',
    '<INGR_END>',
    '<INSTR_START>',
    '<NEXT_INSTR>',
    '<INSTR_END>']

def frame_html_response(html_response):
    return f"""<iframe style="width: 100%; height: 800px" name="result" allow="midi; geolocation; microphone; camera; 
    display-capture; encrypted-media;" sandbox="allow-modals allow-forms 
    allow-scripts allow-same-origin allow-popups 
    allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
    allowpaymentrequest="" frameborder="0" srcdoc='{html_response}'></iframe>"""


def check_special_tokens_order(pre_output):
    return (pre_output.find('<INPUT_START>') <
            pre_output.find('<NEXT_INPUT>') <=
            pre_output.rfind('<NEXT_INPUT>') <
            pre_output.find('<INPUT_END>') <
            pre_output.find('<INGR_START>') <
            pre_output.find('<NEXT_INGR>') <=
            pre_output.rfind('<NEXT_INGR>') <
            pre_output.find('<INGR_END>') <
            pre_output.find('<INSTR_START>') <
            pre_output.find('<NEXT_INSTR>') <=
            pre_output.rfind('<NEXT_INSTR>') <
            pre_output.find('<INSTR_END>') <
            pre_output.find('<TITLE_START>') <
            pre_output.find('<TITLE_END>'))


def make_html_response(title, ingredients, instructions):
    ingredients_html_list = '<ul><li>' + '</li><li>'.join(ingredients) + '</li></ul>'
    instructions_html_list = '<ol><li>' + '</li><li>'.join(instructions) + '</li></ol>'

    html_response = f'''
                    <!DOCTYPE html>
                    <html>
                    <body>                    
                    <h1>{title}</h1>

                    <h2>Ingredientes</h2>
                    {ingredients_html_list}

                    <h2>Instrucciones</h2>
                    {instructions_html_list}

                    </body>
                    </html>
    '''
    return html_response


def rerun_model_output(pre_output):
    if pre_output is None:
        return True
    elif not '<RECIPE_END>' in pre_output:
        logging.info('<RECIPE_END> not in pre_output')
        return True
    pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
    if not all(special_token in pre_output_trimmed for special_token in special_tokens):
        logging.info('Not all special tokens are in preoutput')
        return True
    elif not check_special_tokens_order(pre_output_trimmed):
        logging.info('Special tokens are unordered in preoutput')
        return True
    elif len(pre_output_trimmed.split())<75:
        logging.info('Length of the recipe is <75')
        return True
    else:
        return False

def check_wrong_ingredients(ingredients):
    new_ingredients = []
    for ingredient in ingredients:
        if ingredient.startswith('De '):
            new_ingredients.append(ingredient.strip('De ').capitalize())
        else:
            new_ingredients.append(ingredient)
    return new_ingredients


def make_recipe(input_ingredients):
    logging.info(f'Received inputs: {input_ingredients}')
    input_ingredients = re.sub(' y ', ', ', input_ingredients)
    input = '<RECIPE_START> '
    input += '<INPUT_START> ' + ' <NEXT_INPUT> '.join(input_ingredients.split(', ')) + ' <INPUT_END> '
    input += '<INGR_START> '
    tokenized_input = tokenizer(input, return_tensors='pt')

    pre_output = None
    i = 0
    while rerun_model_output(pre_output):
        if i == 3:
            return frame_html_response(error_html_response)
        output = model.generate(**tokenized_input,
                                max_length=600,
                                do_sample=True,
                                top_p=0.92,
                                top_k=50,
                                # no_repeat_ngram_size=3,
                                num_return_sequences=3)
        pre_output = tokenizer.decode(output[0], skip_special_tokens=False)
        i += 1
    pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
    output_ingredients = re.search('<INGR_START> (.*) <INGR_END>', pre_output_trimmed).group(1)
    output_ingredients = output_ingredients.split(' <NEXT_INGR> ')
    output_ingredients = list(set([output_ingredient.strip() for output_ingredient in output_ingredients]))
    output_ingredients = [output_ing.capitalize() for output_ing in output_ingredients]
    output_ingredients = check_wrong_ingredients(output_ingredients)
    output_title = re.search('<TITLE_START> (.*) <TITLE_END>', pre_output_trimmed).group(1).strip().capitalize()
    output_instructions = re.search('<INSTR_START> (.*) <INSTR_END>', pre_output_trimmed).group(1)
    output_instructions = output_instructions.split(' <NEXT_INSTR> ')

    html_response = make_html_response(output_title, output_ingredients, output_instructions)

    return frame_html_response(html_response)


iface = gr.Interface(
    fn=make_recipe,
    inputs=
    [
        gr.inputs.Textbox(lines=1, placeholder='ingrediente_1, ingrediente_2, ..., ingrediente_n',
                          label='Dime con qué ingredientes quieres que cocinemos hoy y te sugeriremos una receta tan pronto como nuestros fogones estén libres'),
    ],
    outputs=
    [
        gr.outputs.HTML(label="¡Esta es mi propuesta para ti! ¡Buen provecho!")
    ],
    examples=
    [
        ['salmón, zumo de naranja, aceite de oliva, sal, pimienta'],
        ['harina, azúcar, huevos, chocolate, levadura Royal']
    ],
    description=banner)
iface.launch(enable_queue=True)