yym68686 commited on
Commit
2f00898
·
1 Parent(s): 2bfe3df

🐛 Bug: Fix bugs caused by concurrent errors in multiple database write operations.

Browse files
Files changed (1) hide show
  1. main.py +54 -27
main.py CHANGED
@@ -282,40 +282,67 @@ def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> Decimal
282
  # 返回精确到15位小数的结果
283
  return total_cost.quantize(Decimal('0.000000000000001'))
284
 
 
 
 
 
 
285
  async def update_stats(current_info):
286
  if DISABLE_DATABASE:
287
  return
288
- # 这里添加更新数据库的逻辑
289
- async with async_session() as session:
290
- async with session.begin():
291
- try:
292
- columns = [column.key for column in RequestStat.__table__.columns]
293
- filtered_info = {k: v for k, v in current_info.items() if k in columns}
294
- new_request_stat = RequestStat(**filtered_info)
295
- session.add(new_request_stat)
296
- await session.commit()
297
- except Exception as e:
298
- await session.rollback()
299
- logger.error(f"Error updating stats: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  async def update_channel_stats(request_id, provider, model, api_key, success):
302
  if DISABLE_DATABASE:
303
  return
304
- async with async_session() as session:
305
- async with session.begin():
306
- try:
307
- channel_stat = ChannelStat(
308
- request_id=request_id,
309
- provider=provider,
310
- model=model,
311
- api_key=api_key,
312
- success=success,
313
- )
314
- session.add(channel_stat)
315
- await session.commit()
316
- except Exception as e:
317
- await session.rollback()
318
- logger.error(f"Error updating channel stats: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  class LoggingStreamingResponse(Response):
321
  def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
 
282
  # 返回精确到15位小数的结果
283
  return total_cost.quantize(Decimal('0.000000000000001'))
284
 
285
+ from asyncio import Semaphore
286
+
287
+ # 创建一个信号量来控制数据库访问
288
+ db_semaphore = Semaphore(1) # 限制同时只有1个写入操作
289
+
290
  async def update_stats(current_info):
291
  if DISABLE_DATABASE:
292
  return
293
+
294
+ try:
295
+ # 等待获取数据库访问权限
296
+ async with db_semaphore:
297
+ async with async_session() as session:
298
+ async with session.begin():
299
+ try:
300
+ columns = [column.key for column in RequestStat.__table__.columns]
301
+ filtered_info = {k: v for k, v in current_info.items() if k in columns}
302
+ new_request_stat = RequestStat(**filtered_info)
303
+ session.add(new_request_stat)
304
+ await session.commit()
305
+ except Exception as e:
306
+ await session.rollback()
307
+ logger.error(f"Error updating stats: {str(e)}")
308
+ if is_debug:
309
+ import traceback
310
+ traceback.print_exc()
311
+ except Exception as e:
312
+ logger.error(f"Error acquiring database lock: {str(e)}")
313
+ if is_debug:
314
+ import traceback
315
+ traceback.print_exc()
316
 
317
  async def update_channel_stats(request_id, provider, model, api_key, success):
318
  if DISABLE_DATABASE:
319
  return
320
+
321
+ try:
322
+ async with db_semaphore:
323
+ async with async_session() as session:
324
+ async with session.begin():
325
+ try:
326
+ channel_stat = ChannelStat(
327
+ request_id=request_id,
328
+ provider=provider,
329
+ model=model,
330
+ api_key=api_key,
331
+ success=success,
332
+ )
333
+ session.add(channel_stat)
334
+ await session.commit()
335
+ except Exception as e:
336
+ await session.rollback()
337
+ logger.error(f"Error updating channel stats: {str(e)}")
338
+ if is_debug:
339
+ import traceback
340
+ traceback.print_exc()
341
+ except Exception as e:
342
+ logger.error(f"Error acquiring database lock: {str(e)}")
343
+ if is_debug:
344
+ import traceback
345
+ traceback.print_exc()
346
 
347
  class LoggingStreamingResponse(Response):
348
  def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):