artecnosomatic commited on
Commit
c5b2741
Β·
1 Parent(s): f493575

Fix python-dotenv compatibility issue and add advanced conversation model

Browse files
Files changed (4) hide show
  1. app.py +42 -15
  2. spaces_requirements.txt +7 -1
  3. src/conversation_model.py +406 -0
  4. src/main.py +129 -31
app.py CHANGED
@@ -5,12 +5,18 @@ This is the main entry point for the Spaces deployment.
5
  """
6
 
7
  import os
8
- from dotenv import load_dotenv
9
  from flask import Flask, render_template, request, jsonify
10
  from src.main import MemoryAI
11
 
12
- # Load environment variables
13
- load_dotenv()
 
 
 
 
 
 
 
14
 
15
  # Initialize Flask app
16
  app = Flask(__name__)
@@ -48,26 +54,47 @@ def chat():
48
  """Get AI response to user input."""
49
  data = request.json
50
  user_input = data.get('message', '')
 
51
 
52
  if not user_input.strip():
53
  return jsonify({'error': 'Empty message'}), 400
54
 
55
- # Add user input to memories
56
- ai.add_memory(f"User: {user_input}")
57
 
58
- # Get recent memories for context
59
- recent_context = "\n".join(ai.get_recent_memories(3))
60
- full_prompt = f"{recent_context}\n\nUser: {user_input}\nAI:"
61
-
62
- # Generate AI response
63
- response = ai.generate_response(full_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # Add AI response to memories
66
- ai.add_memory(f"AI: {response}")
67
 
 
68
  return jsonify({
69
- 'response': response,
70
- 'memory_count': len(ai.memories)
 
 
 
 
 
 
 
71
  })
72
 
73
  @app.route('/api/save', methods=['POST'])
 
5
  """
6
 
7
  import os
 
8
  from flask import Flask, render_template, request, jsonify
9
  from src.main import MemoryAI
10
 
11
+ # Load environment variables (with fallback if dotenv not available)
12
+ try:
13
+ from dotenv import load_dotenv
14
+ load_dotenv()
15
+ except ImportError:
16
+ print("⚠️ python-dotenv not available, using environment variables directly")
17
+ # Set default values if .env not loaded
18
+ if not os.getenv("MODEL_NAME"):
19
+ os.environ["MODEL_NAME"] = "microsoft/DialoGPT-small"
20
 
21
  # Initialize Flask app
22
  app = Flask(__name__)
 
54
  """Get AI response to user input."""
55
  data = request.json
56
  user_input = data.get('message', '')
57
+ conversation_history = data.get('history', [])
58
 
59
  if not user_input.strip():
60
  return jsonify({'error': 'Empty message'}), 400
61
 
62
+ # Generate AI response with conversation history
63
+ response = ai.generate_response(user_input, conversation_history=conversation_history)
64
 
65
+ return jsonify({
66
+ 'response': response,
67
+ 'memory_count': len(ai.memories),
68
+ 'conversation_stats': ai.get_conversation_stats() if hasattr(ai, 'get_conversation_stats') else {}
69
+ })
70
+
71
+ @app.route('/api/summary', methods=['GET'])
72
+ def get_summary():
73
+ """Get conversation summary."""
74
+ summary = ai.get_conversation_summary()
75
+ return jsonify({'summary': summary})
76
+
77
+ @app.route('/api/similar', methods=['POST'])
78
+ def find_similar():
79
+ """Find similar memories."""
80
+ data = request.json
81
+ query = data.get('query', '')
82
+ top_k = data.get('top_k', 3)
83
 
84
+ if not query.strip():
85
+ return jsonify({'error': 'Empty query'}), 400
86
 
87
+ similar = ai.find_similar_memories(query, top_k)
88
  return jsonify({
89
+ 'similar_memories': [{'text': text, 'similarity': float(score)} for text, score in similar],
90
+ 'query': query
91
+ })
92
+
93
+ @app.route('/api/reset', methods=['POST'])
94
+ def reset_conversation():
95
+ """Reset conversation state."""
96
+ ai.reset_conversation()
97
+ return jsonify({'status': 'success', 'message': 'Conversation reset'})
98
  })
