yym68686 commited on
Commit
c34a2a5
·
1 Parent(s): 0712d85

🐛 Bug: Fix the bug where the request to generate images cannot be parsed.

Browse files
Files changed (3) hide show
  1. .github/workflows/main.yml +1 -1
  2. main.py +18 -16
  3. models.py +54 -27
.github/workflows/main.yml CHANGED
@@ -68,7 +68,7 @@ jobs:
68
  git config --global user.name 'github-actions[bot]'
69
  git config --global user.email 'github-actions[bot]@users.noreply.github.com'
70
  git add VERSION
71
- git commit -m "Bump version to ${{ steps.bump_version.outputs.new_version }}"
72
  git push
73
 
74
  - name: Build and push Docker image
 
68
  git config --global user.name 'github-actions[bot]'
69
  git config --global user.email 'github-actions[bot]@users.noreply.github.com'
70
  git add VERSION
71
+ git commit -m "📖 Bump version to ${{ steps.bump_version.outputs.new_version }}"
72
  git push
73
 
74
  - name: Build and push Docker image
main.py CHANGED
@@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from fastapi.exceptions import RequestValidationError
14
 
15
- from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest
16
  from request import get_payload
17
  from response import fetch_response, fetch_response_stream
18
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
@@ -164,30 +164,27 @@ class StatsMiddleware(BaseHTTPMiddleware):
164
  super().__init__(app)
165
 
166
  async def dispatch(self, request: Request, call_next):
167
- if request.headers.get("x-api-key"):
168
- token = request.headers.get("x-api-key")
169
- elif request.headers.get("Authorization"):
170
- token = request.headers.get("Authorization").split(" ")[1]
171
- else:
172
- token = None
173
-
174
  start_time = time()
175
 
176
- request.state.parsed_body = await parse_request_body(request)
177
  endpoint = f"{request.method} {request.url.path}"
178
  client_ip = request.client.host
179
 
180
  model = "unknown"
181
- enable_moderation = False # 默认不开启道德审查
182
  is_flagged = False
183
  moderated_content = ""
 
184
 
185
  config = app.state.config
186
- api_list = app.state.api_list
187
-
188
  # 根据token决定是否启用道德审查
 
 
 
 
 
 
189
  if token:
190
  try:
 
191
  api_index = api_list.index(token)
192
  enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
193
  except ValueError:
@@ -197,11 +194,16 @@ class StatsMiddleware(BaseHTTPMiddleware):
197
  # 如果token为None,检查全局设置
198
  enable_moderation = config.get('ENABLE_MODERATION', False)
199
 
200
- if request.state.parsed_body:
 
201
  try:
202
- request_model = RequestModel(**request.state.parsed_body)
203
  model = request_model.model
204
- moderated_content = request_model.get_last_text_message()
 
 
 
 
205
 
206
  if enable_moderation and moderated_content:
207
  moderation_response = await self.moderate_content(moderated_content, token)
@@ -636,7 +638,7 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
636
  return token
637
 
638
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
639
- async def request_model(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str = Depends(verify_api_key)):
640
  # logger.info(f"Request received: {request}")
641
  return await model_handler.request_model(request, token)
642
 
 
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from fastapi.exceptions import RequestValidationError
14
 
15
+ from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
16
  from request import get_payload
17
  from response import fetch_response, fetch_response_stream
18
  from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
 
164
  super().__init__(app)
165
 
166
  async def dispatch(self, request: Request, call_next):
 
 
 
 
 
 
 
167
  start_time = time()
168
 
 
169
  endpoint = f"{request.method} {request.url.path}"
170
  client_ip = request.client.host
171
 
172
  model = "unknown"
 
173
  is_flagged = False
174
  moderated_content = ""
175
+ enable_moderation = False # 默认不开启道德审查
176
 
177
  config = app.state.config
 
 
178
  # 根据token决定是否启用道德审查
179
+ if request.headers.get("x-api-key"):
180
+ token = request.headers.get("x-api-key")
181
+ elif request.headers.get("Authorization"):
182
+ token = request.headers.get("Authorization").split(" ")[1]
183
+ else:
184
+ token = None
185
  if token:
186
  try:
187
+ api_list = app.state.api_list
188
  api_index = api_list.index(token)
189
  enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
190
  except ValueError:
 
194
  # 如果token为None,检查全局设置
195
  enable_moderation = config.get('ENABLE_MODERATION', False)
196
 
197
+ parsed_body = await parse_request_body(request)
198
+ if parsed_body:
199
  try:
