Spaces:
Sleeping
Sleeping
| import os | |
| from flask import Flask, render_template, request, jsonify, session | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from groq import Groq | |
| import numpy as np | |
| import logging | |
| from transformers import AutoTokenizer, AutoModel # Keep these | |
| import torch | |
| import torch.nn.functional as F | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| # --- Flask App Setup --- (MUST come before routes or app-dependent code) --- | |
| app = Flask(__name__) | |
| app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', 'a_default_secret_key_please_change') | |
| # --- Initialize Models --- | |
| device = torch.device("cpu") # Force CPU for free tier | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") # Should not happen on free tier | |
| logging.info(f"Using device: {device}") | |
| tokenizer = None | |
| model = None | |
| client = None | |
| try: | |
| # Load tokenizer and model from HuggingFace Hub using transformers | |
| tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
| # Re-add from_tf=True here for AutoModel.from_pretrained | |
| model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2', from_tf=True).to(device) | |
| logging.info("Tokenizer and AutoModel loaded successfully with from_tf=True.") | |
| except Exception as e: | |
| logging.error(f"Error loading Transformer models: {e}") | |
| tokenizer = None | |
| model = None | |
| # Initialize the Groq client | |
| groq_api_key = os.environ.get("GROQ_API_KEY") | |
| if not groq_api_key: | |
| logging.error("GROQ_API_KEY environment variable not set.") | |
| client = None | |
| else: | |
| client = Groq(api_key=groq_api_key) | |
| logging.info("Groq client initialized.") | |
| # --- Helper function for Mean Pooling --- | |
| def mean_pooling(model_output, attention_mask): | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float().to(token_embeddings.device) | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| # --- Function to get embedding --- | |
| def get_embedding(text): | |
| if tokenizer is None or model is None: | |
| logging.error("Embedding models not loaded. Cannot generate embedding.") | |
| return None | |
| try: | |
| encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device) | |
| with torch.no_grad(): | |
| model_output = model(**encoded_input) | |
| sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask']) | |
| sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1) | |
| return sentence_embedding.cpu().numpy()[0] | |
| except Exception as e: | |
| logging.error(f"Error generating embedding: {e}") | |
| return None | |
| # --- Memory Management Functions (rely on get_embedding) --- | |
| # ... (add_to_memory, retrieve_relevant_memory, construct_prompt, trim_memory, summarize_memory - these remain the same, calling get_embedding) ... | |
| def add_to_memory(mem_list, role, content): | |
| if not content or not content.strip(): | |
| logging.warning(f"Attempted to add empty content to memory for role: {role}") | |
| return mem_list | |
| embedding = get_embedding(content) | |
| if embedding is not None: | |
| mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()}) | |
| else: | |
| logging.warning(f"Failed to get embedding for message: {content[:50]}...") | |
| mem_list.append({"role": role, "content": content, "embedding": None}) | |
| return mem_list | |
| def retrieve_relevant_memory(mem_list, user_input, top_k=5): | |
| if not mem_list or tokenizer is None or model is None: | |
| return [] | |
| user_embedding = get_embedding(user_input) | |
| if user_embedding is None: | |
| logging.error("Failed to get user input embedding for retrieval.") | |
| return [] | |
| valid_memory_items = [] | |
| memory_embeddings_np = [] | |
| for m in mem_list: | |
| if m.get("embedding") is not None and isinstance(m["embedding"], list): | |
| try: | |
| np_embedding = np.array(m["embedding"]) | |
| if np_embedding.shape == (model.config.hidden_size,): # Use model config for dimension | |
| valid_memory_items.append(m) | |
| memory_embeddings_np.append(np_embedding) | |
| else: | |
| logging.warning(f"Embedding dimension mismatch for memory entry: {m['content'][:50]}...") | |
| except Exception as conv_e: | |
| logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}") | |
| pass | |
| if not valid_memory_items: | |
| return [] | |
| similarities = cosine_similarity([user_embedding], np.array(memory_embeddings_np))[0] | |
| relevant_messages_sorted = sorted(zip(similarities, valid_memory_items), key=lambda x: x[0], reverse=True) | |
| return [m[1] for m in relevant_messages_sorted[:top_k]] | |
| def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000): | |
| relevant_memory_items = retrieve_relevant_memory(mem_list, user_input) | |
| relevant_content_set = {m["content"] for m in relevant_memory_items if "content" in m} | |
| messages_for_api = [] | |
| messages_for_api.append({"role": "system", "content": "You are a helpful and friendly AI assistant."}) | |
| current_prompt_tokens = len(messages_for_api[0]["content"].split()) | |
| context_messages = [] | |
| for msg in mem_list: | |
| if "content" in msg and msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]: | |
| msg_text = f'{msg["role"]}: {msg["content"]}\n' | |
| msg_tokens = len(msg_text.split()) | |
| if current_prompt_tokens + msg_tokens > max_tokens_in_prompt: | |
| break | |
| context_messages.append({"role": msg["role"], "content": msg["content"]}) | |
| current_prompt_tokens += msg_tokens | |
| messages_for_api.extend(context_messages) | |
| user_input_tokens = len(user_input.split()) | |
| if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1: | |
| logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Context may be truncated.") | |
| messages_for_api.append({"role": "user", "content": user_input}) | |
| return messages_for_api | |
| def trim_memory(mem_list, max_size=50): | |
| while len(mem_list) > max_size: | |
| mem_list.pop(0) | |
| return mem_list | |
| def summarize_memory(mem_list): | |
| if not mem_list or client is None: | |
| logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.") | |
| return [] | |
| long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m]) | |
| if not long_term_memory.strip(): | |
| logging.warning("Memory content is empty. Cannot summarize.") | |
| return [] | |
| try: | |
| summary_completion = client.chat.completions.create( | |
| model="llama-3.1-8b-instruct-fpt", | |
| messages=[ | |
| {"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."}, | |
| {"role": "user", "content": long_term_memory}, | |
| ], | |
| max_tokens= 500, | |
| ) | |
| summary_text = summary_completion.choices[0].message.content | |
| logging.info("Memory summarized.") | |
| return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}] | |
| except Exception as e: | |
| logging.error(f"Error summarizing memory: {e}") | |
| return mem_list | |
| # --- Flask Routes --- (MUST come AFTER app is defined) --- | |
| def index(): | |
| if 'chat_memory' not in session: | |
| session['chat_memory'] = [] | |
| return render_template('index.html') | |
| def chat(): | |
| # Check if Groq client AND embedding models are initialized | |
| if client is None or tokenizer is None or model is None: | |
| status_code = 500 | |
| error_message = "Chatbot backend is not fully initialized (API key or embedding models missing)." | |
| logging.error(error_message) | |
| return jsonify({"response": error_message}), status_code | |
| user_input = request.json.get('message') | |
| if not user_input or not user_input.strip(): | |
| return jsonify({"response": "Please enter a message."}), 400 | |
| current_memory_serializable = session.get('chat_memory', []) | |
| messages_for_api = construct_prompt(current_memory_serializable, user_input) | |
| try: | |
| completion = client.chat.completions.create( | |
| model="llama-3.1-8b-instruct-fpt", | |
| messages=messages_for_api, | |
| temperature=0.6, | |
| max_tokens=1024, | |
| top_p=0.95, | |
| stream=False, | |
| stop=None, | |
| ) | |
| ai_response_content = completion.choices[0].message.content | |
| except Exception as e: | |
| logging.error(f"Error calling Groq API: {e}") | |
| ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later." | |
| current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input) | |
| current_memory_serializable = add_to_memory(current_memory_serialable, "assistant", ai_response_content) | |
| current_memory_serializable = trim_memory(current_memory_serializable, max_size=20) | |
| session['chat_memory'] = current_memory_serializable | |
| return jsonify({"response": ai_response_content}) | |
| def clear_memory(): | |
| session['chat_memory'] = [] | |
| logging.info("Chat memory cleared.") | |
| return jsonify({"status": "Memory cleared."}) | |
| # --- Running the App --- | |
| if __name__ == '__main__': | |
| # Using Uvicorn instead of Waitress | |
| logging.info("Starting Uvicorn server...") | |
| port = int(os.environ.get('PORT', 7860)) | |
| # Use uvicorn.run to start the Flask app (which is a WSGI app) | |
| # It automatically detects it's a WSGI app | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=port) |