99
 
100
  @app.route('/api/save', methods=['POST'])
spaces_requirements.txt CHANGED
@@ -7,6 +7,12 @@ flask==3.1.2
7
  python-dotenv==1.0.1
8
  datasets==4.4.2
9
  accelerate==1.12.0
 
 
10
  blinker==1.9.0
11
  itsdangerous==2.2.0
12
- werkzeug==3.1.5
 
 
 
 
 
7
  python-dotenv==1.0.1
8
  datasets==4.4.2
9
  accelerate==1.12.0
10
+ sentence-transformers==2.2.2
11
+ scikit-learn==1.5.0
12
  blinker==1.9.0
13
  itsdangerous==2.2.0
14
+ werkzeug==3.1.5
15
+ # Ensure proper Flask dependencies
16
+ click==8.1.7
17
+ jinja2==3.1.4
18
+ markupsafe==2.1.5
src/conversation_model.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Advanced Conversation Model for MemoryAI
4
+
5
+ This module provides enhanced conversation capabilities with:
6
+ - Multi-turn dialog management
7
+ - Context-aware response generation
8
+ - Personality and style control
9
+ - Emotion detection and response
10
+ - Topic tracking and continuity
11
+ """
12
+
13
+ import os
14
+ import re
15
+ import random
16
+ from typing import List, Dict, Optional, Tuple
17
+ from datetime import datetime
18
+ import numpy as np
19
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
20
+ from sentence_transformers import SentenceTransformer
21
+ from sklearn.metrics.pairwise import cosine_similarity
22
+ import torch
23
+
24
+ # Check for GPU availability
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ class ConversationModel:
28
+ """
29
+ Advanced conversation model with memory and context awareness.
30
+
31
+ Features:
32
+ - Multi-turn conversation handling
33
+ - Context-aware responses
34
+ - Emotion detection
35
+ - Topic tracking
36
+ - Personality control
37
+ """
38
+
39
+ def __init__(self, model_name: str = "facebook/blenderbot-400M-distill",
40
+ embedding_model: str = "all-MiniLM-L6-v2"):
41
+ """
42
+ Initialize the conversation model.
43
+
44
+ Args:
45
+ model_name: Hugging Face model name for conversation
46
+ embedding_model: Model for semantic embeddings
47
+ """
48
+ self.model_name = model_name
49
+ self.embedding_model_name = embedding_model
50
+
51
+ # Load models
52
+ self.tokenizer = None
53
+ self.model = None
54
+ self.embedding_model = None
55
+ self.conversation_pipeline = None
56
+
57
+ self.load_models()
58
+
59
+ # Conversation state
60
+ self.conversation_history = []
61
+ self.current_topic = "general"
62
+ self.user_emotion = "neutral"
63
+ self.conversation_length = 0
64
+
65
+ # Personality settings
66
+ self.personality = {
67
+ "friendliness": 0.8,
68
+ "humor": 0.6,
69
+ "formality": 0.3,
70
+ "verbosity": 0.7,
71
+ "curiosity": 0.9
72
+ }
73
+
74
+ # Response enhancements
75
+ self.response_enhancers = {
76
+ "greetings": ["Hello!", "Hi there!", "Hey!", "Greetings!", "Nice to see you!"],
77
+ "goodbyes": ["Goodbye!", "See you later!", "Take care!", "Bye!", "Have a great day!"],
78
+ "agreements": ["Yes!", "Absolutely!", "I agree!", "Exactly!", "You're right!"],
79
+ "disagreements": ["I see your point, but...", "That's interesting, however...",
80
+ "I understand, but I think...", "That's a good perspective, but..."],
81
+ "questions": ["What do you think about that?", "Does that make sense?",
82
+ "How does that sound?", "What's your opinion?"]
83
+ }
84
+
85
+ def load_models(self):
86
+ """Load the conversation and embedding models."""
87
+ try:
88
+ print(f"Loading conversation model: {self.model_name}")
89
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
90
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(device)
91
+
92
+ # Create conversation pipeline
93
+ self.conversation_pipeline = pipeline(
94
+ "conversational",
95
+ model=self.model,
96
+ tokenizer=self.tokenizer,
97
+ device=0 if torch.cuda.is_available() else -1
98
+ )
99
+
100
+ print(f"Loading embedding model: {self.embedding_model_name}")
101
+ self.embedding_model = SentenceTransformer(self.embedding_model_name)
102
+
103
+ print("βœ… Models loaded successfully!")
104
+
105
+ except Exception as e:
106
+ print(f"❌ Error loading models: {e}")
107
+ # Fallback to simpler model
108
+ print("Falling back to basic conversation model...")
109
+ self.model_name = "microsoft/DialoGPT-small"
110
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
111
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(device)
112
+ self.conversation_pipeline = pipeline(
113
+ "conversational",
114
+ model=self.model,
115
+ tokenizer=self.tokenizer,
116
+ device=0 if torch.cuda.is_available() else -1
117
+ )
118
+
119
+ def detect_emotion(self, text: str) -> str:
120
+ """Detect emotion in user input."""
121
+ # Simple emotion detection based on keywords
122
+ text_lower = text.lower()
123
+
124
+ happy_keywords = ["happy", "joy", "excited", "great", "awesome", "wonderful", "love"]
125
+ sad_keywords = ["sad", "unhappy", "depressed", "terrible", "awful", "hate"]
126
+ angry_keywords = ["angry", "mad", "furious", "annoyed", "frustrated"]
127
+
128
+ if any(keyword in text_lower for keyword in happy_keywords):
129
+ return "happy"
130
+ elif any(keyword in text_lower for keyword in sad_keywords):
131
+ return "sad"
132
+ elif any(keyword in text_lower for keyword in angry_keywords):
133
+ return "angry"
134
+ else:
135
+ return "neutral"
136
+
137
+ def detect_topic(self, text: str) -> str:
138
+ """Detect the topic of conversation."""
139
+ text_lower = text.lower()
140
+
141
+ topic_keywords = {
142
+ "technology": ["tech", "computer", "software", "hardware", "ai", "machine learning"],
143
+ "sports": ["sports", "game", "football", "basketball", "soccer", "tennis"],
144
+ "movies": ["movie", "film", "cinema", "actor", "actress", "director"],
145
+ "music": ["music", "song", "band", "artist", "concert", "album"],
146
+ "travel": ["travel", "vacation", "trip", "hotel", "flight", "destination"],
147
+ "food": ["food", "restaurant", "cooking", "recipe", "cuisine", "dish"],
148
+ "work": ["work", "job", "career", "office", "meeting", "project"],
149
+ "personal": ["life", "family", "friend", "relationship", "feeling", "emotion"]
150
+ }
151
+
152
+ for topic, keywords in topic_keywords.items():
153
+ if any(keyword in text_lower for keyword in keywords):
154
+ return topic
155
+
156
+ return "general"
157
+
158
+ def generate_response(self, user_input: str, conversation_history: List[Dict] = None) -> str:
159
+ """
160
+ Generate a response to user input with full conversation context.
161
+
162
+ Args:
163
+ user_input: The user's message
164
+ conversation_history: Previous conversation turns
165
+
166
+ Returns:
167
+ Generated response string
168
+ """
169
+ if conversation_history is None:
170
+ conversation_history = []
171
+
172
+ # Update conversation state
173
+ self.user_emotion = self.detect_emotion(user_input)
174
+ self.current_topic = self.detect_topic(user_input)
175
+ self.conversation_length += 1
176
+
177
+ # Add current input to history
178
+ conversation_history.append({"role": "user", "content": user_input})
179
+
180
+ try:
181
+ # Generate response using the conversation model
182
+ response = self.conversation_pipeline(
183
+ conversation_history,
184
+ max_length=150,
185
+ temperature=0.7,
186
+ top_p=0.9,
187
+ repetition_penalty=1.2,
188
+ num_return_sequences=1,
189
+ do_sample=True # Enable sampling for temperature/top_p to work
190
+ )
191
+
192
+ # Handle different response formats
193
+ if isinstance(response, list) and len(response) > 0:
194
+ if 'generated_text' in response[0]:
195
+ generated_text = response[0]['generated_text']
196
+ elif 'text' in response[0]:
197
+ generated_text = response[0]['text']
198
+ else:
199
+ # Try to get the first available text
200
+ generated_text = str(response[0].get('generated_response', response[0].get('response', '')))
201
+ else:
202
+ generated_text = str(response)
203
+
204
+ # Clean and enhance the response
205
+ enhanced_response = self.enhance_response(generated_text, user_input)
206
+
207
+ # Add to conversation history
208
+ conversation_history.append({"role": "assistant", "content": enhanced_response})
209
+
210
+ return enhanced_response
211
+
212
+ except Exception as e:
213
+ print(f"Error generating response: {e}")
214
+ return self.get_fallback_response(user_input)
215
+
216
+ def enhance_response(self, response: str, user_input: str) -> str:
217
+ """Enhance the generated response based on context and personality."""
218
+ # Clean up the response
219
+ response = self.clean_response(response)
220
+
221
+ # Add personality traits
222
+ response = self.add_personality(response)
223
+
224
+ # Make it more conversational
225
+ response = self.make_conversational(response, user_input)
226
+
227
+ return response
228
+
229
+ def clean_response(self, response: str) -> str:
230
+ """Clean up the generated response text."""
231
+ # Remove special tokens and cleanup
232
+ response = response.strip()
233
+ response = re.sub(r'\s+', ' ', response)
234
+ response = re.sub(r'[""\'\']', '', response)
235
+
236
+ # Capitalize first letter and add period if missing
237
+ if response and response[0].islower():
238
+ response = response[0].upper() + response[1:]
239
+
240
+ if response and response[-1] not in ['.', '!', '?']:
241
+ response += '.'
242
+
243
+ return response
244
+
245
+ def add_personality(self, response: str) -> str:
246
+ """Add personality traits to the response."""
247
+ # Add friendliness
248
+ if self.personality["friendliness"] > 0.7:
249
+ friendly_phrases = ["by the way", "I think", "in my opinion",
250
+ "that's interesting", "I'd say"]
251
+ if random.random() < 0.3: # 30% chance to add friendly phrase
252
+ phrase = random.choice(friendly_phrases)
253
+ response = f"{phrase}, {response}"
254
+
255
+ # Add humor if appropriate
256
+ if self.personality["humor"] > 0.5 and self.user_emotion in ["happy", "neutral"]:
257
+ if random.random() < 0.2: # 20% chance to add humor
258
+ humor_tags = ["πŸ˜„", "😊", "🀣", "πŸ˜†"]
259
+ response += " " + random.choice(humor_tags)
260
+
261
+ return response
262
+
263
+ def make_conversational(self, response: str, user_input: str) -> str:
264
+ """Make the response more conversational and context-aware."""
265
+ # Add context references
266
+ if self.conversation_length > 1:
267
+ context_phrases = [
268
+ "As we were discussing",
269
+ "Regarding what you mentioned",
270
+ "Building on that idea",
271
+ "That reminds me"
272
+ ]
273
+ if random.random() < 0.25:
274
+ response = f"{random.choice(context_phrases)}, {response}"
275
+
276
+ # Add follow-up questions
277
+ if random.random() < 0.4: # 40% chance to add a follow-up
278
+ follow_ups = [
279
+ "What do you think about that?",
280
+ "Does that make sense?",
281
+ "How does that sound to you?",
282
+ "Would you like me to elaborate?"
283
+ ]
284
+ response += " " + random.choice(follow_ups)
285
+
286
+ return response
287
+
288
+ def get_fallback_response(self, user_input: str) -> str:
289
+ """Get a fallback response when model generation fails."""
290
+ fallback_responses = [
291
+ "That's an interesting question! Let me think about that...",
292
+ "I'm not sure I understand completely. Could you elaborate?",
293
+ "That's a complex topic. What specifically would you like to know?",
294
+ "I'd love to help with that. Can you provide more details?",
295
+ "That's fascinating! Tell me more about what you're thinking."
296
+ ]
297
+
298
+ return random.choice(fallback_responses)
299
+
300
+ def get_conversation_summary(self) -> str:
301
+ """Get a summary of the current conversation."""
302
+ if not self.conversation_history:
303
+ return "No conversation history yet."
304
+
305
+ summary = f"Conversation Summary:\n"
306
+ summary += f"- Topic: {self.current_topic}\n"
307
+ summary += f"- User Emotion: {self.user_emotion}\n"
308
+ summary += f"- Duration: {self.conversation_length} turns\n"
309
+ summary += f"- Main Points:\n"
310
+
311
+ # Extract key points from conversation
312
+ for i, turn in enumerate(self.conversation_history):
313
+ role = "You" if turn["role"] == "user" else "AI"
314
+ summary += f" {i+1}. {role}: {turn['content'][:50]}...\n"
315
+
316
+ return summary
317
+
318
+ def find_similar_conversations(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]:
319
+ """Find similar conversations from history using semantic search."""
320
+ if not self.conversation_history or not self.embedding_model:
321
+ return []
322
+
323
+ try:
324
+ # Get embedding for the query
325
+ query_embedding = self.embedding_model.encode([query])
326
+
327
+ # Get embeddings for conversation history
328
+ history_texts = [turn["content"] for turn in self.conversation_history if turn["role"] == "user"]
329
+ history_embeddings = self.embedding_model.encode(history_texts)
330
+
331
+ # Calculate similarities
332
+ similarities = cosine_similarity(query_embedding, history_embeddings)[0]
333
+
334
+ # Get top k similar conversations
335
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
336
+
337
+ similar_conversations = []
338
+ for idx in top_indices:
339
+ similar_conversations.append((history_texts[idx], similarities[idx]))
340
+
341
+ return similar_conversations
342
+
343
+ except Exception as e:
344
+ print(f"Error in semantic search: {e}")
345
+ return []
346
+
347
+ def reset_conversation(self):
348
+ """Reset the conversation state."""
349
+ self.conversation_history = []
350
+ self.current_topic = "general"
351
+ self.user_emotion = "neutral"
352
+ self.conversation_length = 0
353
+ print("Conversation reset successfully!")
354
+
355
+ def get_conversation_stats(self) -> Dict:
356
+ """Get statistics about the current conversation."""
357
+ return {
358
+ "length": self.conversation_length,
359
+ "current_topic": self.current_topic,
360
+ "user_emotion": self.user_emotion,
361
+ "personality": self.personality,
362
+ "model": self.model_name
363
+ }
364
+
365
+ # Example usage and testing
366
+ if __name__ == "__main__":
367
+ print("πŸ€– Advanced Conversation Model - Testing")
368
+ print("=" * 50)
369
+
370
+ # Initialize the conversation model
371
+ conv_model = ConversationModel()
372
+
373
+ # Test conversation
374
+ print("Starting test conversation...")
375
+
376
+ conversation = []
377
+
378
+ # Test inputs
379
+ test_inputs = [
380
+ "Hello! How are you doing today?",
381
+ "I'm really excited about the new AI technologies!",
382
+ "What do you think about machine learning?",
383
+ "Can you tell me more about neural networks?",
384
+ "That was very helpful, thank you!"
385
+ ]
386
+
387
+ for user_input in test_inputs:
388
+ print(f"\nπŸ‘€ User: {user_input}")
389
+ response = conv_model.generate_response(user_input, conversation)
390
+ print(f"πŸ€– AI: {response}")
391
+
392
+ # Show conversation stats
393
+ stats = conv_model.get_conversation_stats()
394
+ print(f"πŸ“Š Topic: {stats['current_topic']} | Emotion: {stats['user_emotion']}")
395
+
396
+ # Show conversation summary
397
+ print(f"\n{conv_model.get_conversation_summary()}")
398
+
399
+ # Test semantic search
400
+ print("\nπŸ” Testing semantic search...")
401
+ similar = conv_model.find_similar_conversations("AI technologies", top_k=2)
402
+ print("Similar conversations found:")
403
+ for text, score in similar:
404
+ print(f" - '{text[:30]}...' (similarity: {score:.3f})")
405
+
406
+ print("\nβœ… Conversation model testing complete!")
src/main.py CHANGED
@@ -8,12 +8,23 @@ import os
8
  from dotenv import load_dotenv
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
  import torch
 
11
 
12
- # Load environment variables
13
- load_dotenv()
 
 
 
 
 
 
 
 
 
 
14
 
15
  class MemoryAI:
16
- def __init__(self):
17
  """Initialize the AI model and memory system."""
18
  self.model_name = os.getenv("MODEL_NAME", "gpt2")
19
  self.max_memory = int(os.getenv("MAX_MEMORY_ENTRIES", 100))
@@ -29,19 +40,33 @@ class MemoryAI:
29
  # Initialize memory storage
30
  self.memories = []
31
 
32
- # Load model and tokenizer
33
- print(f"Loading {self.model_name} model...")
34
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
35
- self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
36
 
37
- # Move model to GPU if available
38
- if torch.cuda.is_available():
39
- self.model = self.model.to('cuda')
40
- print("Using CUDA (GPU acceleration)")
41
- else:
42
- print("Using CPU")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- print(f"Initialized {self.model_name} model")
45
  print(f"Memory capacity: {self.max_memory} entries")
46
  print(f"Generation params - Temp: {self.temperature}, Max tokens: {self.max_new_tokens}")
47
 
@@ -53,8 +78,33 @@ class MemoryAI:
53
  self.memories.append(memory_text)
54
  print(f"Memory added. Total memories: {len(self.memories)}")
55
 
56
- def generate_response(self, prompt, max_new_tokens=80):
57
  """Generate a response using the AI model with improved quality."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Improved prompt engineering for conversational AI