200
+ request_model = UnifiedRequest.model_validate(parsed_body).data
201
  model = request_model.model
202
+
203
+ if request_model.request_type == "chat":
204
+ moderated_content = request_model.get_last_text_message()
205
+ elif request_model.request_type == "image":
206
+ moderated_content = request_model.prompt
207
 
208
  if enable_moderation and moderated_content:
209
  moderation_response = await self.moderate_content(moderated_content, token)
 
638
  return token
639
 
640
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
641
+ async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
642
  # logger.info(f"Request received: {request}")
643
  return await model_handler.request_model(request, token)
644
 
models.py CHANGED
@@ -1,30 +1,6 @@
1
  from io import IOBase
2
- from pydantic import BaseModel, Field
3
- from typing import List, Dict, Optional, Union, Tuple
4
-
5
- class ImageGenerationRequest(BaseModel):
6
- model: str
7
- prompt: str
8
- n: int
9
- size: str
10
- stream: bool = False
11
-
12
- class AudioTranscriptionRequest(BaseModel):
13
- file: Tuple[str, IOBase, str]
14
- model: str
15
- language: Optional[str] = None
16
- prompt: Optional[str] = None
17
- response_format: Optional[str] = None
18
- temperature: Optional[float] = None
19
- stream: bool = False
20
-
21
- class Config:
22
- arbitrary_types_allowed = True
23
-
24
- class ModerationRequest(BaseModel):
25
- input: str
26
- model: Optional[str] = "text-moderation-latest"
27
- stream: bool = False
28
 
29
  class FunctionParameter(BaseModel):
30
  type: str
@@ -82,6 +58,7 @@ class ToolChoice(BaseModel):
82
  function: Optional[FunctionChoice] = None
83
 
84
  class RequestModel(BaseModel):
 
85
  model: str
86
  messages: List[Message]
87
  logprobs: Optional[bool] = None
@@ -107,4 +84,54 @@ class RequestModel(BaseModel):
107
  for item in reversed(message.content):
108
  if item.type == "text" and item.text:
109
  return item.text
110
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from io import IOBase
2
+ from pydantic import BaseModel, Field, model_validator
3
+ from typing import List, Dict, Optional, Union, Tuple, Literal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class FunctionParameter(BaseModel):
6
  type: str
 
58
  function: Optional[FunctionChoice] = None
59
 
60
  class RequestModel(BaseModel):
61
+ request_type: Literal["chat"] = "chat"
62
  model: str
63
  messages: List[Message]
64
  logprobs: Optional[bool] = None
 
84
  for item in reversed(message.content):
85
  if item.type == "text" and item.text:
86
  return item.text
87
+ return ""
88
+
89
+ class ImageGenerationRequest(BaseModel):
90
+ request_type: Literal["image"] = "image"
91
+ prompt: str
92
+ model: Optional[str] = "dall-e-3"
93
+ n: Optional[int] = 1
94
+ size: Optional[str] = "1024x1024"
95
+ stream: bool = False
96
+
97
+ class AudioTranscriptionRequest(BaseModel):
98
+ request_type: Literal["audio"] = "audio"
99
+ file: Tuple[str, IOBase, str]
100
+ model: str
101
+ language: Optional[str] = None
102
+ prompt: Optional[str] = None
103
+ response_format: Optional[str] = None
104
+ temperature: Optional[float] = None
105
+ stream: bool = False
106
+
107
+ class Config:
108
+ arbitrary_types_allowed = True
109
+
110
+ class ModerationRequest(BaseModel):
111
+ request_type: Literal["moderation"] = "moderation"
112
+ input: str
113
+ model: Optional[str] = "text-moderation-latest"
114
+ stream: bool = False
115
+
116
+ class UnifiedRequest(BaseModel):
117
+ data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest] = Field(..., discriminator="request_type")
118
+
119
+ @model_validator(mode='before')
120
+ @classmethod
121
+ def set_request_type(cls, values):
122
+ if isinstance(values, dict):
123
+ if "messages" in values:
124
+ values["request_type"] = "chat"
125
+ values["data"] = RequestModel(**values)
126
+ elif "prompt" in values:
127
+ values["request_type"] = "image"
128
+ values["data"] = ImageGenerationRequest(**values)
129
+ elif "file" in values:
130
+ values["request_type"] = "audio"
131
+ values["data"] = AudioTranscriptionRequest(**values)
132
+ elif "input" in values:
133
+ values["request_type"] = "moderation"
134
+ values["data"] = ModerationRequest(**values)
135
+ else:
136
+ raise ValueError("无法确定请求类型")
137
+ return values