Spaces:
Running
Running
import os | |
import datetime | |
from bson.objectid import ObjectId | |
import logging | |
from pymongo import MongoClient, DESCENDING | |
from dotenv import load_dotenv | |
import copy | |
# Cấu hình logging | |
logger = logging.getLogger(__name__) | |
# Tải biến môi trường | |
load_dotenv() | |
# Kết nối MongoDB | |
MONGO_URI = os.getenv("MONGO_URI", "mongodb://localhost:27017/") | |
DATABASE_NAME = os.getenv("MONGO_DB_NAME", "nutribot_db") | |
# Singleton pattern cho kết nối MongoDB | |
_mongo_client = None | |
_db = None | |
def get_db(): | |
"""Trả về instance của MongoDB database (singleton pattern)""" | |
global _mongo_client, _db | |
if _mongo_client is None: | |
try: | |
_mongo_client = MongoClient(MONGO_URI) | |
_db = _mongo_client[DATABASE_NAME] | |
logger.info(f"Đã kết nối đến database: {DATABASE_NAME}") | |
except Exception as e: | |
logger.error(f"Lỗi kết nối MongoDB: {e}") | |
raise | |
return _db | |
def safe_isoformat(timestamp_obj): | |
"""Safely convert timestamp to isoformat string""" | |
if timestamp_obj is None: | |
return None | |
if isinstance(timestamp_obj, str): | |
return timestamp_obj | |
if hasattr(timestamp_obj, 'isoformat'): | |
return timestamp_obj.isoformat() | |
try: | |
return str(timestamp_obj) | |
except: | |
return None | |
def safe_datetime(timestamp_obj): | |
"""Safely convert various timestamp formats to datetime object""" | |
if timestamp_obj is None: | |
return datetime.datetime.now() | |
if isinstance(timestamp_obj, datetime.datetime): | |
return timestamp_obj | |
if isinstance(timestamp_obj, str): | |
try: | |
return datetime.datetime.fromisoformat(timestamp_obj.replace('Z', '+00:00')) | |
except: | |
return datetime.datetime.now() | |
return datetime.datetime.now() | |
class Conversation: | |
def __init__(self, user_id=None, title="Cuộc trò chuyện mới", age_context=None, | |
created_at=None, updated_at=None, is_archived=False, messages=None, | |
conversation_id=None): | |
self.conversation_id = conversation_id | |
self.user_id = user_id | |
self.title = title | |
self.age_context = age_context | |
self.created_at = safe_datetime(created_at) | |
self.updated_at = safe_datetime(updated_at) | |
self.is_archived = is_archived | |
self.messages = messages or [] | |
def create(cls, user_id, title="Cuộc trò chuyện mới", age_context=None): | |
"""Tạo và lưu conversation mới vào database""" | |
try: | |
if isinstance(user_id, str): | |
user_id = ObjectId(user_id) | |
conversation = cls( | |
user_id=user_id, | |
title=title, | |
age_context=age_context | |
) | |
conversation_id = conversation.save() | |
conversation.conversation_id = conversation_id | |
logger.info(f"Created new conversation: {conversation_id}") | |
return conversation_id | |
except Exception as e: | |
logger.error(f"Error creating conversation: {e}") | |
raise | |
def serialize_message_for_following(self, message): | |
"""Serialize message để lưu vào following_messages""" | |
serialized = { | |
"role": message["role"], | |
"content": message["content"], | |
"timestamp": safe_datetime(message.get("timestamp")), | |
"current_version": message.get("current_version", 1), | |
"is_edited": message.get("is_edited", False) | |
} | |
if "sources" in message: | |
serialized["sources"] = message["sources"] | |
if "metadata" in message: | |
serialized["metadata"] = message["metadata"] | |
# Serialize versions | |
if "versions" in message and message["versions"]: | |
serialized["versions"] = [] | |
for version in message["versions"]: | |
version_data = { | |
"version": version["version"], | |
"content": version["content"], | |
"timestamp": safe_datetime(version.get("timestamp")), | |
"following_messages": version.get("following_messages", []) | |
} | |
if "sources" in version: | |
version_data["sources"] = version["sources"] | |
if "metadata" in version: | |
version_data["metadata"] = version["metadata"] | |
serialized["versions"].append(version_data) | |
return serialized | |
def deserialize_message_from_following(self, serialized_message): | |
"""Deserialize message từ following_messages""" | |
message = { | |
"_id": ObjectId(), | |
"role": serialized_message["role"], | |
"content": serialized_message["content"], | |
"timestamp": safe_datetime(serialized_message.get("timestamp")), | |
"current_version": serialized_message.get("current_version", 1), | |
"is_edited": serialized_message.get("is_edited", False) | |
} | |
if "sources" in serialized_message: | |
message["sources"] = serialized_message["sources"] | |
if "metadata" in serialized_message: | |
message["metadata"] = serialized_message["metadata"] | |
# Deserialize versions | |
if "versions" in serialized_message and serialized_message["versions"]: | |
message["versions"] = [] | |
for version_data in serialized_message["versions"]: | |
version = { | |
"version": version_data["version"], | |
"content": version_data["content"], | |
"timestamp": safe_datetime(version_data.get("timestamp")), | |
"following_messages": version_data.get("following_messages", []) | |
} | |
if "sources" in version_data: | |
version["sources"] = version_data["sources"] | |
if "metadata" in version_data: | |
version["metadata"] = version_data["metadata"] | |
message["versions"].append(version) | |
else: | |
# Tạo version mặc định | |
message["versions"] = [{ | |
"version": 1, | |
"content": message["content"], | |
"timestamp": message["timestamp"], | |
"following_messages": [] | |
}] | |
if "sources" in message: | |
message["versions"][0]["sources"] = message["sources"] | |
if "metadata" in message: | |
message["versions"][0]["metadata"] = message["metadata"] | |
return message | |
def capture_following_messages(self, from_message_index): | |
"""Capture tất cả messages sau from_message_index""" | |
following_messages = [] | |
for i in range(from_message_index + 1, len(self.messages)): | |
message = self.messages[i] | |
serialized = self.serialize_message_for_following(message) | |
following_messages.append(serialized) | |
logger.info(f"Captured {len(following_messages)} following messages from index {from_message_index}") | |
return following_messages | |
def restore_following_messages(self, target_message_index, following_messages): | |
"""Restore following messages sau target_message_index""" | |
# Cắt conversation tại target message | |
self.messages = self.messages[:target_message_index + 1] | |
# Restore messages | |
restored_count = 0 | |
for serialized_message in following_messages: | |
message = self.deserialize_message_from_following(serialized_message) | |
self.messages.append(message) | |
restored_count += 1 | |
logger.info(f"Restored {restored_count} following messages after index {target_message_index}") | |
return restored_count | |
def to_dict(self): | |
"""Convert conversation object sang dictionary cho JSON serialization""" | |
try: | |
result = { | |
"id": str(self.conversation_id), | |
"user_id": str(self.user_id), | |
"title": self.title, | |
"age_context": self.age_context, | |
"created_at": self.created_at.isoformat() if self.created_at else None, | |
"updated_at": self.updated_at.isoformat() if self.updated_at else None, | |
"is_archived": self.is_archived, | |
"messages": [] | |
} | |
for message in self.messages: | |
timestamp_str = None | |
if "timestamp" in message: | |
if hasattr(message["timestamp"], 'isoformat'): | |
timestamp_str = message["timestamp"].isoformat() | |
else: | |
timestamp_str = str(message["timestamp"]) | |
message_data = { | |
"id": str(message["_id"]), | |
"_id": str(message["_id"]), | |
"role": message["role"], | |
"content": message["content"], | |
"timestamp": timestamp_str, | |
"current_version": message.get("current_version", 1), | |
"is_edited": message.get("is_edited", False) | |
} | |
if "versions" in message and message["versions"]: | |
message_data["versions"] = [] | |
for version in message["versions"]: | |
version_timestamp = None | |
if "timestamp" in version: | |
if hasattr(version["timestamp"], 'isoformat'): | |
version_timestamp = version["timestamp"].isoformat() | |
else: | |
version_timestamp = str(version["timestamp"]) | |
version_data = { | |
"content": version["content"], | |
"timestamp": version_timestamp, | |
"version": version["version"] | |
} | |
if "sources" in version: | |
version_data["sources"] = version["sources"] | |
message_data["versions"].append(version_data) | |
if "sources" in message: | |
message_data["sources"] = message["sources"] | |
result["messages"].append(message_data) | |
return result | |
except Exception as e: | |
logger.error(f"Lỗi khi convert conversation to dict: {e}") | |
return None | |
def from_dict(cls, conversation_dict): | |
"""Tạo đối tượng Conversation từ dictionary""" | |
if not conversation_dict: | |
return None | |
return cls( | |
conversation_id=conversation_dict.get("_id"), | |
user_id=conversation_dict.get("user_id"), | |
title=conversation_dict.get("title"), | |
age_context=conversation_dict.get("age_context"), | |
created_at=conversation_dict.get("created_at"), | |
updated_at=conversation_dict.get("updated_at"), | |
is_archived=conversation_dict.get("is_archived", False), | |
messages=conversation_dict.get("messages", []) | |
) | |
def save(self): | |
"""Lưu thông tin cuộc hội thoại vào database""" | |
try: | |
db = get_db() | |
conversations_collection = db.conversations | |
self.updated_at = datetime.datetime.now() | |
save_dict = { | |
"user_id": self.user_id, | |
"title": self.title, | |
"age_context": self.age_context, | |
"created_at": self.created_at, | |
"updated_at": self.updated_at, | |
"is_archived": self.is_archived, | |
"messages": [] | |
} | |
for message in self.messages: | |
message_copy = message.copy() | |
message_copy["timestamp"] = safe_datetime(message_copy.get("timestamp")) | |
if "versions" in message_copy: | |
for version in message_copy["versions"]: | |
version["timestamp"] = safe_datetime(version.get("timestamp")) | |
save_dict["messages"].append(message_copy) | |
if not self.conversation_id: | |
insert_result = conversations_collection.insert_one(save_dict) | |
self.conversation_id = insert_result.inserted_id | |
logger.info(f"Saved new conversation with ID: {self.conversation_id}") | |
return self.conversation_id | |
else: | |
conversations_collection.update_one( | |
{"_id": self.conversation_id}, | |
{"$set": save_dict} | |
) | |
logger.info(f"Updated conversation: {self.conversation_id}") | |
return self.conversation_id | |
except Exception as e: | |
logger.error(f"Error saving conversation: {e}") | |
raise | |
def add_message(self, role, content, sources=None, metadata=None, parent_message_id=None): | |
"""Thêm tin nhắn mới vào cuộc hội thoại""" | |
timestamp = datetime.datetime.now() | |
message = { | |
"_id": ObjectId(), | |
"role": role, | |
"content": content, | |
"timestamp": timestamp, | |
"versions": [{ | |
"content": content, | |
"timestamp": timestamp, | |
"version": 1, | |
"following_messages": [] | |
}], | |
"current_version": 1, | |
"parent_message_id": parent_message_id, | |
"is_edited": False | |
} | |
if sources: | |
message["sources"] = sources | |
message["versions"][0]["sources"] = sources | |
if metadata: | |
message["metadata"] = metadata | |
message["versions"][0]["metadata"] = metadata | |
self.messages.append(message) | |
self.updated_at = timestamp | |
self.save() | |
logger.info(f"Added message to conversation {self.conversation_id}") | |
return message["_id"] | |
def edit_message(self, message_id, new_content): | |
"""Chỉnh sửa tin nhắn và lưu following messages vào version hiện tại""" | |
try: | |
message_index = None | |
for i, message in enumerate(self.messages): | |
if str(message["_id"]) == str(message_id): | |
message_index = i | |
break | |
if message_index is None: | |
return False, "Không tìm thấy tin nhắn" | |
message = self.messages[message_index] | |
if message["role"] != "user": | |
return False, "Chỉ có thể chỉnh sửa tin nhắn của người dùng" | |
timestamp = datetime.datetime.now() | |
# Capture following messages cho version hiện tại | |
following_messages = self.capture_following_messages(message_index) | |
# Khởi tạo versions nếu chưa có | |
if "versions" not in message: | |
message["versions"] = [{ | |
"content": message["content"], | |
"timestamp": safe_datetime(message.get("timestamp", timestamp)), | |
"version": 1, | |
"following_messages": following_messages | |
}] | |
if "sources" in message: | |
message["versions"][0]["sources"] = message["sources"] | |
if "metadata" in message: | |
message["versions"][0]["metadata"] = message["metadata"] | |
else: | |
# Cập nhật following_messages cho version hiện tại | |
current_version_index = message.get("current_version", 1) - 1 | |
if current_version_index < len(message["versions"]): | |
message["versions"][current_version_index]["following_messages"] = following_messages | |
# Tạo version mới | |
new_version = len(message["versions"]) + 1 | |
new_version_data = { | |
"content": new_content, | |
"timestamp": timestamp, | |
"version": new_version, | |
"following_messages": [] | |
} | |
message["versions"].append(new_version_data) | |
message["current_version"] = new_version | |
message["content"] = new_content | |
message["is_edited"] = True | |
# Xóa tất cả messages sau message được edit | |
self.messages = self.messages[:message_index + 1] | |
self.updated_at = timestamp | |
self.save() | |
logger.info(f"Edited message {message_id}, created version {new_version} with {len(following_messages)} following messages") | |
return True, "Đã chỉnh sửa tin nhắn thành công" | |
except Exception as e: | |
logger.error(f"Error editing message: {e}") | |
return False, f"Lỗi: {str(e)}" | |
def regenerate_bot_response_after_edit(self, user_message_id, new_response, sources=None): | |
"""Thêm phản hồi bot mới sau khi edit tin nhắn user""" | |
try: | |
user_message_index = None | |
for i, message in enumerate(self.messages): | |
if str(message["_id"]) == str(user_message_id): | |
user_message_index = i | |
break | |
if user_message_index is None: | |
return False, "Không tìm thấy tin nhắn user" | |
timestamp = datetime.datetime.now() | |
bot_message = { | |
"_id": ObjectId(), | |
"role": "bot", | |
"content": new_response, | |
"timestamp": timestamp, | |
"versions": [{ | |
"content": new_response, | |
"timestamp": timestamp, | |
"version": 1, | |
"following_messages": [] | |
}], | |
"current_version": 1, | |
"parent_message_id": self.messages[user_message_index]["_id"], | |
"is_edited": False | |
} | |
if sources: | |
bot_message["sources"] = sources | |
bot_message["versions"][0]["sources"] = sources | |
self.messages.append(bot_message) | |
# Cập nhật following_messages cho version hiện tại của user message | |
user_message = self.messages[user_message_index] | |
if "versions" in user_message: | |
current_version_index = user_message.get("current_version", 1) - 1 | |
if current_version_index < len(user_message["versions"]): | |
# Capture lại following messages bao gồm bot response mới | |
following_messages = self.capture_following_messages(user_message_index) | |
user_message["versions"][current_version_index]["following_messages"] = following_messages | |
self.updated_at = timestamp | |
self.save() | |
logger.info(f"Added bot response after edit for user message {user_message_id}") | |
return True, "Đã tạo phản hồi mới" | |
except Exception as e: | |
logger.error(f"Error adding bot response: {e}") | |
return False, f"Lỗi: {str(e)}" | |
def regenerate_response(self, message_id, new_response, sources=None): | |
"""Tạo version mới cho phản hồi bot""" | |
try: | |
message_index = None | |
for i, message in enumerate(self.messages): | |
if str(message["_id"]) == str(message_id): | |
message_index = i | |
break | |
if message_index is None: | |
return False, "Không tìm thấy tin nhắn" | |
message = self.messages[message_index] | |
if message["role"] != "bot": | |
return False, "Chỉ có thể regenerate phản hồi của bot" | |
timestamp = datetime.datetime.now() | |
# Capture following messages cho version hiện tại | |
following_messages = self.capture_following_messages(message_index) | |
# Khởi tạo versions nếu chưa có | |
if "versions" not in message: | |
message["versions"] = [{ | |
"content": message["content"], | |
"timestamp": safe_datetime(message.get("timestamp", timestamp)), | |
"version": 1, | |
"sources": message.get("sources", []), | |
"following_messages": following_messages | |
}] | |
if "metadata" in message: | |
message["versions"][0]["metadata"] = message["metadata"] | |
else: | |
# Cập nhật following_messages cho version hiện tại | |
current_version_index = message.get("current_version", 1) - 1 | |
if current_version_index < len(message["versions"]): | |
message["versions"][current_version_index]["following_messages"] = following_messages | |
# Tạo version mới | |
new_version = len(message["versions"]) + 1 | |
new_version_data = { | |
"content": new_response, | |
"timestamp": timestamp, | |
"version": new_version, | |
"following_messages": [] | |
} | |
if sources: | |
new_version_data["sources"] = sources | |
message["versions"].append(new_version_data) | |
message["current_version"] = new_version | |
message["content"] = new_response | |
message["is_edited"] = True | |
if sources: | |
message["sources"] = sources | |
# Xóa tất cả messages sau regenerate message | |
self.messages = self.messages[:message_index + 1] | |
self.updated_at = datetime.datetime.now() | |
self.save() | |
logger.info(f"Regenerated response for message {message_id}, version {new_version} with {len(following_messages)} following messages") | |
return True, "Đã tạo phản hồi mới thành công" | |
except Exception as e: | |
logger.error(f"Error regenerating response: {e}") | |
return False, f"Lỗi: {str(e)}" | |
def switch_message_version(self, message_id, version_number): | |
"""Chuyển đổi version của tin nhắn và restore following messages""" | |
try: | |
message_index = None | |
for i, message in enumerate(self.messages): | |
if str(message["_id"]) == str(message_id): | |
message_index = i | |
break | |
if message_index is None: | |
logger.error(f"Message not found: {message_id}") | |
return False | |
message = self.messages[message_index] | |
if not message.get("versions") or version_number > len(message["versions"]) or version_number < 1: | |
logger.error(f"Version {version_number} not found for message {message_id}") | |
return False | |
current_version = message.get("current_version", 1) | |
logger.info(f"Switching message {message_id} from version {current_version} to version {version_number}") | |
selected_version = message["versions"][version_number - 1] | |
# Update message content to selected version | |
message["current_version"] = version_number | |
message["content"] = selected_version["content"] | |
# Update sources and metadata | |
if "sources" in selected_version: | |
message["sources"] = selected_version["sources"] | |
elif "sources" in message: | |
del message["sources"] | |
if "metadata" in selected_version: | |
message["metadata"] = selected_version["metadata"] | |
elif "metadata" in message: | |
del message["metadata"] | |
# Restore following messages từ version được chọn | |
following_messages = selected_version.get("following_messages", []) | |
if following_messages: | |
restored_count = self.restore_following_messages(message_index, following_messages) | |
logger.info(f"Restored {restored_count} following messages from version {version_number}") | |
else: | |
# Nếu không có following messages, chỉ cắt conversation tại message này | |
self.messages = self.messages[:message_index + 1] | |
logger.info(f"No following messages in version {version_number}, truncated conversation at message index {message_index}") | |
self.updated_at = datetime.datetime.now() | |
self.save() | |
logger.info(f"Successfully switched to version {version_number} for message {message_id}") | |
return True | |
except Exception as e: | |
logger.error(f"Error switching message version: {e}") | |
return False | |
def delete_message_and_following(self, message_id): | |
"""Xóa tin nhắn và tất cả tin nhắn sau nó""" | |
try: | |
message_index = None | |
for i, message in enumerate(self.messages): | |
if str(message["_id"]) == str(message_id): | |
message_index = i | |
break | |
if message_index is None: | |
return False, "Không tìm thấy tin nhắn" | |
self.messages = self.messages[:message_index] | |
self.updated_at = datetime.datetime.now() | |
self.save() | |
return True, "Đã xóa tin nhắn và các tin nhắn sau nó" | |
except Exception as e: | |
logger.error(f"Error deleting message: {e}") | |
return False, f"Lỗi: {str(e)}" | |
def delete(self): | |
"""Xóa cuộc hội thoại từ database""" | |
try: | |
if self.conversation_id: | |
db = get_db() | |
conversations_collection = db.conversations | |
conversations_collection.delete_one({"_id": self.conversation_id}) | |
logger.info(f"Đã xóa cuộc hội thoại: {self.conversation_id}") | |
return True | |
return False | |
except Exception as e: | |
logger.error(f"Lỗi khi xóa cuộc hội thoại: {e}") | |
return False | |
def find_by_id(cls, conversation_id): | |
"""Tìm cuộc hội thoại theo ID""" | |
try: | |
db = get_db() | |
conversations_collection = db.conversations | |
if isinstance(conversation_id, str): | |
conversation_id = ObjectId(conversation_id) | |
conversation_dict = conversations_collection.find_one({"_id": conversation_id}) | |
if conversation_dict: | |
return cls.from_dict(conversation_dict) | |
return None | |
except Exception as e: | |
logger.error(f"Lỗi khi tìm cuộc hội thoại: {e}") | |
return None | |
def find_by_user(cls, user_id, limit=50, skip=0, include_archived=False): | |
"""Tìm cuộc hội thoại theo user_id""" | |
try: | |
db = get_db() | |
conversations_collection = db.conversations | |
if isinstance(user_id, str): | |
user_id = ObjectId(user_id) | |
query_filter = {"user_id": user_id} | |
if not include_archived: | |
query_filter["is_archived"] = {"$ne": True} | |
logger.info(f"Querying conversations with filter: {query_filter}, limit: {limit}, skip: {skip}") | |
conversations_cursor = conversations_collection.find(query_filter)\ | |
.sort("updated_at", DESCENDING)\ | |
.skip(skip)\ | |
.limit(limit) | |
conversations_list = list(conversations_cursor) | |
logger.info(f"Found {len(conversations_list)} conversations for user {user_id}") | |
result = [] | |
for conv_dict in conversations_list: | |
conv_obj = cls.from_dict(conv_dict) | |
if conv_obj: | |
result.append(conv_obj) | |
logger.info(f"Returning {len(result)} conversation objects") | |
return result | |
except Exception as e: | |
logger.error(f"Error finding conversations by user: {e}") | |
return [] | |
def count_by_user(cls, user_id, include_archived=False): | |
"""Đếm số cuộc hội thoại của user""" | |
try: | |
db = get_db() | |
conversations_collection = db.conversations | |
if isinstance(user_id, str): | |
user_id = ObjectId(user_id) | |
query_filter = {"user_id": user_id} | |
if not include_archived: | |
query_filter["is_archived"] = {"$ne": True} | |
count = conversations_collection.count_documents(query_filter) | |
logger.info(f"User {user_id} has {count} conversations (include_archived: {include_archived})") | |
return count | |
except Exception as e: | |
logger.error(f"Error counting conversations: {e}") | |
return 0 |