TimInf commited on
Commit
bcc13b5
·
verified ·
1 Parent(s): beaa316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -128
app.py CHANGED
@@ -5,17 +5,18 @@ import numpy as np
5
  import random
6
  import json
7
 
8
- # Model loading (same as before)
9
  bert_model_name = "alexdseo/RecipeBERT"
10
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
11
  bert_model = AutoModel.from_pretrained(bert_model_name)
12
- bert_model.eval()
13
 
 
14
  MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
15
  t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
16
  t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
17
 
18
- # Token mapping for T5 model output processing
19
  special_tokens = t5_tokenizer.all_special_tokens
20
  tokens_map = {
21
  "<sep>": "--",
@@ -23,12 +24,12 @@ tokens_map = {
23
  }
24
 
25
  def get_embedding(text):
26
- """Computes embedding for a text with Mean Pooling over all tokens"""
27
  inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
28
  with torch.no_grad():
29
  outputs = bert_model(**inputs)
30
 
31
- # Mean Pooling - take average of all token embeddings
32
  attention_mask = inputs['attention_mask']
33
  token_embeddings = outputs.last_hidden_state
34
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
@@ -38,18 +39,18 @@ def get_embedding(text):
38
  return (sum_embeddings / sum_mask).squeeze(0)
39
 
40
  def average_embedding(embedding_list):
41
- """Computes the average of a list of embeddings"""
42
  tensors = torch.stack([emb for _, emb in embedding_list])
43
  return tensors.mean(dim=0)
44
 
45
  def get_cosine_similarity(vec1, vec2):
46
- """Computes the cosine similarity between two vectors"""
47
  if torch.is_tensor(vec1):
48
  vec1 = vec1.detach().numpy()
49
  if torch.is_tensor(vec2):
50
  vec2 = vec2.detach().numpy()
51
 
52
- # Make sure vectors have the right shape (flatten if necessary)
53
  vec1 = vec1.flatten()
54
  vec2 = vec2.flatten()
55
 
@@ -57,98 +58,108 @@ def get_cosine_similarity(vec1, vec2):
57
  norm_a = np.linalg.norm(vec1)
58
  norm_b = np.linalg.norm(vec2)
59
 
60
- # Avoid division by zero
61
  if norm_a == 0 or norm_b == 0:
62
  return 0
63
 
64
  return dot_product / (norm_a * norm_b)
65
 
66
  def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
67
- """Computes combined score considering both similarity to average and individual ingredients"""
68
  results = []
69
 
70
  for name, emb in embedding_list:
71
- # Similarity to average vector
72
  avg_similarity = get_cosine_similarity(query_vector, emb)
73
 
74
- # Average similarity to individual ingredients
75
  individual_similarities = [get_cosine_similarity(good_emb, emb)
76
  for _, good_emb in all_good_embeddings]
77
  avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
78
 
79
- # Combined score (weighted average)
80
  combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
81
 
82
  results.append((name, emb, combined_score))
83
 
84
- # Sort by combined score (descending)
85
  results.sort(key=lambda x: x[2], reverse=True)
86
  return results
87
 
88
  def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
89
  """
90
- Finds the best ingredients based on RecipeBERT embeddings.
 
 
 
 
 
 
 
 
 
91
  """
92
- # Ensure no duplicates in lists
93
  required_ingredients = list(set(required_ingredients))
94
  available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
95
 
96
- # Special case: If no required ingredients, randomly select one from available ingredients
97
  if not required_ingredients and available_ingredients:
98
  random_ingredient = random.choice(available_ingredients)
99
  required_ingredients = [random_ingredient]
100
  available_ingredients = [i for i in available_ingredients if i != random_ingredient]
 
101
 
102
- # If still no ingredients or already at max capacity
103
  if not required_ingredients or len(required_ingredients) >= max_ingredients:
104
  return required_ingredients[:max_ingredients]
105
 
106
- # If no additional ingredients available
107
  if not available_ingredients:
108
  return required_ingredients
109
 
110
- # Calculate embeddings for all ingredients
111
  embed_required = [(e, get_embedding(e)) for e in required_ingredients]
112
  embed_available = [(e, get_embedding(e)) for e in available_ingredients]
113
 
114
- # Number of ingredients to add
115
  num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
116
 
117
- # Copy required ingredients to final list
118
  final_ingredients = embed_required.copy()
119
 
120
- # Add best ingredients
121
  for _ in range(num_to_add):
122
- # Calculate average vector of current combination
123
  avg = average_embedding(final_ingredients)
124
 
125
- # Calculate combined scores for all candidates
126
  candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
127
 
128
- # If no candidates left, break
129
  if not candidates:
130
  break
131
 
132
- # Choose best ingredient
133
  best_name, best_embedding, _ = candidates[0]
134
 
135
- # Add best ingredient to final list
136
  final_ingredients.append((best_name, best_embedding))
137
 
138
- # Remove ingredient from available ingredients
139
  embed_available = [item for item in embed_available if item[0] != best_name]
140
 
141
- # Extract only ingredient names
142
  return [name for name, _ in final_ingredients]
143
 
144
  def skip_special_tokens(text, special_tokens):
145
- """Removes special tokens from text"""
146
  for token in special_tokens:
147
  text = text.replace(token, "")
148
  return text
149
 
150
  def target_postprocessing(texts, special_tokens):
151
- """Post-processes generated text"""
152
  if not isinstance(texts, list):
153
  texts = [texts]
154
 
@@ -164,29 +175,31 @@ def target_postprocessing(texts, special_tokens):
164
  return new_texts
165
 
166
  def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
167
- """Validates if the recipe contains approximately the expected ingredients."""
 
 
168
  recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
169
  expected_count = len(expected_ingredients)
170
  return abs(recipe_count - expected_count) == tolerance
171
 
172
  def generate_recipe_with_t5(ingredients_list, max_retries=5):
173
- """Generates a recipe using the T5 recipe generation model with validation."""
174
  original_ingredients = ingredients_list.copy()
175
 
176
  for attempt in range(max_retries):
177
  try:
178
- # For retries after the first attempt, shuffle the ingredients
179
  if attempt > 0:
180
  current_ingredients = original_ingredients.copy()
181
  random.shuffle(current_ingredients)
182
  else:
183
  current_ingredients = ingredients_list
184
 
185
- # Format ingredients as a comma-separated string
186
  ingredients_string = ", ".join(current_ingredients)
187
  prefix = "items: "
188
 
189
- # Generation settings
190
  generation_kwargs = {
191
  "max_length": 512,
192
  "min_length": 64,
@@ -194,8 +207,9 @@ def generate_recipe_with_t5(ingredients_list, max_retries=5):
194
  "top_k": 60,
195
  "top_p": 0.95
196
  }
 
197
 
198
- # Tokenize input
199
  inputs = t5_tokenizer(
200
  prefix + ingredients_string,
201
  max_length=256,
@@ -204,21 +218,21 @@ def generate_recipe_with_t5(ingredients_list, max_retries=5):
204
  return_tensors="jax"
205
  )
206
 
207
- # Generate text
208
  output_ids = t5_model.generate(
209
  input_ids=inputs.input_ids,
210
  attention_mask=inputs.attention_mask,
211
  **generation_kwargs
212
  )
213
 
214
- # Decode and post-process
215
  generated = output_ids.sequences
216
  generated_text = target_postprocessing(
217
  t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
218
  special_tokens
219
  )[0]
220
 
221
- # Parse sections
222
  recipe = {}
223
  sections = generated_text.split("\n")
224
  for section in sections:
@@ -232,177 +246,192 @@ def generate_recipe_with_t5(ingredients_list, max_retries=5):
232
  directions_text = section.replace("directions:", "").strip()
233
  recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
234
 
235
- # If title is missing, create one
236
  if "title" not in recipe:
237
- recipe["title"] = f"Recipe with {', '.join(current_ingredients[:3])}"
238
 
239
- # Ensure all sections exist
240
  if "ingredients" not in recipe:
241
  recipe["ingredients"] = current_ingredients
242
  if "directions" not in recipe:
243
- recipe["directions"] = ["No directions generated"]
244
 
245
- # Validate the recipe
246
  if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
 
247
  return recipe
248
  else:
 
249
  if attempt == max_retries - 1:
 
250
  return recipe
251
 
252
  except Exception as e:
 
253
  if attempt == max_retries - 1:
254
  return {
255
- "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
256
  "ingredients": original_ingredients,
257
- "directions": ["Error generating recipe instructions"]
258
  }
259
 
260
- # Fallback
261
  return {
262
- "title": f"Recipe with {original_ingredients[0] if original_ingredients else 'ingredients'}",
263
  "ingredients": original_ingredients,
264
- "directions": ["Error generating recipe instructions"]
265
  }
266
 
267
  def flutter_api_generate_recipe(ingredients_data):
268
  """
269
- Flutter-friendly API function that processes JSON input
270
- and returns structured JSON output matching your original Flask API
271
  """
272
  try:
273
- # Parse input - handle both string and dict formats
274
  if isinstance(ingredients_data, str):
275
  data = json.loads(ingredients_data)
276
  else:
277
- data = ingredients_data
278
-
279
- # Extract parameters (same as your Flask API)
280
  required_ingredients = data.get('required_ingredients', [])
281
  available_ingredients = data.get('available_ingredients', [])
282
-
283
- # Backward compatibility
284
  if data.get('ingredients') and not required_ingredients:
285
  required_ingredients = data.get('ingredients', [])
286
-
287
  max_ingredients = data.get('max_ingredients', 7)
288
  max_retries = data.get('max_retries', 5)
289
-
290
  if not required_ingredients and not available_ingredients:
291
- return json.dumps({"error": "No ingredients provided"})
292
-
293
- # Find optimal ingredients
294
  optimized_ingredients = find_best_ingredients(
295
  required_ingredients,
296
- available_ingredients,
297
  max_ingredients
298
  )
299
-
300
- # Generate recipe
301
  recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
302
-
303
- # Return in exact same format as your Flask API
304
  result = {
305
  'title': recipe['title'],
306
- 'ingredients': recipe['ingredients'],
307
  'directions': recipe['directions'],
308
  'used_ingredients': optimized_ingredients
309
  }
310
-
311
  return json.dumps(result)
312
-
313
  except Exception as e:
314
- return json.dumps({"error": f"Error in recipe generation: {str(e)}"})
315
 
316
- def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients, max_retries):
317
- """Gradio UI function for web interface"""
318
  try:
