Spaces:
Sleeping
Sleeping
fix: decoder
Browse files- modules/ai_model.py +10 -4
modules/ai_model.py
CHANGED
|
@@ -151,7 +151,7 @@ class AIModel:
|
|
| 151 |
if input_type == "image" and isinstance(formatted_input, Image.Image):
|
| 152 |
image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
|
| 153 |
if image_token not in prompt:
|
| 154 |
-
prompt = f"{image_token}
|
| 155 |
inputs = self.processor(
|
| 156 |
text=prompt,
|
| 157 |
images=formatted_input,
|
|
@@ -163,7 +163,13 @@ class AIModel:
|
|
| 163 |
return_tensors="pt"
|
| 164 |
).to(self.model.device, dtype=torch.bfloat16)
|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
with torch.inference_mode():
|
| 168 |
generation_args = {
|
| 169 |
"max_new_tokens": 512,
|
|
@@ -187,7 +193,8 @@ class AIModel:
|
|
| 187 |
**inputs,
|
| 188 |
**generation_args
|
| 189 |
)
|
| 190 |
-
|
|
|
|
| 191 |
decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
| 192 |
|
| 193 |
return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"
|
|
@@ -205,7 +212,6 @@ class AIModel:
|
|
| 205 |
|
| 206 |
if not self.is_available():
|
| 207 |
log.error("模型未就绪,无法执行 chat_completion")
|
| 208 |
-
# 对于需要JSON输出的场景,返回一个表示错误的有效JSON字符串
|
| 209 |
if kwargs.get("response_format", {}).get("type") == "json_object":
|
| 210 |
return '{"error": "Model not available"}'
|
| 211 |
return "抱歉,AI 模型当前不可用。"
|
|
|
|
| 151 |
if input_type == "image" and isinstance(formatted_input, Image.Image):
|
| 152 |
image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
|
| 153 |
if image_token not in prompt:
|
| 154 |
+
prompt = f"{image_token}\n{prompt}"
|
| 155 |
inputs = self.processor(
|
| 156 |
text=prompt,
|
| 157 |
images=formatted_input,
|
|
|
|
| 163 |
return_tensors="pt"
|
| 164 |
).to(self.model.device, dtype=torch.bfloat16)
|
| 165 |
|
| 166 |
+
if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
|
| 167 |
+
log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
|
| 168 |
+
inputs.input_ids = inputs.input_ids[:, :512]
|
| 169 |
+
if hasattr(inputs, 'attention_mask'):
|
| 170 |
+
inputs.attention_mask = inputs.attention_mask[:, :512]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
with torch.inference_mode():
|
| 174 |
generation_args = {
|
| 175 |
"max_new_tokens": 512,
|
|
|
|
| 193 |
**inputs,
|
| 194 |
**generation_args
|
| 195 |
)
|
| 196 |
+
input_length = inputs.input_ids.shape[-1]
|
| 197 |
+
generated_tokens = outputs[0][input_length:]
|
| 198 |
decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
| 199 |
|
| 200 |
return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"
|
|
|
|
| 212 |
|
| 213 |
if not self.is_available():
|
| 214 |
log.error("模型未就绪,无法执行 chat_completion")
|
|
|
|
| 215 |
if kwargs.get("response_format", {}).get("type") == "json_object":
|
| 216 |
return '{"error": "Model not available"}'
|
| 217 |
return "抱歉,AI 模型当前不可用。"
|