File size: 10,020 Bytes
ead2510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import os
import uuid
import datetime
from typing import Dict, List, Optional, Any
from pymongo import MongoClient
from bson.objectid import ObjectId

class MongoDBHelper:
    """Helper class for MongoDB operations"""
    
    def __init__(self, connection_string: Optional[str] = None):
        """Initialize the MongoDB client"""
        # Get connection string from env var or use provided one
        self.connection_string = connection_string or os.getenv('MONGODB_URI')
        
        if not self.connection_string:
            raise ValueError("MongoDB connection string not provided. Set MONGODB_URI environment variable or pass it to constructor.")
            
        self.client = MongoClient(self.connection_string)
        self.db = self.client.get_database("pyscout_ai")
        
        # Collections
        self.api_keys_collection = self.db.api_keys
        self.usage_collection = self.db.usage
        self.users_collection = self.db.users
        self.conversations_collection = self.db.conversations
        self.messages_collection = self.db.messages
        
        self._create_indexes()
    
    def _create_indexes(self):
        # API Keys indexes
        self.api_keys_collection.create_index("key", unique=True)
        self.api_keys_collection.create_index("user_id")
        self.api_keys_collection.create_index("created_at")
        
        # Usage indexes
        self.usage_collection.create_index("api_key")
        self.usage_collection.create_index("timestamp")
        
        # Users indexes
        self.users_collection.create_index("email", unique=True)
        
        # Conversations indexes
        self.conversations_collection.create_index("user_id")
        self.conversations_collection.create_index("created_at")
        
        # Messages indexes
        self.messages_collection.create_index("conversation_id")
        self.messages_collection.create_index("timestamp")

    def create_user(self, email: str, name: str, organization: str = None) -> str:
        user_id = str(ObjectId())
        self.users_collection.insert_one({
            "_id": ObjectId(user_id),
            "email": email,
            "name": name,
            "organization": organization,
            "created_at": datetime.datetime.utcnow(),
            "last_active": datetime.datetime.utcnow()
        })
        return user_id

    def create_conversation(self, user_id: str, system_prompt: str = None) -> str:
        conversation_id = str(ObjectId())
        self.conversations_collection.insert_one({
            "_id": ObjectId(conversation_id),
            "user_id": user_id,
            "system_prompt": system_prompt,
            "created_at": datetime.datetime.utcnow(),
            "last_message_at": datetime.datetime.utcnow(),
            "is_active": True
        })
        return conversation_id

    def add_message(self, conversation_id: str, role: str, content: str, 

                   model: str = None, tokens: int = 0) -> str:
        message_id = str(ObjectId())
        self.messages_collection.insert_one({
            "_id": ObjectId(message_id),
            "conversation_id": conversation_id,
            "role": role,
            "content": content,
            "model": model,
            "tokens": tokens,
            "timestamp": datetime.datetime.utcnow()
        })
        
        # Update conversation last_message_at
        self.conversations_collection.update_one(
            {"_id": ObjectId(conversation_id)},
            {"$set": {"last_message_at": datetime.datetime.utcnow()}}
        )
        
        return message_id

    def get_conversation_history(self, conversation_id: str) -> List[Dict]:
        return list(self.messages_collection.find(
            {"conversation_id": conversation_id},
            {"_id": 0}
        ).sort("timestamp", 1))

    def get_user_conversations(self, user_id: str, limit: int = 10) -> List[Dict]:
        conversations = list(self.conversations_collection.find(
            {"user_id": user_id},
            {"_id": 1, "system_prompt": 1, "created_at": 1, "last_message_at": 1}
        ).sort("last_message_at", -1).limit(limit))
        
        # Convert ObjectId to string
        for conv in conversations:
            conv["_id"] = str(conv["_id"])
        return conversations

    def generate_api_key(self, user_id: str, name: str = "Default API Key") -> str:
        """Generate a new API key for a user"""
        # Format: PyScoutAI-{uuid4-hex}
        api_key = f"PyScoutAI-{uuid.uuid4().hex}"
        
        # Store in database
        self.api_keys_collection.insert_one({
            "key": api_key,
            "user_id": user_id,
            "name": name,
            "created_at": datetime.datetime.utcnow(),
            "last_used": None,
            "is_active": True,
            "rate_limit": {
                "requests_per_day": 1000,
                "tokens_per_day": 1000000
            }
        })
        
        return api_key
    
    def validate_api_key(self, api_key: str) -> Dict[str, Any]:
        """

        Validate an API key

        

        Returns:

            Dict with user info if valid, None otherwise

        """
        if not api_key:
            return None
            
        # Find the API key in the database
        key_data = self.api_keys_collection.find_one({"key": api_key, "is_active": True})
        if not key_data:
            return None
            
        # Update last used timestamp
        self.api_keys_collection.update_one(
            {"_id": key_data["_id"]},
            {"$set": {"last_used": datetime.datetime.utcnow()}}
        )
        
        return key_data
    
    def log_api_usage(self, api_key: str, endpoint: str, tokens: int = 0, 

                     model: str = None, conversation_id: str = None):
        usage_data = {
            "api_key": api_key,
            "endpoint": endpoint,
            "tokens": tokens,
            "model": model,
            "timestamp": datetime.datetime.utcnow()
        }
        if conversation_id:
            usage_data["conversation_id"] = conversation_id
            
        self.usage_collection.insert_one(usage_data)

    def get_user_api_keys(self, user_id: str) -> List[Dict[str, Any]]:
        """Get all API keys for a user"""
        keys = list(self.api_keys_collection.find({"user_id": user_id}))
        # Convert ObjectId to string for JSON serialization
        for key in keys:
            key["_id"] = str(key["_id"])
        return keys
    
    def revoke_api_key(self, api_key: str) -> bool:
        """Revoke an API key"""
        result = self.api_keys_collection.update_one(
            {"key": api_key},
            {"$set": {"is_active": False}}
        )
        return result.modified_count > 0
    
    def check_rate_limit(self, api_key: str) -> Dict[str, Any]:
        """

        Check if the API key has exceeded its rate limits

        

        Returns:

            Dict with rate limit info and allowed status

        """
        key_data = self.api_keys_collection.find_one({"key": api_key, "is_active": True})
        if not key_data:
            return {"allowed": False, "reason": "Invalid API key"}
            
        # Get rate limit settings
        rate_limit = key_data.get("rate_limit", {})
        requests_per_day = rate_limit.get("requests_per_day", 1000)
        tokens_per_day = rate_limit.get("tokens_per_day", 1000000)
        
        # Calculate usage for today
        today_start = datetime.datetime.combine(
            datetime.datetime.utcnow().date(),
            datetime.time.min
        )
        
        # Count requests today
        requests_today = self.usage_collection.count_documents({
            "api_key": api_key,
            "timestamp": {"$gte": today_start}
        })
        
        # Sum tokens used today
        tokens_pipeline = [
            {"$match": {"api_key": api_key, "timestamp": {"$gte": today_start}}},
            {"$group": {"_id": None, "total_tokens": {"$sum": "$tokens"}}}
        ]
        tokens_result = list(self.usage_collection.aggregate(tokens_pipeline))
        tokens_today = tokens_result[0]["total_tokens"] if tokens_result else 0
        
        # Check if limits are exceeded
        if requests_today >= requests_per_day:
            return {
                "allowed": False,
                "reason": "Daily request limit exceeded",
                "limit": requests_per_day,
                "used": requests_today
            }
            
        if tokens_today >= tokens_per_day:
            return {
                "allowed": False,
                "reason": "Daily token limit exceeded",
                "limit": tokens_per_day,
                "used": tokens_today
            }
            
        return {
            "allowed": True,
            "requests": {
                "limit": requests_per_day,
                "used": requests_today,
                "remaining": requests_per_day - requests_today
            },
            "tokens": {
                "limit": tokens_per_day,
                "used": tokens_today,
                "remaining": tokens_per_day - tokens_today
            }
        }

    def get_user_stats(self, user_id: str) -> Dict:
        pipeline = [
            {"$match": {"user_id": user_id}},
            {"$group": {
                "_id": None,
                "total_conversations": {"$sum": 1},
                "total_messages": {"$sum": "$message_count"},
                "total_tokens": {"$sum": "$total_tokens"}
            }}
        ]
        stats = list(self.conversations_collection.aggregate(pipeline))
        return stats[0] if stats else {"total_conversations": 0, "total_messages": 0, "total_tokens": 0}