319
- # Parse text inputs
320
  required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
321
  available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
322
-
323
- # Create data dict in Flutter API format
324
- data = {
325
  'required_ingredients': required_ingredients,
326
  'available_ingredients': available_ingredients,
327
- 'max_ingredients': max_ingredients,
328
- 'max_retries': max_retries
329
  }
330
-
331
- # Use the same function as Flutter API
332
- result_json = flutter_api_generate_recipe(data)
 
 
 
 
 
 
 
 
333
  result = json.loads(result_json)
334
-
335
  if 'error' in result:
336
  return result['error'], "", "", ""
337
-
338
- # Format for Gradio display
339
  ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']])
340
  directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])])
341
  used_ingredients = ', '.join(result['used_ingredients'])
342
-
343
  return (
344
  result['title'],
345
- ingredients_list,
346
  directions_list,
347
  used_ingredients
348
  )
349
-
350
  except Exception as e:
351
- return f"Error: {str(e)}", "", "", ""
352
-
353
- # Create Gradio Interface
354
- with gr.Blocks(title="AI Recipe Generator") as demo:
355
- gr.Markdown("# 🍳 AI Recipe Generator")
356
- gr.Markdown("Generate recipes using AI with intelligent ingredient combination!")
357
-
358
- with gr.Tab("Web Interface"):
 
