Eliot0110 commited on
Commit
ce08446
·
1 Parent(s): 7324283

fix: decoder

Browse files
Files changed (1) hide show
  1. modules/ai_model.py +10 -4
modules/ai_model.py CHANGED
@@ -151,7 +151,7 @@ class AIModel:
151
  if input_type == "image" and isinstance(formatted_input, Image.Image):
152
  image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
153
  if image_token not in prompt:
154
- prompt = f"{image_token}\\n{prompt}"
155
  inputs = self.processor(
156
  text=prompt,
157
  images=formatted_input,
@@ -163,7 +163,13 @@ class AIModel:
163
  return_tensors="pt"
164
  ).to(self.model.device, dtype=torch.bfloat16)
165
 
166
- input_len = inputs.input_ids.shape[-1]
 
 
 
 
 
 
167
  with torch.inference_mode():
168
  generation_args = {
169
  "max_new_tokens": 512,
@@ -187,7 +193,8 @@ class AIModel:
187
  **inputs,
188
  **generation_args
189
  )
190
- generated_tokens = outputs[0][input_len:]
 
191
  decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
192
 
193
  return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"
@@ -205,7 +212,6 @@ class AIModel:
205
 
206
  if not self.is_available():
207
  log.error("模型未就绪,无法执行 chat_completion")
208
- # 对于需要JSON输出的场景,返回一个表示错误的有效JSON字符串
209
  if kwargs.get("response_format", {}).get("type") == "json_object":
210
  return '{"error": "Model not available"}'
211
  return "抱歉,AI 模型当前不可用。"
 
151
  if input_type == "image" and isinstance(formatted_input, Image.Image):
152
  image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
153
  if image_token not in prompt:
154
+ prompt = f"{image_token}\n{prompt}"
155
  inputs = self.processor(
156
  text=prompt,
157
  images=formatted_input,
 
163
  return_tensors="pt"
164
  ).to(self.model.device, dtype=torch.bfloat16)
165
 
166
+ if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
167
+ log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
168
+ inputs.input_ids = inputs.input_ids[:, :512]
169
+ if hasattr(inputs, 'attention_mask'):
170
+ inputs.attention_mask = inputs.attention_mask[:, :512]
171
+
172
+
173
  with torch.inference_mode():
174
  generation_args = {
175
  "max_new_tokens": 512,
 
193
  **inputs,
194
  **generation_args
195
  )
196
+ input_length = inputs.input_ids.shape[-1]
197
+ generated_tokens = outputs[0][input_length:]
198
  decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
199
 
200
  return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"
 
212
 
213
  if not self.is_available():
214
  log.error("模型未就绪,无法执行 chat_completion")
 
215
  if kwargs.get("response_format", {}).get("type") == "json_object":
216
  return '{"error": "Model not available"}'
217
  return "抱歉,AI 模型当前不可用。"