zxdu20 commited on
Commit
373fd6b
1 Parent(s): e22cddf

Fix attention_mask and position_ids

Browse files
Files changed (1) hide show
  1. tokenization_chatglm.py +23 -21
tokenization_chatglm.py CHANGED
@@ -340,7 +340,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
340
  token_ids_0 += [self.sp_tokenizer[self.bos_token]]
341
 
342
  if token_ids_1 is not None:
343
- if token_ids_1[-1] != eop_id:
344
  token_ids_1 += [eop_id]
345
  token_ids_0 += token_ids_1
346
 
@@ -397,26 +397,28 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
397
  needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
398
 
399
  # Initialize attention mask if not present.
400
- if return_attention_mask:
401
- if bos_token_id in required_input:
402
- context_length = required_input.index(bos_token_id)
403
- else:
404
- context_length = seq_length
405
- attention_mask = np.ones((1, seq_length, seq_length))
406
- attention_mask = np.tril(attention_mask)
407
- attention_mask[:, :, :context_length] = 1
408
- attention_mask = np.bool_(attention_mask < 0.5)
409
- encoded_inputs["attention_mask"] = attention_mask
410
-
411
- position_ids = np.arange(seq_length, dtype=np.int64)
412
- mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
413
- if mask_token in required_input:
414
- mask_position = required_input.index(mask_token)
415
- position_ids[context_length:] = mask_position
416
- block_position_ids = np.concatenate(
417
- [np.zeros(context_length, dtype=np.int64),
418
- np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
419
- encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
 
 
420
 
421
  if needs_to_be_padded:
422
  difference = max_length - len(required_input)
 
340
  token_ids_0 += [self.sp_tokenizer[self.bos_token]]
341
 
342
  if token_ids_1 is not None:
343
+ if not token_ids_1 or token_ids_1[-1] != eop_id:
344
  token_ids_1 += [eop_id]
345
  token_ids_0 += token_ids_1
346
 
 
397
  needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
398
 
399
  # Initialize attention mask if not present.
400
+ if max_length is not None:
401
+ if "attention_mask" not in encoded_inputs:
402
+ if bos_token_id in required_input:
403
+ context_length = required_input.index(bos_token_id)
404
+ else:
405
+ context_length = seq_length
406
+ attention_mask = np.ones((1, seq_length, seq_length))
407
+ attention_mask = np.tril(attention_mask)
408
+ attention_mask[:, :, :context_length] = 1
409
+ attention_mask = np.bool_(attention_mask < 0.5)
410
+ encoded_inputs["attention_mask"] = attention_mask
411
+
412
+ if "position_ids" not in encoded_inputs:
413
+ position_ids = np.arange(seq_length, dtype=np.int64)
414
+ mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
415
+ if mask_token in required_input:
416
+ mask_position = required_input.index(mask_token)
417
+ position_ids[context_length:] = mask_position
418
+ block_position_ids = np.concatenate(
419
+ [np.zeros(context_length, dtype=np.int64),
420
+ np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
421
+ encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
422
 
423
  if needs_to_be_padded:
424
  difference = max_length - len(required_input)