359
  with gr.Row():
360
  with gr.Column():
361
  required_ing = gr.Textbox(
362
- label="Required Ingredients (comma-separated)",
363
- placeholder="chicken, rice, onion",
364
  lines=2
365
  )
366
  available_ing = gr.Textbox(
367
- label="Available Ingredients (comma-separated)",
368
- placeholder="garlic, tomato, pepper, herbs",
369
  lines=2
370
  )
371
- max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximum Ingredients")
372
- max_retries = gr.Slider(1, 10, value=5, step=1, label="Max Retries")
373
- generate_btn = gr.Button("Generate Recipe", variant="primary")
374
-
 
 
375
  with gr.Column():
376
- title_output = gr.Textbox(label="Recipe Title", interactive=False)
377
- ingredients_output = gr.Textbox(label="Ingredients", lines=8, interactive=False)
378
- directions_output = gr.Textbox(label="Directions", lines=10, interactive=False)
379
- used_ingredients_output = gr.Textbox(label="Used Ingredients", interactive=False)
380
-
381
- generate_btn.click(
382
- fn=gradio_ui_generate_recipe,
383
- inputs=[required_ing, available_ing, max_ing, max_retries],
384
- outputs=[title_output, ingredients_output, directions_output, used_ingredients_output]
385
- )
386
-
387
- with gr.Tab("API Testing"):
388
- gr.Markdown("### Test the Flutter API")
389
- gr.Markdown("This tab uses the same function that Flutter apps will call via API")
390
-
391
  api_input = gr.Textbox(
392
- label="JSON Input (Flutter API Format)",
393
  placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}',