59
  if "microsoft/DialoGPT" in self.model_name:
60
  # DialoGPT uses a different format
@@ -66,7 +116,7 @@ class MemoryAI:
66
  inputs = self.tokenizer(improved_prompt, return_tensors="pt")
67
 
68
  # Move inputs to same device as model
69
- if next(self.model.parameters()).is_cuda:
70
  inputs = {k: v.to('cuda') for k, v in inputs.items()}
71
 
72
  # Generate with better parameters
@@ -151,44 +201,62 @@ class MemoryAI:
151
 
152
  def converse(self):
153
  """Start a conversation loop with the AI."""
154
- print("Starting conversation... Type 'quit' to exit.")
 
155
  print("Type '!memories' to see recent memories, '!clear' to clear memories")
 
 
 
 
 
156
 
157
  while True:
158
- user_input = input("You: ")
159
 
160
  if user_input.lower() == 'quit':
 
161
  break
162
 
163
  # Handle special commands
164
  if user_input.lower() == '!memories':
165
  recent_memories = self.get_recent_memories()
166
- print("Recent memories:")
167
  for i, memory in enumerate(recent_memories, 1):
168
  print(f" {i}. {memory}")
169
  continue
170
 
171
  if user_input.lower() == '!clear':
172
  self.clear_memories()
 
