|
""" |
|
API密钥模型 - 处理API密钥的CRUD操作 |
|
""" |
|
import json |
|
import uuid |
|
from datetime import datetime |
|
import os |
|
import pytz |
|
import sqlite3 |
|
from utils.db import get_db_connection |
|
from config import API_KEYS_FILE, DATABASE_PATH |
|
|
|
class ApiKeyManager: |
|
"""管理API密钥的类""" |
|
|
|
@staticmethod |
|
def load_keys(): |
|
"""加载所有API密钥 (兼容旧的JSON方式)""" |
|
if not os.path.exists(API_KEYS_FILE): |
|
with open(API_KEYS_FILE, 'w', encoding='utf-8') as f: |
|
json.dump({"api_keys": []}, f, ensure_ascii=False, indent=2) |
|
return {"api_keys": []} |
|
|
|
try: |
|
with open(API_KEYS_FILE, 'r', encoding='utf-8') as f: |
|
return json.load(f) |
|
except json.JSONDecodeError: |
|
return {"api_keys": []} |
|
|
|
@staticmethod |
|
def save_keys(data): |
|
"""保存API密钥数据 (兼容旧的JSON方式)""" |
|
with open(API_KEYS_FILE, 'w', encoding='utf-8') as f: |
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|
@staticmethod |
|
def get_all_keys(): |
|
"""获取所有密钥""" |
|
conn = get_db_connection() |
|
try: |
|
cursor = conn.cursor() |
|
cursor.execute('SELECT * FROM api_keys') |
|
rows = cursor.fetchall() |
|
|
|
|
|
api_keys = [] |
|
for row in rows: |
|
key_dict = dict(row) |
|
|
|
key_dict['success'] = bool(key_dict['success']) |
|
api_keys.append(key_dict) |
|
|
|
return {"api_keys": api_keys} |
|
except sqlite3.Error as e: |
|
print(f"获取所有密钥时出错: {str(e)}") |
|
|
|
return ApiKeyManager.load_keys() |
|
finally: |
|
conn.close() |
|
|
|
@staticmethod |
|
def add_key(platform, name, key): |
|
"""添加新的API密钥""" |
|
|
|
if key: |
|
key = key.replace("'", "").replace('"', "").replace('(', "").replace(')', "").replace('[', "").replace(']', "").replace(' ', "") |
|
|
|
current_time = datetime.now(pytz.timezone('Asia/Shanghai')).isoformat() |
|
new_key_id = str(uuid.uuid4()) |
|
|
|
new_key = { |
|
"id": new_key_id, |
|
"platform": platform, |
|
"name": name, |
|
"key": key, |
|
"created_at": current_time, |
|
"updated_at": current_time, |
|
"success": False, |
|
"return_message": "等待测试", |
|
"states": "", |
|
"balance": 0 |
|
} |
|
|
|
conn = get_db_connection() |
|
try: |
|
cursor = conn.cursor() |
|
cursor.execute(''' |
|
INSERT INTO api_keys |
|
(id, platform, name, key, created_at, updated_at, success, return_message, states, balance) |
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
|
''', ( |
|
new_key_id, |
|
platform, |
|
name, |
|
key, |
|
current_time, |
|
current_time, |
|
0, |
|
"等待测试", |
|
"", |
|
0 |
|
)) |
|
conn.commit() |
|
|
|
return new_key |
|
except sqlite3.Error as e: |
|
print(f"添加密钥时出错: {str(e)}") |
|
conn.rollback() |
|
|
|
api_keys_data = ApiKeyManager.load_keys() |
|
api_keys_data["api_keys"].append(new_key) |
|
ApiKeyManager.save_keys(api_keys_data) |
|
return new_key |
|
finally: |
|
conn.close() |
|
|
|
@staticmethod |
|
def delete_key(key_id): |
|
"""删除指定的API密钥""" |
|
conn = get_db_connection() |
|
try: |
|
cursor = conn.cursor() |
|
cursor.execute('DELETE FROM api_keys WHERE id = ?', (key_id,)) |
|
deleted = cursor.rowcount > 0 |
|
conn.commit() |
|
return deleted |
|
except sqlite3.Error as e: |
|
print(f"删除密钥时出错: {str(e)}") |
|
conn.rollback() |
|
|
|
api_keys_data = ApiKeyManager.load_keys() |
|
original_count = len(api_keys_data["api_keys"]) |
|
api_keys_data["api_keys"] = [k for k in api_keys_data["api_keys"] if k.get("id") != key_id] |
|
|
|
if len(api_keys_data["api_keys"]) < original_count: |
|
ApiKeyManager.save_keys(api_keys_data) |
|
return True |
|
return False |
|
finally: |
|
conn.close() |
|
|
|
@staticmethod |
|
def bulk_delete(key_ids): |
|
"""批量删除多个API密钥""" |
|
if not key_ids: |
|
return 0 |
|
|
|
conn = get_db_connection() |
|
try: |
|
cursor = conn.cursor() |
|
|
|
cursor.execute('SELECT COUNT(*) FROM api_keys') |
|
original_count = cursor.fetchone()[0] |
|
|
|
|
|
placeholders = ','.join(['?'] * len(key_ids)) |
|
cursor.execute(f'DELETE FROM api_keys WHERE id IN ({placeholders})', key_ids) |
|
|
|
|
|
cursor.execute('SELECT COUNT(*) FROM api_keys') |
|
new_count = cursor.fetchone()[0] |
|
|
|
conn.commit() |
|
return original_count - new_count |
|
except sqlite3.Error as e: |
|
print(f"批量删除密钥时出错: {str(e)}") |
|
conn.rollback() |
|
|
|
api_keys_data = ApiKeyManager.load_keys() |
|
original_count = len(api_keys_data["api_keys"]) |
|
api_keys_data["api_keys"] = [k for k in api_keys_data["api_keys"] if k.get("id") not in key_ids] |
|
|
|
deleted_count = original_count - len(api_keys_data["api_keys"]) |
|
if deleted_count > 0: |
|
ApiKeyManager.save_keys(api_keys_data) |
|
|
|
return deleted_count |
|
finally: |
|
conn.close() |
|
|
|
@staticmethod |
|
def bulk_add_keys(keys_data): |
|
"""批量添加多个API密钥 |
|
|
|
Args: |
|
keys_data: 包含多个密钥信息的列表,每个元素包含platform、name、key |
|
|
|
Returns: |
|
添加的密钥列表 |
|
""" |
|
if not keys_data: |
|
return [] |
|
|
|
added_keys = [] |
|
now = datetime.now(pytz.timezone('Asia/Shanghai')).isoformat() |
|
|
|
conn = get_db_connection() |
|
try: |
|
cursor = conn.cursor() |
|
|
|
for key_info in keys_data: |
|
platform = key_info.get("platform") |
|
name = key_info.get("name") |
|
key = key_info.get("key") |
|
|
|
|
|
if key: |
|
key = key.replace("'", "").replace('"', "").replace('(', "").replace(')', "").replace('[', "").replace(']', "").replace(' ', "") |
|
|
|
new_key_id = str(uuid.uuid4()) |
|
|
|
new_key = { |
|
"id": new_key_id, |
|
"platform": platform, |
|
"name": name, |
|
"key": key, |
|
"created_at": now, |
|
"updated_at": now, |
|
"success": False, |
|
"return_message": "等待测试", |
|
"states": "", |
|
"balance": 0 |
|
} |
|
|
|
cursor.execute(''' |
|
INSERT INTO api_keys |
|
(id, platform, name, key, created_at, updated_at, success, return_message, states, balance) |
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
|
''', ( |
|
new_key_id, |
|
platform, |
|
name, |
|
key, |
|
now, |
|
now, |
|
0, |
|
"等待测试", |
|
"", |
|
0 |
|
)) |
|
|
|
added_keys.append(new_key) |
|
|
|
conn.commit() |
|
|
|
return added_keys |
|
except sqlite3.Error as e: |
|
print(f"批量添加密钥时出错: {str(e)}") |
|
conn.rollback() |
|
|
|
api_keys_data = ApiKeyManager.load_keys() |
|
for key in added_keys: |
|
api_keys_data["api_keys"].append(key) |
|
ApiKeyManager.save_keys(api_keys_data) |
|
return added_keys |
|
finally: |
|
conn.close() |
|
|
|
@staticmethod |
|
def update_key(key_id, name, key): |
|
"""更新API密钥信息""" |
|
|
|
if key: |
|
key = key.replace("'", "").replace('"', "").replace('(', "").replace(')', "").replace('[', "").replace(']', "").replace(' ', "") |
|
|
|
updated_at = datetime.now(pytz.timezone('Asia/Shanghai')).isoformat() |
|
|
|
conn = get_db_connection() |
|
try: |
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute(''' |
|
UPDATE api_keys |
|
SET name = ?, key = ?, updated_at = ?, success = ?, return_message = ? |
|
WHERE id = ? |
|
''', (name, key, updated_at, 0, "等待测试", key_id)) |
|
|
|
if cursor.rowcount > 0: |
|
conn.commit() |
|
|
|
|
|
cursor.execute('SELECT * FROM api_keys WHERE id = ?', (key_id,)) |
|
row = cursor.fetchone() |
|
|
|
if row: |
|
updated_key = dict(row) |
|
|
|
updated_key['success'] = bool(updated_key['success']) |
|
return updated_key |
|
|
|
return None |
|
except sqlite3.Error as e: |
|
print(f"更新密钥时出错: {str(e)}") |
|
conn.rollback() |
|
|
|
|
|
api_keys_data = ApiKeyManager.load_keys() |
|
updated_key = None |
|
for k in api_keys_data["api_keys"]: |
|
if k.get("id") == key_id: |
|
k["name"] = name |
|
k["key"] = key |
|
k["updated_at"] = updated_at |
|
|
|
k["success"] = False |
|
k["return_message"] = "等待测试" |
|
updated_key = k |
|
break |
|
|
|
if updated_key: |
|
ApiKeyManager.save_keys(api_keys_data) |
|
|
|
return updated_key |
|
finally: |
|
conn.close() |
|
|