394
  lines=4
395
  )
396
- api_output = gr.Textbox(label="JSON Output", lines=15, interactive=False)
397
- api_test_btn = gr.Button("Test API", variant="secondary")
398
-
399
  api_test_btn.click(
400
  fn=flutter_api_generate_recipe,
401
  inputs=[api_input],
402
  outputs=[api_output],
403
- api_name="generate_recipe_for_flutter" # <-- Hinzufügen dieser Zeile
404
  )
405
-
406
  gr.Examples(
407
  examples=[
408
  ['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'],
@@ -412,4 +441,4 @@ with gr.Blocks(title="AI Recipe Generator") as demo:
412
  )
413
 
414
  if __name__ == "__main__":
415
- demo.launch()
 
5
  import random
6
  import json
7
 
8
+ # Lade RecipeBERT Modell (für semantische Zutat-Kombination)
9
  bert_model_name = "alexdseo/RecipeBERT"
10
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
11
  bert_model = AutoModel.from_pretrained(bert_model_name)
12
+ bert_model.eval() # Setze das Modell in den Evaluationsmodus
13
 
14
+ # Lade T5 Rezeptgenerierungsmodell
15
  MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
16
  t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
17
  t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
18
 
19
+ # Token Mapping für die T5 Modell-Ausgabe
20
  special_tokens = t5_tokenizer.all_special_tokens
21
  tokens_map = {
22
  "<sep>": "--",
 
24
  }
25
 
26
  def get_embedding(text):
27
+ """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens"""
28
  inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
29
  with torch.no_grad():
30
  outputs = bert_model(**inputs)
31
 
32
+ # Mean Pooling - Mittelwert aller Token-Embeddings
33
  attention_mask = inputs['attention_mask']
34
  token_embeddings = outputs.last_hidden_state
35
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
 
39
  return (sum_embeddings / sum_mask).squeeze(0)
40
 
41
  def average_embedding(embedding_list):
42
+ """Berechnet den Durchschnitt einer Liste von Embeddings"""
43
  tensors = torch.stack([emb for _, emb in embedding_list])
44
  return tensors.mean(dim=0)
45
 
46
  def get_cosine_similarity(vec1, vec2):
47
+ """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren"""
48
  if torch.is_tensor(vec1):
49
  vec1 = vec1.detach().numpy()
50
  if torch.is_tensor(vec2):
51
  vec2 = vec2.detach().numpy()
52
 
53
+ # Stelle sicher, dass die Vektoren die richtige Form haben (flachen sie bei Bedarf ab)
54
  vec1 = vec1.flatten()
55
  vec2 = vec2.flatten()
56
 
 
58
  norm_a = np.linalg.norm(vec1)
59
  norm_b = np.linalg.norm(vec2)
60
 
61
+ # Division durch Null vermeiden
62
  if norm_a == 0 or norm_b == 0:
63
  return 0
64
 
65
  return dot_product / (norm_a * norm_b)
66
 
67
  def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
68
+ """Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten"""
69
  results = []
70
 
71
  for name, emb in embedding_list:
72
+ # Ähnlichkeit zum Durchschnittsvektor
73
  avg_similarity = get_cosine_similarity(query_vector, emb)
74
 
75
+ # Durchschnittliche Ähnlichkeit zu einzelnen Zutaten
76
  individual_similarities = [get_cosine_similarity(good_emb, emb)
77
  for _, good_emb in all_good_embeddings]
78
  avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
79
 
80
+ # Kombinierter Score (gewichteter Durchschnitt)
81
  combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
82
 
83
  results.append((name, emb, combined_score))
84
 
85
+ # Sortiere nach kombiniertem Score (absteigend)
86
  results.sort(key=lambda x: x[2], reverse=True)
87
  return results
88
 
89
  def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
90
  """
91
+ Findet die besten Zutaten basierend auf RecipeBERT Embeddings.
92
+
93
+ Args:
94
+ required_ingredients (list): Benötigte Zutaten, die verwendet werden müssen
95
+ available_ingredients (list): Verfügbare Zutaten zur Auswahl
96
+ max_ingredients (int): Maximale Anzahl von Zutaten für das Rezept
97
+ avg_weight (float): Gewicht für den Durchschnittsvektor
98
+
99
+ Returns:
100
+ list: Die optimale Kombination von Zutaten
101
  """
102
+ # Stelle sicher, dass keine Duplikate in den Listen sind
103
  required_ingredients = list(set(required_ingredients))
104
  available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
105
 
106
+ # Sonderfall: Wenn keine benötigten Zutaten vorhanden sind, wähle zufällig eine aus den verfügbaren Zutaten
107
  if not required_ingredients and available_ingredients:
108
  random_ingredient = random.choice(available_ingredients)
109
  required_ingredients = [random_ingredient]
110
  available_ingredients = [i for i in available_ingredients if i != random_ingredient]
111
+ # print(f"Keine benötigten Zutaten angegeben. Zufällig ausgewählt: {random_ingredient}")
112
 
113
+ # Wenn immer noch keine Zutaten vorhanden oder bereits maximale Kapazität erreicht ist
114
  if not required_ingredients or len(required_ingredients) >= max_ingredients:
115
  return required_ingredients[:max_ingredients]
116
 
117
+ # Wenn keine zusätzlichen Zutaten verfügbar sind
118
  if not available_ingredients:
119
  return required_ingredients
120
 
121
+ # Berechne Embeddings für alle Zutaten
122
  embed_required = [(e, get_embedding(e)) for e in required_ingredients]
123
  embed_available = [(e, get_embedding(e)) for e in available_ingredients]
124
 
125
+ # Anzahl der hinzuzufügenden Zutaten
126
  num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
127
 
128
+ # Kopiere benötigte Zutaten in die endgültige Liste
129
  final_ingredients = embed_required.copy()
130
 
131
+ # Füge die besten Zutaten hinzu
132
  for _ in range(num_to_add):
133
+ # Berechne den Durchschnittsvektor der aktuellen Kombination
134
  avg = average_embedding(final_ingredients)
135
 
136
+ # Berechne kombinierte Scores für alle Kandidaten
137
  candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
138
 
139
+ # Wenn keine Kandidaten mehr übrig sind, breche ab
140
  if not candidates:
141
  break
142
 
143
+ # Wähle die beste Zutat
144
  best_name, best_embedding, _ = candidates[0]
145
 
146
+ # Füge die beste Zutat zur endgültigen Liste hinzu
147
  final_ingredients.append((best_name, best_embedding))
148
 
149
+ # Entferne die Zutat aus den verfügbaren Zutaten
150
  embed_available = [item for item in embed_available if item[0] != best_name]
151
 
152
+ # Extrahiere nur die Zutatennamen
153
  return [name for name, _ in final_ingredients]
154
 
155
  def skip_special_tokens(text, special_tokens):
156
+ """Entfernt spezielle Tokens aus dem Text"""
157
  for token in special_tokens:
158
  text = text.replace(token, "")
159
  return text
160
 
161
  def target_postprocessing(texts, special_tokens):
162
+ """Post-processed generierten Text"""
163
  if not isinstance(texts, list):
164
  texts = [texts]
165
 
 
175
  return new_texts
176
 
177
  def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
178
+ """
179
+ Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält.
180
+ """
181
  recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
182
  expected_count = len(expected_ingredients)
183
  return abs(recipe_count - expected_count) == tolerance
184
 
185
  def generate_recipe_with_t5(ingredients_list, max_retries=5):
186
+ """Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung."""
187
  original_ingredients = ingredients_list.copy()
188
 
189
  for attempt in range(max_retries):
190
  try:
191
+ # Für Wiederholungsversuche nach dem ersten Versuch, mische die Zutaten
192
  if attempt > 0:
193
  current_ingredients = original_ingredients.copy()
194
  random.shuffle(current_ingredients)
195
  else:
196
  current_ingredients = ingredients_list
197
 
198
+ # Formatiere Zutaten als kommaseparierten String
199
  ingredients_string = ", ".join(current_ingredients)
200
  prefix = "items: "
201
 
202
+ # Generationseinstellungen
203
  generation_kwargs = {
204
  "max_length": 512,
205
  "min_length": 64,
 
207
  "top_k": 60,
208
  "top_p": 0.95
209
  }
210
+ # print(f"Versuch {attempt + 1}: {prefix + ingredients_string}")
211
 
212
+ # Tokenisiere Eingabe
213
  inputs = t5_tokenizer(
214
  prefix + ingredients_string,
215
  max_length=256,
 
218
  return_tensors="jax"
219
  )
220
 
221
+ # Generiere Text
222
  output_ids = t5_model.generate(
223
  input_ids=inputs.input_ids,
224
  attention_mask=inputs.attention_mask,
225
  **generation_kwargs
226
  )
227
 
228
+ # Dekodieren und Nachbearbeiten
229
  generated = output_ids.sequences
230
  generated_text = target_postprocessing(
231
  t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
232
  special_tokens
233
  )[0]
234
 
235
+ # Abschnitte parsen
236
  recipe = {}
237
  sections = generated_text.split("\n")
238
  for section in sections:
 
246
  directions_text = section.replace("directions:", "").strip()
247
  recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
248
 
249
+ # Wenn der Titel fehlt, erstelle einen
250
  if "title" not in recipe:
251
+ recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
252
 
253
+ # Stelle sicher, dass alle Abschnitte existieren
254
  if "ingredients" not in recipe:
255
  recipe["ingredients"] = current_ingredients
256
  if "directions" not in recipe:
257
+ recipe["directions"] = ["Keine Anweisungen generiert"]
258
 
259
+ # Validiere das Rezept
260
  if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
261
+ # print(f"Erfolg bei Versuch {attempt + 1}: Rezept hat die richtige Anzahl von Zutaten")
262
  return recipe
263
  else:
264
+ # print(f"Versuch {attempt + 1} fehlgeschlagen: Erwartet {len(original_ingredients)} Zutaten, erhalten {len(recipe['ingredients'])}")
265
  if attempt == max_retries - 1:
266
+ # print("Maximale Wiederholungsversuche erreicht, letztes generiertes Rezept wird zurückgegeben")
267
  return recipe
268
 
269
  except Exception as e:
270
+ # print(f"Fehler bei der Rezeptgenerierung Versuch {attempt + 1}: {str(e)}")
271
  if attempt == max_retries - 1:
272
  return {
273
+ "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
274
  "ingredients": original_ingredients,
275
+ "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
276
  }
277
 
278
+ # Fallback (sollte nicht erreicht werden)
279
  return {
280
+ "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
281
  "ingredients": original_ingredients,
282
+ "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
283
  }
284
 
285
  def flutter_api_generate_recipe(ingredients_data):
286
  """
287
+ Flutter-freundliche API-Funktion, die JSON-Eingaben verarbeitet
288
+ und strukturierte JSON-Ausgaben zurückgibt, die deiner ursprünglichen Flask-API entsprechen.
289
  """
290
  try:
291
+ # Eingabe parsen - behandle sowohl String- als auch Dict-Formate
292
  if isinstance(ingredients_data, str):
293
  data = json.loads(ingredients_data)
294
  else:
295
+ data = ingredients_data # Ist bereits ein Dict (z.B. von Gradio UI)
296
+
297
+ # Parameter extrahieren (wie deine ursprüngliche Flask-API)
298
  required_ingredients = data.get('required_ingredients', [])
299
  available_ingredients = data.get('available_ingredients', [])
300
+
301
+ # Abwärtskompatibilität: Wenn nur 'ingredients' angegeben ist, behandle es als required_ingredients
302
  if data.get('ingredients') and not required_ingredients:
303
  required_ingredients = data.get('ingredients', [])
304
+
305
  max_ingredients = data.get('max_ingredients', 7)
306
  max_retries = data.get('max_retries', 5)
307
+
308
  if not required_ingredients and not available_ingredients:
309
+ return json.dumps({"error": "Keine Zutaten angegeben"})
310
+
311
+ # Optimale Zutaten finden
312
  optimized_ingredients = find_best_ingredients(
313
  required_ingredients,
314
+ available_ingredients,
315
  max_ingredients
316
  )
317
+
318
+ # Rezept mit optimierten Zutaten generieren
319
  recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
320
+
321
+ # Für die Flutter-App formatieren - strukturiertes Format
322
  result = {
323
  'title': recipe['title'],
324
+ 'ingredients': recipe['ingredients'],
325
  'directions': recipe['directions'],
326
  'used_ingredients': optimized_ingredients
327
  }
328
+
329
  return json.dumps(result)
330
+
331
  except Exception as e:
332
+ return json.dumps({"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"})
333
 
334
+ def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients_val, max_retries_val):
335
+ """Gradio UI Funktion für die Web-Oberfläche"""
336
  try:
