🐛 Bug: Fix bugs caused by concurrent errors in multiple database write operations.
Browse files
main.py
CHANGED
@@ -282,40 +282,67 @@ def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> Decimal
|
|
282 |
# 返回精确到15位小数的结果
|
283 |
return total_cost.quantize(Decimal('0.000000000000001'))
|
284 |
|
|
|
|
|
|
|
|
|
|
|
285 |
async def update_stats(current_info):
|
286 |
if DISABLE_DATABASE:
|
287 |
return
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
async def update_channel_stats(request_id, provider, model, api_key, success):
|
302 |
if DISABLE_DATABASE:
|
303 |
return
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
class LoggingStreamingResponse(Response):
|
321 |
def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
|
|
|
282 |
# 返回精确到15位小数的结果
|
283 |
return total_cost.quantize(Decimal('0.000000000000001'))
|
284 |
|
285 |
+
from asyncio import Semaphore
|
286 |
+
|
287 |
+
# 创建一个信号量来控制数据库访问
|
288 |
+
db_semaphore = Semaphore(1) # 限制同时只有1个写入操作
|
289 |
+
|
290 |
async def update_stats(current_info):
|
291 |
if DISABLE_DATABASE:
|
292 |
return
|
293 |
+
|
294 |
+
try:
|
295 |
+
# 等待获取数据库访问权限
|
296 |
+
async with db_semaphore:
|
297 |
+
async with async_session() as session:
|
298 |
+
async with session.begin():
|
299 |
+
try:
|
300 |
+
columns = [column.key for column in RequestStat.__table__.columns]
|
301 |
+
filtered_info = {k: v for k, v in current_info.items() if k in columns}
|
302 |
+
new_request_stat = RequestStat(**filtered_info)
|
303 |
+
session.add(new_request_stat)
|
304 |
+
await session.commit()
|
305 |
+
except Exception as e:
|
306 |
+
await session.rollback()
|
307 |
+
logger.error(f"Error updating stats: {str(e)}")
|
308 |
+
if is_debug:
|
309 |
+
import traceback
|
310 |
+
traceback.print_exc()
|
311 |
+
except Exception as e:
|
312 |
+
logger.error(f"Error acquiring database lock: {str(e)}")
|
313 |
+
if is_debug:
|
314 |
+
import traceback
|
315 |
+
traceback.print_exc()
|
316 |
|
317 |
async def update_channel_stats(request_id, provider, model, api_key, success):
|
318 |
if DISABLE_DATABASE:
|
319 |
return
|
320 |
+
|
321 |
+
try:
|
322 |
+
async with db_semaphore:
|
323 |
+
async with async_session() as session:
|
324 |
+
async with session.begin():
|
325 |
+
try:
|
326 |
+
channel_stat = ChannelStat(
|
327 |
+
request_id=request_id,
|
328 |
+
provider=provider,
|
329 |
+
model=model,
|
330 |
+
api_key=api_key,
|
331 |
+
success=success,
|
332 |
+
)
|
333 |
+
session.add(channel_stat)
|
334 |
+
await session.commit()
|
335 |
+
except Exception as e:
|
336 |
+
await session.rollback()
|
337 |
+
logger.error(f"Error updating channel stats: {str(e)}")
|
338 |
+
if is_debug:
|
339 |
+
import traceback
|
340 |
+
traceback.print_exc()
|
341 |
+
except Exception as e:
|
342 |
+
logger.error(f"Error acquiring database lock: {str(e)}")
|
343 |
+
if is_debug:
|
344 |
+
import traceback
|
345 |
+
traceback.print_exc()
|
346 |
|
347 |
class LoggingStreamingResponse(Response):
|
348 |
def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
|