import streamlit as st import pandas as pd import chromadb from sentence_transformers import SentenceTransformer from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from PIL import Image from io import BytesIO import requests from huggingface_hub import login # --- 1. Load Recipes Dataset --- @st.cache_data def load_recipes(): try: recipes_df = pd.read_csv("recipes.csv") recipes_df = recipes_df.rename(columns={"recipe_name": "title", "directions": "instructions"}) recipes_df = recipes_df[['title', 'ingredients', 'instructions', 'img_src']] recipes_df.fillna("", inplace=True) recipes_df["ingredients"] = recipes_df["ingredients"].str.lower().str.replace(r'[^\w\s]', '', regex=True) recipes_df["combined_text"] = recipes_df["title"] + " " + recipes_df["ingredients"] return recipes_df except Exception as e: st.error(f"⚠ Error loading recipes: {e}") return pd.DataFrame() recipes_df = load_recipes() # --- 2. Load SentenceTransformer Model --- @st.cache_resource def load_embedding_model(): return SentenceTransformer("all-mpnet-base-v2") embedding_model = load_embedding_model() # --- 3. Initialize ChromaDB --- chroma_client = chromadb.PersistentClient(path="./chroma_db") collection = chroma_client.get_or_create_collection(name="recipe_collection") # --- 4. Generate & Store Embeddings --- def get_sentence_transformer_embeddings(text): return embedding_model.encode(text).tolist() try: existing_data = collection.get() existing_ids = set(existing_data["ids"]) if existing_data and "ids" in existing_data else set() except Exception as e: st.error(f"⚠ ChromaDB Error: {e}") existing_ids = set() for index, row in recipes_df.iterrows(): recipe_id = str(index) if recipe_id in existing_ids: continue embedding = get_sentence_transformer_embeddings(row["combined_text"]) if embedding: collection.add(embeddings=[embedding], documents=[row["combined_text"]], ids=[recipe_id]) # --- 5. Retrieve Similar Recipes --- def retrieve_recipes(query, top_k=3): query_embedding = get_sentence_transformer_embeddings(query) results = collection.query(query_embeddings=[query_embedding], n_results=top_k) if results and "ids" in results and results["ids"] and results["ids"][0]: recipe_indices = [int(id) for id in results["ids"][0] if id.isdigit()] return recipes_df.iloc[recipe_indices] if recipe_indices else None return None hf_token = st.secrets["key"] if hf_token is None: raise ValueError("Hugging Face token is missing. Add it as a secret in your Space.") login(token=hf_token) # --- 6. Load Mistral-7B-Instruct --- @st.cache_resource @st.cache_resource def load_mistral_model(): model_name = "mistralai/Mistral-7B-Instruct-v0.3" tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=True) return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150) mistral_model = load_mistral_model() # --- 7. Answer Question Using Mistral --- def answer_question(query, context=""): greetings = ["hi", "hello", "hey", "greetings", "how are you", "what's up"] query_cleaned = query.lower().strip() # Handle greetings if query_cleaned in greetings: return "Hello! I'm here to assist with recipes and food-related questions. 🍽️ What would you like to know?" # Retrieve relevant recipe related_recipes = retrieve_recipes(query, top_k=1) if related_recipes is None or related_recipes.empty: return "I specialize in recipes! 🍽️ Feel free to ask me about ingredients, cooking methods, or meal ideas. 😊" # If found, use its instructions as context context = related_recipes.iloc[0]['instructions'] prompt = f"Context: {context}\n\nQuestion: {query}\nAnswer:" response = mistral_model(prompt) if isinstance(response, list) and response: return response[0].get("generated_text", "I'm not sure, but I can help with recipes! 😊").strip() return "I'm not sure, but I can help with recipes! 😊" # --- 8. Classify Query Type --- @st.cache_resource def load_classifier(): return pipeline("zero-shot-classification", model="facebook/bart-large-mnli", use_auth_token=True) classifier = load_classifier() def classify_query(query): recipe_keywords = ["make", "cook", "bake", "recipe", "prepare"] if any(keyword in query.lower() for keyword in recipe_keywords): return "Recipe Search" labels = ["Q&A", "Recipe Search"] result = classifier(query, candidate_labels=labels, multi_label=False) return result.get("labels", ["Q&A"])[0] # --- 9. Display Image --- def display_image(image_url, recipe_name): try: if not isinstance(image_url, str) or not image_url.startswith("http"): raise ValueError("Invalid or missing image URL") response = requests.get(image_url, timeout=5) response.raise_for_status() image = Image.open(BytesIO(response.content)) st.image(image, caption=recipe_name, use_container_width=True) except requests.exceptions.RequestException as e: st.warning(f"⚠ Image fetch error: {e}") placeholder_url = "https://via.placeholder.com/300?text=No+Image" st.image(placeholder_url, caption=recipe_name, use_container_width=True) # --- 10. Streamlit UI --- st.title("🍽️ AI Recipe & Q&A Assistant (Powered by Mistral-7B)") user_query = st.text_input("Enter your question or recipe search query:", "", key="main_query_input") if "retrieved_recipes" not in st.session_state: st.session_state["retrieved_recipes"] = None if st.button("Ask AI"): if user_query: # Handle greetings separately greeting_response = answer_question(user_query) if greeting_response.startswith("Hello!"): st.subheader("🤖 AI Answer:") st.write(greeting_response) else: # Classify query intent = classify_query(user_query) if intent == "Q&A": st.subheader("🤖 AI Answer:") response = answer_question(user_query) st.write(response) elif intent == "Recipe Search": retrieved_recipes = retrieve_recipes(user_query) if retrieved_recipes is not None and not retrieved_recipes.empty: st.session_state["retrieved_recipes"] = retrieved_recipes st.subheader("🍴 Found Recipes:") for index, recipe in retrieved_recipes.iterrows(): st.markdown(f"### {recipe['title']}") st.write(f"**Ingredients:** {recipe['ingredients']}") st.write(f"**Instructions:** {recipe['instructions']}") display_image(recipe.get('img_src', ''), recipe['title']) else: st.warning("⚠️ No relevant recipes found.") else: st.warning("❌ Unable to classify the query.")