337
+ # Text-Eingaben parsen
338
  required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
339
  available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]
340
+
341
+ # Erstelle ein Dictionary im Format der Flutter API
342
+ data_for_flutter_api = {
343
  'required_ingredients': required_ingredients,
344
  'available_ingredients': available_ingredients,
345
+ 'max_ingredients': max_ingredients_val, # Verwende den Parameter aus dem Slider
346
+ 'max_retries': max_retries_val # Verwende den Parameter aus dem Slider
347
  }
348
+
349
+ # --- WICHTIG: Wandle das Python-Dictionary in einen JSON-String um,
350
+ # da flutter_api_generate_recipe intern dieses Format erwartet,
351
+ # wenn es über einen Gradio-Input aufgerufen wird, der einen String liefert.
352
+ # Dies simuliert den JSON-String, den die Flutter-App senden würde.
353
+ data_json_string = json.dumps(data_for_flutter_api)
354
+ # ---------------------------------------------------------------------
355
+
356
+ # Verwende dieselbe Funktion wie die Flutter API
357
+ result_json = flutter_api_generate_recipe(data_json_string) # <-- Hier die Änderung
358
+
359
  result = json.loads(result_json)
360
+
361
  if 'error' in result:
362
  return result['error'], "", "", ""
363
+
364
+ # Für die Gradio-Anzeige formatieren
365
  ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']])
