lamhieu commited on
Commit
3150894
·
1 Parent(s): c1964fd

chore: update something

Browse files
Files changed (1) hide show
  1. lightweight_embeddings/router.py +79 -14
lightweight_embeddings/router.py CHANGED
@@ -2,10 +2,11 @@ from __future__ import annotations
2
 
3
  import logging
4
  import os
 
5
  from datetime import datetime
6
  from typing import Dict, List, Union
7
 
8
- from fastapi import APIRouter, BackgroundTasks, HTTPException, Header
9
  from pydantic import BaseModel, Field
10
 
11
  from .analytics import Analytics
@@ -114,30 +115,94 @@ analytics = Analytics(
114
  sync_interval=30 * 60, # 30 minutes
115
  )
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  @router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
119
  async def create_embeddings(
120
- request: EmbeddingRequest,
121
  background_tasks: BackgroundTasks,
122
- authorization: str = Header(None)
 
123
  ):
124
  """
125
  Generate embeddings for the given text or image inputs.
126
  """
127
  # Check authorization
128
  expected_token = os.environ.get("ACCESS_TOKEN")
 
 
129
  if expected_token:
130
- if not authorization:
131
- raise HTTPException(status_code=401, detail="Authorization header required")
132
-
133
- # Support both "Bearer <token>" and plain token formats
134
- token = authorization
135
- if authorization.startswith("Bearer "):
136
- token = authorization[7:] # Remove "Bearer " prefix
137
-
138
- if token != expected_token:
139
- raise HTTPException(status_code=401, detail="Invalid authorization token")
140
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  try:
142
  modality = detect_model_kind(request.model)
143
  embeddings = await embeddings_service.generate_embeddings(
 
2
 
3
  import logging
4
  import os
5
+ import time
6
  from datetime import datetime
7
  from typing import Dict, List, Union
8
 
9
+ from fastapi import APIRouter, BackgroundTasks, HTTPException, Header, Request
10
  from pydantic import BaseModel, Field
11
 
12
  from .analytics import Analytics
 
115
  sync_interval=30 * 60, # 30 minutes
116
  )
117
 
118
+ # Rate limiting cache: {ip: [timestamp1, timestamp2, ...]}
119
+ rate_limit_cache: Dict[str, List[float]] = {}
120
+
121
+
122
+ def check_rate_limit(
123
+ client_ip: str, max_requests: int = 4, window_seconds: int = 60
124
+ ) -> bool:
125
+ """
126
+ Check if the client IP has exceeded the rate limit.
127
+ Returns True if request is allowed, False if rate limited.
128
+ """
129
+ current_time = time.time()
130
+
131
+ # Clean up old entries and get current requests
132
+ if client_ip in rate_limit_cache:
133
+ # Remove requests older than the window
134
+ rate_limit_cache[client_ip] = [
135
+ timestamp
136
+ for timestamp in rate_limit_cache[client_ip]
137
+ if current_time - timestamp < window_seconds
138
+ ]
139
+ else:
140
+ rate_limit_cache[client_ip] = []
141
+
142
+ # Check if under limit
143
+ if len(rate_limit_cache[client_ip]) < max_requests:
144
+ # Add current request timestamp
145
+ rate_limit_cache[client_ip].append(current_time)
146
+ return True
147
+
148
+ return False
149
+
150
 
151
  @router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
152
  async def create_embeddings(
153
+ request: EmbeddingRequest,
154
  background_tasks: BackgroundTasks,
155
+ fastapi_request: Request,
156
+ authorization: str = Header(None),
157
  ):
158
  """
159
  Generate embeddings for the given text or image inputs.
160
  """
161
  # Check authorization
162
  expected_token = os.environ.get("ACCESS_TOKEN")
163
+ is_authenticated = False
164
+
165
  if expected_token:
166
+ if authorization:
167
+ # Support both "Bearer <token>" and plain token formats
168
+ token = authorization
169
+ if authorization.startswith("Bearer "):
170
+ token = authorization[7:] # Remove "Bearer " prefix
171
+
172
+ if token == expected_token:
173
+ is_authenticated = True
174
+
175
+ # If not authenticated, check rate limit
176
+ if not is_authenticated:
177
+ # Get client IP
178
+ client_ip = fastapi_request.client.host
179
+ if hasattr(fastapi_request.headers, "get"):
180
+ # Check for forwarded IP (in case of proxy)
181
+ forwarded_for = fastapi_request.headers.get("X-Forwarded-For")
182
+ if forwarded_for:
183
+ client_ip = forwarded_for.split(",")[0].strip()
184
+
185
+ real_ip = fastapi_request.headers.get("X-Real-IP")
186
+ if real_ip:
187
+ client_ip = real_ip.strip()
188
+
189
+ # Check rate limit (4 requests per minute)
190
+ if not check_rate_limit(client_ip):
191
+ raise HTTPException(
192
+ status_code=429,
193
+ detail="Rate limit exceeded. Maximum 4 requests per minute for unauthenticated users.",
194
+ )
195
+
196
+ # If no authorization header was provided when ACCESS_TOKEN is set
197
+ if not authorization:
198
+ raise HTTPException(
199
+ status_code=401, detail="Authorization header required"
200
+ )
201
+ else:
202
+ raise HTTPException(
203
+ status_code=401, detail="Invalid authorization token"
204
+ )
205
+
206
  try:
207
  modality = detect_model_kind(request.model)
208
  embeddings = await embeddings_service.generate_embeddings(