jucendrero commited on
Commit
20066dd
1 Parent(s): 9e46f6d

Second functional version

Browse files
Files changed (1) hide show
  1. app.py +31 -34
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import re
4
  from resources import banner, error_html_response
 
 
5
 
6
  model_checkpoint = 'gastronomia-para-to2/gastronomia_para_to2'
7
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
@@ -71,64 +73,59 @@ def rerun_model_output(pre_output):
71
  if pre_output is None:
72
  return True
73
  elif not '<RECIPE_END>' in pre_output:
74
- print('<RECIPE_END> not in pre_output')
75
  return True
76
  pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
77
  if not all(special_token in pre_output_trimmed for special_token in special_tokens):
78
- print('Not all special tokens are in preoutput')
79
  return True
80
  elif not check_special_tokens_order(pre_output_trimmed):
81
- print('Special tokens are unordered in preoutput')
82
  return True
83
  elif len(pre_output_trimmed.split())<75:
84
- print('Length of the recipe is <75')
85
  return True
86
  else:
87
  return False
88
 
89
-
90
- def generate_output(tokenized_input):
91
- pre_output = None
92
- while rerun_model_output(pre_output):
93
- output = model.generate(**tokenized_input,
94
- max_length=600,
95
- do_sample=True,
96
- top_p=0.92,
97
- top_k=50,
98
- # no_repeat_ngram_size=2,
99
- num_return_sequences=3)
100
- pre_output = tokenizer.decode(output[0], skip_special_tokens=False)
101
- pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
102
- return pre_output_trimmed
103
-
104
-
105
  def check_wrong_ingredients(ingredients):
106
- if ingredients is None:
107
- return True
108
- if any(ingredient.startswith('De') for ingredient in ingredients):
109
- print('At least one ingredient starts with De')
110
- return True
 
 
111
 
112
 
113
  def make_recipe(input_ingredients):
 
114
  input_ingredients = re.sub(' y ', ', ', input_ingredients)
115
  input = '<RECIPE_START> '
116
  input += '<INPUT_START> ' + ' <NEXT_INPUT> '.join(input_ingredients.split(', ')) + ' <INPUT_END> '
117
  input += '<INGR_START> '
118
  tokenized_input = tokenizer(input, return_tensors='pt')
119
 
120
- output_ingredients = None
121
  i = 0
122
- while check_wrong_ingredients(output_ingredients):
123
- if i == 4:
124
  return frame_html_response(error_html_response)
125
- pre_output_trimmed = generate_output(tokenized_input)
126
- output_ingredients = re.search('<INGR_START> (.*) <INGR_END>', pre_output_trimmed).group(1)
127
- output_ingredients = output_ingredients.split(' <NEXT_INGR> ')
128
- output_ingredients = list(set([output_ingredient.strip() for output_ingredient in output_ingredients]))
129
- output_ingredients = [output_ing.capitalize() for output_ing in output_ingredients]
 
 
 
130
  i += 1
131
-
 
 
 
 
 
132
  output_title = re.search('<TITLE_START> (.*) <TITLE_END>', pre_output_trimmed).group(1).strip().capitalize()
133
  output_instructions = re.search('<INSTR_START> (.*) <INSTR_END>', pre_output_trimmed).group(1)
134
  output_instructions = output_instructions.split(' <NEXT_INSTR> ')
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import re
4
  from resources import banner, error_html_response
5
+ import logging
6
+ logging.basicConfig(format='%(asctime)s: [%(levelname)s]: %(message)s', level=logging.INFO)
7
 
8
  model_checkpoint = 'gastronomia-para-to2/gastronomia_para_to2'
9
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
 
73
  if pre_output is None:
74
  return True
75
  elif not '<RECIPE_END>' in pre_output:
76
+ logging.info('<RECIPE_END> not in pre_output')
77
  return True
78
  pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
79
  if not all(special_token in pre_output_trimmed for special_token in special_tokens):
80
+ logging.info('Not all special tokens are in preoutput')
81
  return True
82
  elif not check_special_tokens_order(pre_output_trimmed):
83
+ logging.info('Special tokens are unordered in preoutput')
84
  return True
85
  elif len(pre_output_trimmed.split())<75:
86
+ logging.info('Length of the recipe is <75')
87
  return True
88
  else:
89
  return False
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def check_wrong_ingredients(ingredients):
92
+ new_ingredients = []
93
+ for ingredient in ingredients:
94
+ if ingredient.startswith('De '):
95
+ new_ingredients.append(ingredient.strip('De ').capitalize())
96
+ else:
97
+ new_ingredients.append(ingredient)
98
+ return new_ingredients
99
 
100
 
101
  def make_recipe(input_ingredients):
102
+ logging.info(f'Received inputs: {input_ingredients}')
103
  input_ingredients = re.sub(' y ', ', ', input_ingredients)
104
  input = '<RECIPE_START> '
105
  input += '<INPUT_START> ' + ' <NEXT_INPUT> '.join(input_ingredients.split(', ')) + ' <INPUT_END> '
106
  input += '<INGR_START> '
107
  tokenized_input = tokenizer(input, return_tensors='pt')
108
 
109
+ pre_output = None
110
  i = 0
111
+ while rerun_model_output(pre_output):
112
+ if i == 3:
113
  return frame_html_response(error_html_response)
114
+ output = model.generate(**tokenized_input,
115
+ max_length=600,
116
+ do_sample=True,
117
+ top_p=0.92,
118
+ top_k=50,
119
+ # no_repeat_ngram_size=3,
120
+ num_return_sequences=3)
121
+ pre_output = tokenizer.decode(output[0], skip_special_tokens=False)
122
  i += 1
123
+ pre_output_trimmed = pre_output[:pre_output.find('<RECIPE_END>')]
124
+ output_ingredients = re.search('<INGR_START> (.*) <INGR_END>', pre_output_trimmed).group(1)
125
+ output_ingredients = output_ingredients.split(' <NEXT_INGR> ')
126
+ output_ingredients = list(set([output_ingredient.strip() for output_ingredient in output_ingredients]))
127
+ output_ingredients = [output_ing.capitalize() for output_ing in output_ingredients]
128
+ output_ingredients = check_wrong_ingredients(output_ingredients)
129
  output_title = re.search('<TITLE_START> (.*) <TITLE_END>', pre_output_trimmed).group(1).strip().capitalize()
130
  output_instructions = re.search('<INSTR_START> (.*) <INSTR_END>', pre_output_trimmed).group(1)
131
  output_instructions = output_instructions.split(' <NEXT_INSTR> ')