Spaces:
Running
Running
chore: update something
Browse files- lightweight_embeddings/router.py +79 -14
lightweight_embeddings/router.py
CHANGED
|
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import os
|
|
|
|
| 5 |
from datetime import datetime
|
| 6 |
from typing import Dict, List, Union
|
| 7 |
|
| 8 |
-
from fastapi import APIRouter, BackgroundTasks, HTTPException, Header
|
| 9 |
from pydantic import BaseModel, Field
|
| 10 |
|
| 11 |
from .analytics import Analytics
|
|
@@ -114,30 +115,94 @@ analytics = Analytics(
|
|
| 114 |
sync_interval=30 * 60, # 30 minutes
|
| 115 |
)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
@router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
|
| 119 |
async def create_embeddings(
|
| 120 |
-
request: EmbeddingRequest,
|
| 121 |
background_tasks: BackgroundTasks,
|
| 122 |
-
|
|
|
|
| 123 |
):
|
| 124 |
"""
|
| 125 |
Generate embeddings for the given text or image inputs.
|
| 126 |
"""
|
| 127 |
# Check authorization
|
| 128 |
expected_token = os.environ.get("ACCESS_TOKEN")
|
|
|
|
|
|
|
| 129 |
if expected_token:
|
| 130 |
-
if
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
token
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
try:
|
| 142 |
modality = detect_model_kind(request.model)
|
| 143 |
embeddings = await embeddings_service.generate_embeddings(
|
|
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
+
import time
|
| 6 |
from datetime import datetime
|
| 7 |
from typing import Dict, List, Union
|
| 8 |
|
| 9 |
+
from fastapi import APIRouter, BackgroundTasks, HTTPException, Header, Request
|
| 10 |
from pydantic import BaseModel, Field
|
| 11 |
|
| 12 |
from .analytics import Analytics
|
|
|
|
| 115 |
sync_interval=30 * 60, # 30 minutes
|
| 116 |
)
|
| 117 |
|
| 118 |
+
# Rate limiting cache: {ip: [timestamp1, timestamp2, ...]}
|
| 119 |
+
rate_limit_cache: Dict[str, List[float]] = {}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def check_rate_limit(
|
| 123 |
+
client_ip: str, max_requests: int = 4, window_seconds: int = 60
|
| 124 |
+
) -> bool:
|
| 125 |
+
"""
|
| 126 |
+
Check if the client IP has exceeded the rate limit.
|
| 127 |
+
Returns True if request is allowed, False if rate limited.
|
| 128 |
+
"""
|
| 129 |
+
current_time = time.time()
|
| 130 |
+
|
| 131 |
+
# Clean up old entries and get current requests
|
| 132 |
+
if client_ip in rate_limit_cache:
|
| 133 |
+
# Remove requests older than the window
|
| 134 |
+
rate_limit_cache[client_ip] = [
|
| 135 |
+
timestamp
|
| 136 |
+
for timestamp in rate_limit_cache[client_ip]
|
| 137 |
+
if current_time - timestamp < window_seconds
|
| 138 |
+
]
|
| 139 |
+
else:
|
| 140 |
+
rate_limit_cache[client_ip] = []
|
| 141 |
+
|
| 142 |
+
# Check if under limit
|
| 143 |
+
if len(rate_limit_cache[client_ip]) < max_requests:
|
| 144 |
+
# Add current request timestamp
|
| 145 |
+
rate_limit_cache[client_ip].append(current_time)
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
+
return False
|
| 149 |
+
|
| 150 |
|
| 151 |
@router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
|
| 152 |
async def create_embeddings(
|
| 153 |
+
request: EmbeddingRequest,
|
| 154 |
background_tasks: BackgroundTasks,
|
| 155 |
+
fastapi_request: Request,
|
| 156 |
+
authorization: str = Header(None),
|
| 157 |
):
|
| 158 |
"""
|
| 159 |
Generate embeddings for the given text or image inputs.
|
| 160 |
"""
|
| 161 |
# Check authorization
|
| 162 |
expected_token = os.environ.get("ACCESS_TOKEN")
|
| 163 |
+
is_authenticated = False
|
| 164 |
+
|
| 165 |
if expected_token:
|
| 166 |
+
if authorization:
|
| 167 |
+
# Support both "Bearer <token>" and plain token formats
|
| 168 |
+
token = authorization
|
| 169 |
+
if authorization.startswith("Bearer "):
|
| 170 |
+
token = authorization[7:] # Remove "Bearer " prefix
|
| 171 |
+
|
| 172 |
+
if token == expected_token:
|
| 173 |
+
is_authenticated = True
|
| 174 |
+
|
| 175 |
+
# If not authenticated, check rate limit
|
| 176 |
+
if not is_authenticated:
|
| 177 |
+
# Get client IP
|
| 178 |
+
client_ip = fastapi_request.client.host
|
| 179 |
+
if hasattr(fastapi_request.headers, "get"):
|
| 180 |
+
# Check for forwarded IP (in case of proxy)
|
| 181 |
+
forwarded_for = fastapi_request.headers.get("X-Forwarded-For")
|
| 182 |
+
if forwarded_for:
|
| 183 |
+
client_ip = forwarded_for.split(",")[0].strip()
|
| 184 |
+
|
| 185 |
+
real_ip = fastapi_request.headers.get("X-Real-IP")
|
| 186 |
+
if real_ip:
|
| 187 |
+
client_ip = real_ip.strip()
|
| 188 |
+
|
| 189 |
+
# Check rate limit (4 requests per minute)
|
| 190 |
+
if not check_rate_limit(client_ip):
|
| 191 |
+
raise HTTPException(
|
| 192 |
+
status_code=429,
|
| 193 |
+
detail="Rate limit exceeded. Maximum 4 requests per minute for unauthenticated users.",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# If no authorization header was provided when ACCESS_TOKEN is set
|
| 197 |
+
if not authorization:
|
| 198 |
+
raise HTTPException(
|
| 199 |
+
status_code=401, detail="Authorization header required"
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
raise HTTPException(
|
| 203 |
+
status_code=401, detail="Invalid authorization token"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
try:
|
| 207 |
modality = detect_model_kind(request.model)
|
| 208 |
embeddings = await embeddings_service.generate_embeddings(
|