Fix the bug of model matching format error
Browse files
main.py
CHANGED
@@ -21,7 +21,8 @@ from urllib.parse import urlparse
|
|
21 |
@asynccontextmanager
|
22 |
async def lifespan(app: FastAPI):
|
23 |
# 启动时的代码
|
24 |
-
|
|
|
25 |
yield
|
26 |
# 关闭时的代码
|
27 |
await app.state.client.aclose()
|
@@ -35,7 +36,20 @@ security = HTTPBearer()
|
|
35 |
def load_config():
|
36 |
try:
|
37 |
with open('api.yaml', 'r') as f:
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
except FileNotFoundError:
|
40 |
print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
|
41 |
return []
|
@@ -43,19 +57,7 @@ def load_config():
|
|
43 |
print("配置文件 'config.yaml' 格式不正确。请检查YAML格式。")
|
44 |
return []
|
45 |
|
46 |
-
config = load_config()
|
47 |
-
for index, provider in enumerate(config['providers']):
|
48 |
-
model_dict = {}
|
49 |
-
for model in provider['model']:
|
50 |
-
if type(model) == str:
|
51 |
-
model_dict[model] = model
|
52 |
-
if type(model) == dict:
|
53 |
-
model_dict.update({value: key for key, value in model.items()})
|
54 |
-
provider['model'] = model_dict
|
55 |
-
config['providers'][index] = provider
|
56 |
-
api_keys_db = config['api_keys']
|
57 |
-
api_list = [item["api"] for item in api_keys_db]
|
58 |
-
print(json.dumps(config, indent=4, ensure_ascii=False))
|
59 |
|
60 |
async def process_request(request: RequestModel, provider: Dict):
|
61 |
print("provider: ", provider['provider'])
|
@@ -102,7 +104,10 @@ class ModelRequestHandler:
|
|
102 |
if "/" in model:
|
103 |
provider_name = model.split("/")[0]
|
104 |
model = model.split("/")[1]
|
105 |
-
|
|
|
|
|
|
|
106 |
provider_rules.append(provider_name)
|
107 |
provider_list = []
|
108 |
for provider in config['providers']:
|
@@ -250,6 +255,11 @@ def generate_api_key():
|
|
250 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
251 |
return {"api_key": api_key}
|
252 |
|
|
|
|
|
|
|
|
|
|
|
253 |
if __name__ == '__main__':
|
254 |
import uvicorn
|
255 |
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
|
|
|
21 |
@asynccontextmanager
|
22 |
async def lifespan(app: FastAPI):
|
23 |
# 启动时的代码
|
24 |
+
timeout = httpx.Timeout(connect=10.0, read=30.0, write=30.0, pool=30.0)
|
25 |
+
app.state.client = httpx.AsyncClient(timeout=timeout)
|
26 |
yield
|
27 |
# 关闭时的代码
|
28 |
await app.state.client.aclose()
|
|
|
36 |
def load_config():
|
37 |
try:
|
38 |
with open('api.yaml', 'r') as f:
|
39 |
+
conf = yaml.safe_load(f)
|
40 |
+
for index, provider in enumerate(conf['providers']):
|
41 |
+
model_dict = {}
|
42 |
+
for model in provider['model']:
|
43 |
+
if type(model) == str:
|
44 |
+
model_dict[model] = model
|
45 |
+
if type(model) == dict:
|
46 |
+
model_dict.update({value: key for key, value in model.items()})
|
47 |
+
provider['model'] = model_dict
|
48 |
+
conf['providers'][index] = provider
|
49 |
+
api_keys_db = conf['api_keys']
|
50 |
+
api_list = [item["api"] for item in api_keys_db]
|
51 |
+
print(json.dumps(conf, indent=4, ensure_ascii=False))
|
52 |
+
return conf, api_keys_db, api_list
|
53 |
except FileNotFoundError:
|
54 |
print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
|
55 |
return []
|
|
|
57 |
print("配置文件 'config.yaml' 格式不正确。请检查YAML格式。")
|
58 |
return []
|
59 |
|
60 |
+
config, api_keys_db, api_list = load_config()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
async def process_request(request: RequestModel, provider: Dict):
|
63 |
print("provider: ", provider['provider'])
|
|
|
104 |
if "/" in model:
|
105 |
provider_name = model.split("/")[0]
|
106 |
model = model.split("/")[1]
|
107 |
+
for provider in config['providers']:
|
108 |
+
if provider['provider'] == provider_name:
|
109 |
+
models_list = provider['model'].keys()
|
110 |
+
if (model and model_name == model) or (model == "*" and model_name in models_list):
|
111 |
provider_rules.append(provider_name)
|
112 |
provider_list = []
|
113 |
for provider in config['providers']:
|
|
|
255 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
256 |
return {"api_key": api_key}
|
257 |
|
258 |
+
async def on_fetch(request, env):
|
259 |
+
import asgi
|
260 |
+
|
261 |
+
return await asgi.fetch(app, request, env)
|
262 |
+
|
263 |
if __name__ == '__main__':
|
264 |
import uvicorn
|
265 |
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
|