chat_memory / app.py
Swaroop Ingavale
Update
02ce8d2
import os
from flask import Flask, render_template, request, jsonify, session
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
import numpy as np
import logging
from transformers import AutoTokenizer, AutoModel # Keep these
import torch
import torch.nn.functional as F
# Configure logging
logging.basicConfig(level=logging.INFO)
# --- Flask App Setup --- (MUST come before routes or app-dependent code) ---
app = Flask(__name__)
app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', 'a_default_secret_key_please_change')
# --- Initialize Models ---
device = torch.device("cpu") # Force CPU for free tier
if torch.cuda.is_available():
device = torch.device("cuda") # Should not happen on free tier
logging.info(f"Using device: {device}")
tokenizer = None
model = None
client = None
try:
# Load tokenizer and model from HuggingFace Hub using transformers
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
# Re-add from_tf=True here for AutoModel.from_pretrained
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2', from_tf=True).to(device)
logging.info("Tokenizer and AutoModel loaded successfully with from_tf=True.")
except Exception as e:
logging.error(f"Error loading Transformer models: {e}")
tokenizer = None
model = None
# Initialize the Groq client
groq_api_key = os.environ.get("GROQ_API_KEY")
if not groq_api_key:
logging.error("GROQ_API_KEY environment variable not set.")
client = None
else:
client = Groq(api_key=groq_api_key)
logging.info("Groq client initialized.")
# --- Helper function for Mean Pooling ---
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float().to(token_embeddings.device)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# --- Function to get embedding ---
def get_embedding(text):
if tokenizer is None or model is None:
logging.error("Embedding models not loaded. Cannot generate embedding.")
return None
try:
encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
sentence_embedding = F.normalize(sentence_embedding, p=2, dim=1)
return sentence_embedding.cpu().numpy()[0]
except Exception as e:
logging.error(f"Error generating embedding: {e}")
return None
# --- Memory Management Functions (rely on get_embedding) ---
# ... (add_to_memory, retrieve_relevant_memory, construct_prompt, trim_memory, summarize_memory - these remain the same, calling get_embedding) ...
def add_to_memory(mem_list, role, content):
if not content or not content.strip():
logging.warning(f"Attempted to add empty content to memory for role: {role}")
return mem_list
embedding = get_embedding(content)
if embedding is not None:
mem_list.append({"role": role, "content": content, "embedding": embedding.tolist()})
else:
logging.warning(f"Failed to get embedding for message: {content[:50]}...")
mem_list.append({"role": role, "content": content, "embedding": None})
return mem_list
def retrieve_relevant_memory(mem_list, user_input, top_k=5):
if not mem_list or tokenizer is None or model is None:
return []
user_embedding = get_embedding(user_input)
if user_embedding is None:
logging.error("Failed to get user input embedding for retrieval.")
return []
valid_memory_items = []
memory_embeddings_np = []
for m in mem_list:
if m.get("embedding") is not None and isinstance(m["embedding"], list):
try:
np_embedding = np.array(m["embedding"])
if np_embedding.shape == (model.config.hidden_size,): # Use model config for dimension
valid_memory_items.append(m)
memory_embeddings_np.append(np_embedding)
else:
logging.warning(f"Embedding dimension mismatch for memory entry: {m['content'][:50]}...")
except Exception as conv_e:
logging.warning(f"Could not convert embedding for memory entry: {m['content'][:50]}... Error: {conv_e}")
pass
if not valid_memory_items:
return []
similarities = cosine_similarity([user_embedding], np.array(memory_embeddings_np))[0]
relevant_messages_sorted = sorted(zip(similarities, valid_memory_items), key=lambda x: x[0], reverse=True)
return [m[1] for m in relevant_messages_sorted[:top_k]]
def construct_prompt(mem_list, user_input, max_tokens_in_prompt=1000):
relevant_memory_items = retrieve_relevant_memory(mem_list, user_input)
relevant_content_set = {m["content"] for m in relevant_memory_items if "content" in m}
messages_for_api = []
messages_for_api.append({"role": "system", "content": "You are a helpful and friendly AI assistant."})
current_prompt_tokens = len(messages_for_api[0]["content"].split())
context_messages = []
for msg in mem_list:
if "content" in msg and msg["content"] in relevant_content_set and msg["role"] in ["user", "assistant", "system"]:
msg_text = f'{msg["role"]}: {msg["content"]}\n'
msg_tokens = len(msg_text.split())
if current_prompt_tokens + msg_tokens > max_tokens_in_prompt:
break
context_messages.append({"role": msg["role"], "content": msg["content"]})
current_prompt_tokens += msg_tokens
messages_for_api.extend(context_messages)
user_input_tokens = len(user_input.split())
if current_prompt_tokens + user_input_tokens > max_tokens_in_prompt and len(messages_for_api) > 1:
logging.warning(f"User input exceeds max_tokens_in_prompt with existing context. Context may be truncated.")
messages_for_api.append({"role": "user", "content": user_input})
return messages_for_api
def trim_memory(mem_list, max_size=50):
while len(mem_list) > max_size:
mem_list.pop(0)
return mem_list
def summarize_memory(mem_list):
if not mem_list or client is None:
logging.warning("Memory is empty or Groq client not initialized. Cannot summarize.")
return []
long_term_memory = " ".join([m["content"] for m in mem_list if "content" in m])
if not long_term_memory.strip():
logging.warning("Memory content is empty. Cannot summarize.")
return []
try:
summary_completion = client.chat.completions.create(
model="llama-3.1-8b-instruct-fpt",
messages=[
{"role": "system", "content": "Summarize the following conversation for key points. Keep it concise."},
{"role": "user", "content": long_term_memory},
],
max_tokens= 500,
)
summary_text = summary_completion.choices[0].message.content
logging.info("Memory summarized.")
return [{"role": "system", "content": f"Previous conversation summary: {summary_text}"}]
except Exception as e:
logging.error(f"Error summarizing memory: {e}")
return mem_list
# --- Flask Routes --- (MUST come AFTER app is defined) ---
@app.route('/')
def index():
if 'chat_memory' not in session:
session['chat_memory'] = []
return render_template('index.html')
@app.route('/chat', methods=['POST'])
def chat():
# Check if Groq client AND embedding models are initialized
if client is None or tokenizer is None or model is None:
status_code = 500
error_message = "Chatbot backend is not fully initialized (API key or embedding models missing)."
logging.error(error_message)
return jsonify({"response": error_message}), status_code
user_input = request.json.get('message')
if not user_input or not user_input.strip():
return jsonify({"response": "Please enter a message."}), 400
current_memory_serializable = session.get('chat_memory', [])
messages_for_api = construct_prompt(current_memory_serializable, user_input)
try:
completion = client.chat.completions.create(
model="llama-3.1-8b-instruct-fpt",
messages=messages_for_api,
temperature=0.6,
max_tokens=1024,
top_p=0.95,
stream=False,
stop=None,
)
ai_response_content = completion.choices[0].message.content
except Exception as e:
logging.error(f"Error calling Groq API: {e}")
ai_response_content = "Sorry, I encountered an error when trying to respond. Please try again later."
current_memory_serializable = add_to_memory(current_memory_serializable, "user", user_input)
current_memory_serializable = add_to_memory(current_memory_serialable, "assistant", ai_response_content)
current_memory_serializable = trim_memory(current_memory_serializable, max_size=20)
session['chat_memory'] = current_memory_serializable
return jsonify({"response": ai_response_content})
@app.route('/clear_memory', methods=['POST'])
def clear_memory():
session['chat_memory'] = []
logging.info("Chat memory cleared.")
return jsonify({"status": "Memory cleared."})
# --- Running the App ---
if __name__ == '__main__':
# Using Uvicorn instead of Waitress
logging.info("Starting Uvicorn server...")
port = int(os.environ.get('PORT', 7860))
# Use uvicorn.run to start the Flask app (which is a WSGI app)
# It automatically detects it's a WSGI app
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=port)