366
  directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])])
367
  used_ingredients = ', '.join(result['used_ingredients'])
368
+
369
  return (
370
  result['title'],
371
+ ingredients_list,
372
  directions_list,
373
  used_ingredients
374
  )
375
+
376
  except Exception as e:
377
+ # Fehlermeldung für die Gradio UI
378
+ return f"Fehler: {str(e)}", "", "", ""
379
+
380
+ # Erstelle die Gradio Oberfläche
381
+ with gr.Blocks(title="AI Rezept Generator") as demo:
382
+ gr.Markdown("# 🍳 AI Rezept Generator")
383
+ gr.Markdown("Generiere Rezepte mit KI und intelligenter Zutat-Kombination!")
384
+
385
+ with gr.Tab("Web-Oberfläche"):
386
  with gr.Row():
387
  with gr.Column():
388
  required_ing = gr.Textbox(
389
+ label="Benötigte Zutaten (kommasepariert)",
390
+ placeholder="Hähnchen, Reis, Zwiebel",
391
  lines=2
392
  )
393
  available_ing = gr.Textbox(
394
+ label="Verfügbare Zutaten (kommasepariert, optional)",
395
+ placeholder="Knoblauch, Tomate, Pfeffer, Kräuter",
396
  lines=2
397
  )
398
+ # Die Parameter-Namen für Slider müssen mit den Argumenten der Gradio UI Funktion übereinstimmen
399
+ max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximale Zutaten")
400
+ max_retries = gr.Slider(1, 10, value=5, step=1, label="Max. Wiederholungsversuche")
401
+
402
+ generate_btn = gr.Button("Rezept generieren", variant="primary")
403
+
404
  with gr.Column():
