🐛 Bug: Fix the bug where the request to generate images cannot be parsed.
Browse files- .github/workflows/main.yml +1 -1
- main.py +18 -16
- models.py +54 -27
.github/workflows/main.yml
CHANGED
@@ -68,7 +68,7 @@ jobs:
|
|
68 |
git config --global user.name 'github-actions[bot]'
|
69 |
git config --global user.email 'github-actions[bot]@users.noreply.github.com'
|
70 |
git add VERSION
|
71 |
-
git commit -m "Bump version to ${{ steps.bump_version.outputs.new_version }}"
|
72 |
git push
|
73 |
|
74 |
- name: Build and push Docker image
|
|
|
68 |
git config --global user.name 'github-actions[bot]'
|
69 |
git config --global user.email 'github-actions[bot]@users.noreply.github.com'
|
70 |
git add VERSION
|
71 |
+
git commit -m "📖 Bump version to ${{ steps.bump_version.outputs.new_version }}"
|
72 |
git push
|
73 |
|
74 |
- name: Build and push Docker image
|
main.py
CHANGED
@@ -12,7 +12,7 @@ from fastapi.responses import StreamingResponse, JSONResponse
|
|
12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
13 |
from fastapi.exceptions import RequestValidationError
|
14 |
|
15 |
-
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest
|
16 |
from request import get_payload
|
17 |
from response import fetch_response, fetch_response_stream
|
18 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
@@ -164,30 +164,27 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
164 |
super().__init__(app)
|
165 |
|
166 |
async def dispatch(self, request: Request, call_next):
|
167 |
-
if request.headers.get("x-api-key"):
|
168 |
-
token = request.headers.get("x-api-key")
|
169 |
-
elif request.headers.get("Authorization"):
|
170 |
-
token = request.headers.get("Authorization").split(" ")[1]
|
171 |
-
else:
|
172 |
-
token = None
|
173 |
-
|
174 |
start_time = time()
|
175 |
|
176 |
-
request.state.parsed_body = await parse_request_body(request)
|
177 |
endpoint = f"{request.method} {request.url.path}"
|
178 |
client_ip = request.client.host
|
179 |
|
180 |
model = "unknown"
|
181 |
-
enable_moderation = False # 默认不开启道德审查
|
182 |
is_flagged = False
|
183 |
moderated_content = ""
|
|
|
184 |
|
185 |
config = app.state.config
|
186 |
-
api_list = app.state.api_list
|
187 |
-
|
188 |
# 根据token决定是否启用道德审查
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
if token:
|
190 |
try:
|
|
|
191 |
api_index = api_list.index(token)
|
192 |
enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
|
193 |
except ValueError:
|
@@ -197,11 +194,16 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
197 |
# 如果token为None,检查全局设置
|
198 |
enable_moderation = config.get('ENABLE_MODERATION', False)
|
199 |
|
200 |
-
|
|
|
201 |
try:
|
202 |
-
request_model =
|
203 |
model = request_model.model
|
204 |
-
|
|
|
|
|
|
|
|
|
205 |
|
206 |
if enable_moderation and moderated_content:
|
207 |
moderation_response = await self.moderate_content(moderated_content, token)
|
@@ -636,7 +638,7 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
|
|
636 |
return token
|
637 |
|
638 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
639 |
-
async def request_model(request:
|
640 |
# logger.info(f"Request received: {request}")
|
641 |
return await model_handler.request_model(request, token)
|
642 |
|
|
|
12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
13 |
from fastapi.exceptions import RequestValidationError
|
14 |
|
15 |
+
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
16 |
from request import get_payload
|
17 |
from response import fetch_response, fetch_response_stream
|
18 |
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder
|
|
|
164 |
super().__init__(app)
|
165 |
|
166 |
async def dispatch(self, request: Request, call_next):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
start_time = time()
|
168 |
|
|
|
169 |
endpoint = f"{request.method} {request.url.path}"
|
170 |
client_ip = request.client.host
|
171 |
|
172 |
model = "unknown"
|
|
|
173 |
is_flagged = False
|
174 |
moderated_content = ""
|
175 |
+
enable_moderation = False # 默认不开启道德审查
|
176 |
|
177 |
config = app.state.config
|
|
|
|
|
178 |
# 根据token决定是否启用道德审查
|
179 |
+
if request.headers.get("x-api-key"):
|
180 |
+
token = request.headers.get("x-api-key")
|
181 |
+
elif request.headers.get("Authorization"):
|
182 |
+
token = request.headers.get("Authorization").split(" ")[1]
|
183 |
+
else:
|
184 |
+
token = None
|
185 |
if token:
|
186 |
try:
|
187 |
+
api_list = app.state.api_list
|
188 |
api_index = api_list.index(token)
|
189 |
enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
|
190 |
except ValueError:
|
|
|
194 |
# 如果token为None,检查全局设置
|
195 |
enable_moderation = config.get('ENABLE_MODERATION', False)
|
196 |
|
197 |
+
parsed_body = await parse_request_body(request)
|
198 |
+
if parsed_body:
|
199 |
try:
|
200 |
+
request_model = UnifiedRequest.model_validate(parsed_body).data
|
201 |
model = request_model.model
|
202 |
+
|
203 |
+
if request_model.request_type == "chat":
|
204 |
+
moderated_content = request_model.get_last_text_message()
|
205 |
+
elif request_model.request_type == "image":
|
206 |
+
moderated_content = request_model.prompt
|
207 |
|
208 |
if enable_moderation and moderated_content:
|
209 |
moderation_response = await self.moderate_content(moderated_content, token)
|
|
|
638 |
return token
|
639 |
|
640 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
641 |
+
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
|
642 |
# logger.info(f"Request received: {request}")
|
643 |
return await model_handler.request_model(request, token)
|
644 |
|
models.py
CHANGED
@@ -1,30 +1,6 @@
|
|
1 |
from io import IOBase
|
2 |
-
from pydantic import BaseModel, Field
|
3 |
-
from typing import List, Dict, Optional, Union, Tuple
|
4 |
-
|
5 |
-
class ImageGenerationRequest(BaseModel):
|
6 |
-
model: str
|
7 |
-
prompt: str
|
8 |
-
n: int
|
9 |
-
size: str
|
10 |
-
stream: bool = False
|
11 |
-
|
12 |
-
class AudioTranscriptionRequest(BaseModel):
|
13 |
-
file: Tuple[str, IOBase, str]
|
14 |
-
model: str
|
15 |
-
language: Optional[str] = None
|
16 |
-
prompt: Optional[str] = None
|
17 |
-
response_format: Optional[str] = None
|
18 |
-
temperature: Optional[float] = None
|
19 |
-
stream: bool = False
|
20 |
-
|
21 |
-
class Config:
|
22 |
-
arbitrary_types_allowed = True
|
23 |
-
|
24 |
-
class ModerationRequest(BaseModel):
|
25 |
-
input: str
|
26 |
-
model: Optional[str] = "text-moderation-latest"
|
27 |
-
stream: bool = False
|
28 |
|
29 |
class FunctionParameter(BaseModel):
|
30 |
type: str
|
@@ -82,6 +58,7 @@ class ToolChoice(BaseModel):
|
|
82 |
function: Optional[FunctionChoice] = None
|
83 |
|
84 |
class RequestModel(BaseModel):
|
|
|
85 |
model: str
|
86 |
messages: List[Message]
|
87 |
logprobs: Optional[bool] = None
|
@@ -107,4 +84,54 @@ class RequestModel(BaseModel):
|
|
107 |
for item in reversed(message.content):
|
108 |
if item.type == "text" and item.text:
|
109 |
return item.text
|
110 |
-
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from io import IOBase
|
2 |
+
from pydantic import BaseModel, Field, model_validator
|
3 |
+
from typing import List, Dict, Optional, Union, Tuple, Literal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
class FunctionParameter(BaseModel):
|
6 |
type: str
|
|
|
58 |
function: Optional[FunctionChoice] = None
|
59 |
|
60 |
class RequestModel(BaseModel):
|
61 |
+
request_type: Literal["chat"] = "chat"
|
62 |
model: str
|
63 |
messages: List[Message]
|
64 |
logprobs: Optional[bool] = None
|
|
|
84 |
for item in reversed(message.content):
|
85 |
if item.type == "text" and item.text:
|
86 |
return item.text
|
87 |
+
return ""
|
88 |
+
|
89 |
+
class ImageGenerationRequest(BaseModel):
|
90 |
+
request_type: Literal["image"] = "image"
|
91 |
+
prompt: str
|
92 |
+
model: Optional[str] = "dall-e-3"
|
93 |
+
n: Optional[int] = 1
|
94 |
+
size: Optional[str] = "1024x1024"
|
95 |
+
stream: bool = False
|
96 |
+
|
97 |
+
class AudioTranscriptionRequest(BaseModel):
|
98 |
+
request_type: Literal["audio"] = "audio"
|
99 |
+
file: Tuple[str, IOBase, str]
|
100 |
+
model: str
|
101 |
+
language: Optional[str] = None
|
102 |
+
prompt: Optional[str] = None
|
103 |
+
response_format: Optional[str] = None
|
104 |
+
temperature: Optional[float] = None
|
105 |
+
stream: bool = False
|
106 |
+
|
107 |
+
class Config:
|
108 |
+
arbitrary_types_allowed = True
|
109 |
+
|
110 |
+
class ModerationRequest(BaseModel):
|
111 |
+
request_type: Literal["moderation"] = "moderation"
|
112 |
+
input: str
|
113 |
+
model: Optional[str] = "text-moderation-latest"
|
114 |
+
stream: bool = False
|
115 |
+
|
116 |
+
class UnifiedRequest(BaseModel):
|
117 |
+
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest] = Field(..., discriminator="request_type")
|
118 |
+
|
119 |
+
@model_validator(mode='before')
|
120 |
+
@classmethod
|
121 |
+
def set_request_type(cls, values):
|
122 |
+
if isinstance(values, dict):
|
123 |
+
if "messages" in values:
|
124 |
+
values["request_type"] = "chat"
|
125 |
+
values["data"] = RequestModel(**values)
|
126 |
+
elif "prompt" in values:
|
127 |
+
values["request_type"] = "image"
|
128 |
+
values["data"] = ImageGenerationRequest(**values)
|
129 |
+
elif "file" in values:
|
130 |
+
values["request_type"] = "audio"
|
131 |
+
values["data"] = AudioTranscriptionRequest(**values)
|
132 |
+
elif "input" in values:
|
133 |
+
values["request_type"] = "moderation"
|
134 |
+
values["data"] = ModerationRequest(**values)
|
135 |
+
else:
|
136 |
+
raise ValueError("无法确定请求类型")
|
137 |
+
return values
|