173
  continue
174
 
175
- # Add user input to memories
176
- self.add_memory(f"User: {user_input}")
177
-
178
- # Get recent memories for context
179
- recent_context = "\n".join(self.get_recent_memories(3))
180
- full_prompt = f"{recent_context}\n\nUser: {user_input}\nAI:"
181
 
182
- # Generate AI response
183
- response = self.generate_response(full_prompt)
184
- print(f"AI: {response}")
 
185
 
186
- # Add AI response to memories
187
- self.add_memory(f"AI: {response}")
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  def get_available_models(self):
190
  """Get a list of commonly available models."""
191
- return [
192
  "gpt2",
193
  "distilgpt2",
194
  "gpt2-medium",
@@ -198,6 +266,36 @@ class MemoryAI:
198
  "microsoft/DialoGPT-small",
199
  "microsoft/DialoGPT-medium"
200
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  def save_memories(self):
203
  """Save memories to a file."""
 
8
  from dotenv import load_dotenv
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
  import torch
11
+ from typing import List, Dict, Optional
12
 
13
+ # Import our advanced conversation model
14
+ try:
15
+ from src.conversation_model import ConversationModel
16
+ except ImportError:
17
+ # Fallback import for direct execution
18
+ from conversation_model import ConversationModel
19
+
20
+ # Load environment variables (with fallback if dotenv not available)
21
+ try:
22
+ load_dotenv()
23
+ except ImportError:
24
+ print("⚠️ python-dotenv not available, using default values")
25
 
26
  class MemoryAI:
27
+ def __init__(self, use_advanced_model: bool = True):
28
  """Initialize the AI model and memory system."""
29
  self.model_name = os.getenv("MODEL_NAME", "gpt2")
30
  self.max_memory = int(os.getenv("MAX_MEMORY_ENTRIES", 100))
 
40
  # Initialize memory storage
41
  self.memories = []
42
 
43
+ # Initialize conversation model
44
+ self.use_advanced_model = use_advanced_model
45
+ self.conversation_model = None
 
46
 
47
+ if use_advanced_model:
48
+ try:
49
+ print("Loading advanced conversation model...")
50
+ self.conversation_model = ConversationModel()
51
+ print("βœ… Advanced conversation model loaded!")
52
+ except Exception as e:
53
+ print(f"❌ Error loading advanced model: {e}")
54
+ print("Falling back to basic model...")
55
+ self.use_advanced_model = False
56
+
57
+ # Load basic model as fallback
58
+ if not self.use_advanced_model:
59
+ print(f"Loading basic {self.model_name} model...")
60
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
61
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
62
+
63
+ # Move model to GPU if available
64
+ if torch.cuda.is_available():
65
+ self.model = self.model.to('cuda')
66
+ print("Using CUDA (GPU acceleration)")
67
+ else:
68
+ print("Using CPU")
69
 
 
70
  print(f"Memory capacity: {self.max_memory} entries")
71
  print(f"Generation params - Temp: {self.temperature}, Max tokens: {self.max_new_tokens}")
72
 
 
78
  self.memories.append(memory_text)
79
  print(f"Memory added. Total memories: {len(self.memories)}")
80
 
81
+ def generate_response(self, prompt, max_new_tokens=80, conversation_history=None):
82
  """Generate a response using the AI model with improved quality."""
83
+ # Use advanced conversation model if available
84
+ if self.use_advanced_model and self.conversation_model:
85
+ try:
86
+ # Convert memory to conversation history format
87
+ conv_history = []
88
+ if conversation_history:
89
+ for entry in conversation_history:
90
+ conv_history.append({"role": entry.get("role", "user"),
91
+ "content": entry.get("content", entry.get("text", ""))})
92
+
93
+ # Generate response using advanced model
94
+ response = self.conversation_model.generate_response(prompt, conv_history)
95
+
96
+ # Add to memories
97
+ self.add_memory(f"User: {prompt}")
98
+ self.add_memory(f"AI: {response}")
99
+
100
+ return response
101
+
102
+ except Exception as e:
103
+ print(f"Advanced model error: {e}")
104
+ # Fallback to basic model
105
+ pass
106
+
107
+ # Fallback to basic model
108
  # Improved prompt engineering for conversational AI
109
  if "microsoft/DialoGPT" in self.model_name:
110
  # DialoGPT uses a different format
 
116
  inputs = self.tokenizer(improved_prompt, return_tensors="pt")
117
 
118
  # Move inputs to same device as model
119
+ if hasattr(self, 'model') and next(self.model.parameters()).is_cuda:
120
  inputs = {k: v.to('cuda') for k, v in inputs.items()}
121
 
122
  # Generate with better parameters
 
201
 
202
  def converse(self):
203
  """Start a conversation loop with the AI."""
204
+ print("πŸ€– MemoryAI - Advanced Conversation Mode")
205
+ print("Type 'quit' to exit.")
206
  print("Type '!memories' to see recent memories, '!clear' to clear memories")
207
+ print("Type '!summary' for conversation summary, '!reset' to reset conversation")
208
+ print("=" * 60)
209
+
210
+ # Initialize conversation history for advanced model
211
+ conversation_history = []
212
 
213
  while True:
214
+ user_input = input("πŸ‘€ You: ")
215
 
216
  if user_input.lower() == 'quit':
217
+ print("πŸ€– AI: Goodbye! Have a great day!")
218
  break
219
 
220
  # Handle special commands
221
  if user_input.lower() == '!memories':
222
  recent_memories = self.get_recent_memories()
223
+ print("πŸ“š Recent memories:")
224
  for i, memory in enumerate(recent_memories, 1):
225
  print(f" {i}. {memory}")
226
  continue
227
 
228
  if user_input.lower() == '!clear':
229
  self.clear_memories()
230
+ print("πŸ—‘οΈ Memories cleared!")
231
  continue
232
 
233
+ if user_input.lower() == '!summary' and self.use_advanced_model:
234
+ summary = self.get_conversation_summary()
235
+ print(f"πŸ“Š {summary}")
236
+ continue
 
 
237
 
238
+ if user_input.lower() == '!reset':
239
+ self.reset_conversation()
240
+ conversation_history = []
241
+ continue
242
 
243
+ if user_input.strip():
244
+ # Generate response with conversation history
245
+ response = self.generate_response(user_input, conversation_history=conversation_history)
246
+ print(f"πŸ€– AI: {response}")
247
+
248
+ # Update conversation history
249
+ conversation_history.append({"role": "user", "content": user_input})
250
+ conversation_history.append({"role": "assistant", "content": response})
251
+
252
+ # Show conversation stats if using advanced model
253
+ if self.use_advanced_model and self.conversation_model:
254
+ stats = self.conversation_model.get_conversation_stats()
255
+ print(f"πŸ“Š Topic: {stats['current_topic']} | Emotion: {stats['user_emotion']}")
256
 
257
  def get_available_models(self):
258
  """Get a list of commonly available models."""
259
+ models = [
260
  "gpt2",
261
  "distilgpt2",
262
  "gpt2-medium",
 
266
  "microsoft/DialoGPT-small",
267
  "microsoft/DialoGPT-medium"
268
  ]
269
+
270
+ # Add advanced conversation models
271
+ if self.use_advanced_model:
272
+ models.extend([
273
+ "facebook/blenderbot-400M-distill",
274
+ "facebook/blenderbot-1B-distill",
275
+ "microsoft/DialoGPT-large"
276
+ ])
277
+
278
+ return models
279
+
280
+ def get_conversation_summary(self) -> str:
281
+ """Get a summary of the current conversation."""
282
+ if not self.use_advanced_model or not self.conversation_model:
283
+ return "Conversation summary available only with advanced model."
284
+
285
+ return self.conversation_model.get_conversation_summary()
286
+
287
+ def find_similar_memories(self, query: str, top_k: int = 3) -> list:
288
+ """Find memories similar to the query using semantic search."""
289
+ if not self.use_advanced_model or not self.conversation_model:
290
+ return []
291
+
292
+ return self.conversation_model.find_similar_conversations(query, top_k)
293
+
294
+ def reset_conversation(self):
295
+ """Reset the conversation state."""
296
+ if self.use_advanced_model and self.conversation_model:
297
+ self.conversation_model.reset_conversation()
298
+ print("Conversation reset successfully!")
299
 
300
  def save_memories(self):
301
  """Save memories to a file."""