File size: 10,754 Bytes
d71c8ab 1787e4b 288724d 1787e4b bcc13b5 1787e4b 288724d 1787e4b bcc13b5 1787e4b 288724d d71c8ab 1787e4b 288724d d71c8ab 288724d d71c8ab 288724d d71c8ab 288724d d71c8ab 288724d d71c8ab 288724d d71c8ab 288724d d71c8ab 288724d d71c8ab 288724d d71c8ab 288724d d71c8ab 288724d d71c8ab 288724d d71c8ab bcc13b5 d71c8ab bcc13b5 d71c8ab 288724d d71c8ab bcc13b5 d71c8ab bcc13b5 d71c8ab bcc13b5 d71c8ab bcc13b5 d71c8ab 88f38f2 d71c8ab 88f38f2 d71c8ab 88f38f2 1787e4b 288724d 1787e4b d71c8ab bcc13b5 d71c8ab 88f38f2 288724d bcc13b5 288724d bcc13b5 288724d f61a7ca 288724d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
import torch
import numpy as np
import random
import json
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
# Keine Gradio-Imports hier!
# Lade RecipeBERT Modell (für semantische Zutat-Kombination)
bert_model_name = "alexdseo/RecipeBERT"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name)
bert_model.eval()
# Lade T5 Rezeptgenerierungsmodell
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
# Token Mapping (bleibt gleich)
special_tokens = t5_tokenizer.all_special_tokens
tokens_map = {
"<sep>": "--",
"<section>": "\n"
}
# Deine Helper-Funktionen (get_embedding, average_embedding, get_cosine_similarity, etc.)
# ... diese bleiben ALLE GLEICH wie in deinem aktuellen app.py Code ...
# Kopiere alle Funktionen von 'get_embedding' bis 'generate_recipe_with_t5' hierher.
# (Ich kürze sie hier aus Platzgründen, aber sie müssen vollständig eingefügt werden)
def get_embedding(text):
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = bert_model(**inputs)
attention_mask = inputs['attention_mask']
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return (sum_embeddings / sum_mask).squeeze(0)
def average_embedding(embedding_list):
tensors = torch.stack([emb for _, emb in embedding_list])
return tensors.mean(dim=0)
def get_cosine_similarity(vec1, vec2):
if torch.is_tensor(vec1): vec1 = vec1.detach().numpy()
if torch.is_tensor(vec2): vec2 = vec2.detach().numpy()
vec1 = vec1.flatten()
vec2 = vec2.flatten()
dot_product = np.dot(vec1, vec2)
norm_a = np.linalg.norm(vec1)
norm_b = np.linalg.norm(vec2)
if norm_a == 0 or norm_b == 0: return 0
return dot_product / (norm_a * norm_b)
def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
results = []
for name, emb in embedding_list:
avg_similarity = get_cosine_similarity(query_vector, emb)
individual_similarities = [get_cosine_similarity(good_emb, emb) for _, good_emb in all_good_embeddings]
avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
results.append((name, emb, combined_score))
results.sort(key=lambda x: x[2], reverse=True)
return results
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
required_ingredients = list(set(required_ingredients))
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
if not required_ingredients and available_ingredients:
random_ingredient = random.choice(available_ingredients)
required_ingredients = [random_ingredient]
available_ingredients = [i for i in available_ingredients if i != random_ingredient]
if not required_ingredients or len(required_ingredients) >= max_ingredients:
return required_ingredients[:max_ingredients]
if not available_ingredients:
return required_ingredients
embed_required = [(e, get_embedding(e)) for e in required_ingredients]
embed_available = [(e, get_embedding(e)) for e in available_ingredients]
num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
final_ingredients = embed_required.copy()
for _ in range(num_to_add):
avg = average_embedding(final_ingredients)
candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
if not candidates: break
best_name, best_embedding, _ = candidates[0]
final_ingredients.append((best_name, best_embedding))
embed_available = [item for item in embed_available if item[0] != best_name]
return [name for name, _ in final_ingredients]
def skip_special_tokens(text, special_tokens):
for token in special_tokens: text = text.replace(token, "")
return text
def target_postprocessing(texts, special_tokens):
if not isinstance(texts, list): texts = [texts]
new_texts = []
for text in texts:
text = skip_special_tokens(text, special_tokens)
for k, v in tokens_map.items(): text = text.replace(k, v)
new_texts.append(text)
return new_texts
def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
expected_count = len(expected_ingredients)
return abs(recipe_count - expected_count) == tolerance
def generate_recipe_with_t5(ingredients_list, max_retries=5):
original_ingredients = ingredients_list.copy()
for attempt in range(max_retries):
try:
if attempt > 0:
current_ingredients = original_ingredients.copy()
random.shuffle(current_ingredients)
else:
current_ingredients = ingredients_list
ingredients_string = ", ".join(current_ingredients)
prefix = "items: "
generation_kwargs = {
"max_length": 512, "min_length": 64, "do_sample": True,
"top_k": 60, "top_p": 0.95
}
inputs = t5_tokenizer(
prefix + ingredients_string, max_length=256, padding="max_length",
truncation=True, return_tensors="jax"
)
output_ids = t5_model.generate(
input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, **generation_kwargs
)
generated = output_ids.sequences
generated_text = target_postprocessing(t5_tokenizer.batch_decode(generated, skip_special_tokens=False), special_tokens)[0]
recipe = {}
sections = generated_text.split("\n")
for section in sections:
section = section.strip()
if section.startswith("title:"):
recipe["title"] = section.replace("title:", "").strip().capitalize()
elif section.startswith("ingredients:"):
ingredients_text = section.replace("ingredients:", "").strip()
recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()]
elif section.startswith("directions:"):
directions_text = section.replace("directions:", "").strip()
recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
if "title" not in recipe:
recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
if "ingredients" not in recipe:
recipe["ingredients"] = current_ingredients
if "directions" not in recipe:
recipe["directions"] = ["Keine Anweisungen generiert"]
if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
return recipe
else:
if attempt == max_retries - 1: return recipe
except Exception as e:
if attempt == max_retries - 1:
return {
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
"ingredients": original_ingredients,
"directions": ["Fehler beim Generieren der Rezeptanweisungen"]
}
return {
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
"ingredients": original_ingredients,
"directions": ["Fehler beim Generieren der Rezeptanweisungen"]
}
def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
"""
Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
"""
if not required_ingredients and not available_ingredients:
return {"error": "Keine Zutaten angegeben"}
try:
optimized_ingredients = find_best_ingredients(
required_ingredients, available_ingredients, max_ingredients
)
recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
result = {
'title': recipe['title'],
'ingredients': recipe['ingredients'],
'directions': recipe['directions'],
'used_ingredients': optimized_ingredients
}
return result
except Exception as e:
return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
# --- FastAPI-Implementierung ---
app = FastAPI()
class RecipeRequest(BaseModel):
required_ingredients: list[str] = []
available_ingredients: list[str] = []
max_ingredients: int = 7
max_retries: int = 5
# Optional: Für Abwärtskompatibilität, falls 'ingredients' als Top-Level-Feld gesendet wird
# ingredients: list[str] = [] # Dies würde auch akzeptiert und müsste dann in der Logik verarbeitet werden
@app.post("/generate_recipe") # Einfacher Endpunkt, den Flutter aufruft
async def generate_recipe_api(request_data: RecipeRequest):
"""
Standard-REST-API-Endpunkt für die Flutter-App.
Nimmt direkt JSON-Daten an und gibt direkt JSON zurück.
"""
# Verarbeite optionale Abwärtskompatibilität hier, falls nötig
if not request_data.required_ingredients and 'ingredients' in request_data.model_dump():
request_data.required_ingredients = request_data.model_dump()['ingredients']
result_dict = process_recipe_request_logic(
request_data.required_ingredients,
request_data.available_ingredients,
request_data.max_ingredients,
request_data.max_retries
)
return JSONResponse(content=result_dict)
# In diesem Setup gibt es keine Gradio UI, nur die FastAPI-API.
# Dadurch sollte der Space zuverlässiger starten.
print("INFO: FastAPI application script finished execution and defined 'app' variable.")
# Der if __name__ == "__main__": Block wird von Hugging Face Spaces ignoriert,
# da sie den Uvicorn-Server direkt starten, der die 'app'-Variable sucht.
|