yym68686 commited on
Commit
783c658
·
1 Parent(s): b8a7df8

Fix the bug of model matching format error

Browse files
Files changed (1) hide show
  1. main.py +26 -16
main.py CHANGED
@@ -21,7 +21,8 @@ from urllib.parse import urlparse
21
  @asynccontextmanager
22
  async def lifespan(app: FastAPI):
23
  # 启动时的代码
24
- app.state.client = httpx.AsyncClient()
 
25
  yield
26
  # 关闭时的代码
27
  await app.state.client.aclose()
@@ -35,7 +36,20 @@ security = HTTPBearer()
35
  def load_config():
36
  try:
37
  with open('api.yaml', 'r') as f:
38
- return yaml.safe_load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except FileNotFoundError:
40
  print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
41
  return []
@@ -43,19 +57,7 @@ def load_config():
43
  print("配置文件 'config.yaml' 格式不正确。请检查YAML格式。")
44
  return []
45
 
46
- config = load_config()
47
- for index, provider in enumerate(config['providers']):
48
- model_dict = {}
49
- for model in provider['model']:
50
- if type(model) == str:
51
- model_dict[model] = model
52
- if type(model) == dict:
53
- model_dict.update({value: key for key, value in model.items()})
54
- provider['model'] = model_dict
55
- config['providers'][index] = provider
56
- api_keys_db = config['api_keys']
57
- api_list = [item["api"] for item in api_keys_db]
58
- print(json.dumps(config, indent=4, ensure_ascii=False))
59
 
60
  async def process_request(request: RequestModel, provider: Dict):
61
  print("provider: ", provider['provider'])
@@ -102,7 +104,10 @@ class ModelRequestHandler:
102
  if "/" in model:
103
  provider_name = model.split("/")[0]
104
  model = model.split("/")[1]
105
- if (model and model_name == model) or (model == "*"):
 
 
 
106
  provider_rules.append(provider_name)
107
  provider_list = []
108
  for provider in config['providers']:
@@ -250,6 +255,11 @@ def generate_api_key():
250
  api_key = "sk-" + secrets.token_urlsafe(32)
251
  return {"api_key": api_key}
252
 
 
 
 
 
 
253
  if __name__ == '__main__':
254
  import uvicorn
255
  uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
 
21
  @asynccontextmanager
22
  async def lifespan(app: FastAPI):
23
  # 启动时的代码
24
+ timeout = httpx.Timeout(connect=10.0, read=30.0, write=30.0, pool=30.0)
25
+ app.state.client = httpx.AsyncClient(timeout=timeout)
26
  yield
27
  # 关闭时的代码
28
  await app.state.client.aclose()
 
36
  def load_config():
37
  try:
38
  with open('api.yaml', 'r') as f:
39
+ conf = yaml.safe_load(f)
40
+ for index, provider in enumerate(conf['providers']):
41
+ model_dict = {}
42
+ for model in provider['model']:
43
+ if type(model) == str:
44
+ model_dict[model] = model
45
+ if type(model) == dict:
46
+ model_dict.update({value: key for key, value in model.items()})
47
+ provider['model'] = model_dict
48
+ conf['providers'][index] = provider
49
+ api_keys_db = conf['api_keys']
50
+ api_list = [item["api"] for item in api_keys_db]
51
+ print(json.dumps(conf, indent=4, ensure_ascii=False))
52
+ return conf, api_keys_db, api_list
53
  except FileNotFoundError:
54
  print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
55
  return []
 
57
  print("配置文件 'config.yaml' 格式不正确。请检查YAML格式。")
58
  return []
59
 
60
+ config, api_keys_db, api_list = load_config()
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  async def process_request(request: RequestModel, provider: Dict):
63
  print("provider: ", provider['provider'])
 
104
  if "/" in model:
105
  provider_name = model.split("/")[0]
106
  model = model.split("/")[1]
107
+ for provider in config['providers']:
108
+ if provider['provider'] == provider_name:
109
+ models_list = provider['model'].keys()
110
+ if (model and model_name == model) or (model == "*" and model_name in models_list):
111
  provider_rules.append(provider_name)
112
  provider_list = []
113
  for provider in config['providers']:
 
255
  api_key = "sk-" + secrets.token_urlsafe(32)
256
  return {"api_key": api_key}
257
 
258
+ async def on_fetch(request, env):
259
+ import asgi
260
+
261
+ return await asgi.fetch(app, request, env)
262
+
263
  if __name__ == '__main__':
264
  import uvicorn
265
  uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)