405
+ title_output = gr.Textbox(label="Rezepttitel", interactive=False)
406
+ ingredients_output = gr.Textbox(label="Zutaten", lines=8, interactive=False)
407
+ directions_output = gr.Textbox(label="Anweisungen", lines=10, interactive=False)
408
+ used_ingredients_output = gr.Textbox(label="Verwendete Zutaten", interactive=False)
409
+
410
+ generate_btn.click(
411
+ fn=gradio_ui_generate_recipe,
412
+ inputs=[required_ing, available_ing, max_ing, max_retries], # Hier die Slider-Komponenten übergeben
413
+ outputs=[title_output, ingredients_output, directions_output, used_ingredients_output]
414
+ )
415
+
416
+ with gr.Tab("API-Test"):
417
+ gr.Markdown("### Teste die Flutter API")
418
+ gr.Markdown("Dieser Tab verwendet dieselbe Funktion, die Flutter-Apps über die API aufrufen werden.")
419
+
420
  api_input = gr.Textbox(
421
+ label="JSON-Eingabe (Flutter API-Format)",
422
  placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}',
423
  lines=4
424
  )
425
+ api_output = gr.Textbox(label="JSON-Ausgabe", lines=15, interactive=False)
426
+ api_test_btn = gr.Button("API testen", variant="secondary")
427
+
428
  api_test_btn.click(
429
  fn=flutter_api_generate_recipe,
430
  inputs=[api_input],
431
  outputs=[api_output],
432
+ api_name="generate_recipe_for_flutter" # <-- Dies ist der von Flutter verwendete API-Name
433
  )
434
+
435
  gr.Examples(
436
  examples=[
437
  ['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'],
 
441
  )
442
 
443
  if __name__ == "__main__":
444
+ demo.launch()