Gong Baitao commited on
Commit
8ea0ba9
1 Parent(s): 02b2bad

Update tokenization_cpmbee.py

Browse files
Files changed (1) hide show
  1. tokenization_cpmbee.py +130 -0
tokenization_cpmbee.py CHANGED
@@ -18,6 +18,7 @@ import os
18
  from typing import Any, Dict, List, Optional, Tuple, Union
19
 
20
  import numpy as np
 
21
  from typing_extensions import TypedDict
22
 
23
  from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
@@ -866,3 +867,132 @@ class CpmBeeTokenizer(PreTrainedTokenizer):
866
  )
867
 
868
  return batch_outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from typing import Any, Dict, List, Optional, Tuple, Union
19
 
20
  import numpy as np
21
+ from numpy.typing import NDArray
22
  from typing_extensions import TypedDict
23
 
24
  from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
 
867
  )
868
 
869
  return batch_outputs
870
+
871
+ def prepare_for_finetune(
872
+ self,
873
+ data_list: List[Dict],
874
+ max_length: int = 2048
875
+ ):
876
+ _inputs: List[NDArray[np.int32]] = []
877
+ _inputs_sub: List[NDArray[np.int32]] = []
878
+ _context: List[NDArray[np.int8]] = []
879
+ _sample_ids: List[NDArray[np.int32]] = []
880
+ _segments: List[NDArray[np.int32]] = []
881
+ _num_segments: List[NDArray[np.int32]] = []
882
+ _segment_rel_offset: List[NDArray[np.int32]] = []
883
+ _segment_rel: List[NDArray[np.int32]] = []
884
+ _spans: List[List[int]] = []
885
+ _raw_data: List[List[Any]] = []
886
+
887
+ raw_data = {}
888
+ for data in data_list:
889
+ (
890
+ input_ids,
891
+ input_id_subs,
892
+ context,
893
+ segment_ids,
894
+ segment_rel,
895
+ n_segments,
896
+ _
897
+ ) = self.convert_data_to_id(data)
898
+
899
+ input_ids = input_ids[: max_length]
900
+ context = context[: max_length]
901
+ segment_ids = segment_ids[: max_length]
902
+ raw_data["input"] = data
903
+ raw_data["samples"] = []
904
+
905
+ sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
906
+ segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
907
+ num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
908
+
909
+ _inputs.append(input_ids)
910
+ _inputs_sub.append(input_id_subs)
911
+ _context.append(context)
912
+ _sample_ids.append(sample_ids)
913
+ _segments.append(segment_ids)
914
+ _num_segments.append(num_segments)
915
+ _segment_rel_offset.append(segment_rel_offset)
916
+ _segment_rel.append(segment_rel)
917
+ _spans.append([input_ids.shape[0]])
918
+ _raw_data.append([raw_data])
919
+
920
+ batch_size = len(_inputs)
921
+ inputs = np.zeros((batch_size, max_length), dtype=np.int32)
922
+ inputs_sub = np.zeros((batch_size, max_length), dtype=np.int32)
923
+ context = np.zeros((batch_size, max_length), dtype=np.int8)
924
+ sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
925
+ segments = np.zeros((batch_size, max_length), dtype=np.int32)
926
+ num_segments = np.zeros((batch_size, max_length), dtype=np.int32)
927
+ segment_rel_offset = np.zeros((batch_size, max_length), dtype=np.int32)
928
+ tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
929
+
930
+ max_rel = 0
931
+ for i in range(batch_size):
932
+ max_rel = max(max_rel, _segment_rel[i].shape[0])
933
+ segment_rel = np.zeros((batch_size, max_rel), dtype=np.int32)
934
+ spans = np.zeros((batch_size, max_length), dtype=np.int32)
935
+ length = np.zeros((batch_size,), dtype=np.int32)
936
+
937
+ batch_ext_table_map: Dict[Tuple[int, int], int] = {}
938
+ batch_ext_table_ids: List[int] = []
939
+ batch_ext_table_sub: List[int] = []
940
+ raw_data_list: List[Any] = []
941
+
942
+ for i in range(batch_size):
943
+ instance_length = _inputs[i].shape[0]
944
+ rel_size = _segment_rel[i].shape[0]
945
+ inputs[i, :instance_length] = _inputs[i]
946
+ inputs_sub[i, :instance_length] = _inputs_sub[i]
947
+ context[i, :instance_length] = _context[i]
948
+ sample_ids[i, :instance_length] = _sample_ids[i]
949
+ segments[i, :instance_length] = _segments[i]
950
+ num_segments[i, :instance_length] = _num_segments[i]
951
+ segment_rel_offset[i, :instance_length] = _segment_rel_offset[i]
952
+ segment_rel[i, :rel_size] = _segment_rel[i]
953
+
954
+ span_begin = 0
955
+ for span_id, span_end in enumerate(_spans[i]):
956
+ spans[i, span_begin:span_end] = span_id
957
+ span_begin = span_end
958
+ length[i] = instance_length
959
+ raw_data_list.extend(_raw_data[i])
960
+
961
+ for j in range(instance_length):
962
+ idx, idx_sub = _inputs[i][j], _inputs_sub[i][j]
963
+ tgt_idx = idx
964
+ if idx_sub > 0:
965
+ # need to be in ext table
966
+ if (idx, idx_sub) not in batch_ext_table_map:
967
+ batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
968
+ batch_ext_table_ids.append(idx)
969
+ batch_ext_table_sub.append(idx_sub)
970
+ tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.vocab_size
971
+ if j > 1 and context[i, j - 1] == 0:
972
+ if idx != self.bos_token_id:
973
+ tgt[i, j - 1] = tgt_idx
974
+ else:
975
+ tgt[i, j - 1] = self.eos_token_id
976
+ if context[i, instance_length - 1] == 0:
977
+ tgt[i, instance_length - 1] = self.eos_token_id
978
+
979
+ if len(batch_ext_table_map) == 0:
980
+ # placeholder
981
+ batch_ext_table_ids.append(0)
982
+ batch_ext_table_sub.append(1)
983
+
984
+ return BatchEncoding({
985
+ "input_ids": inputs,
986
+ "input_id_sub": inputs_sub,
987
+ "length": length,
988
+ "context": context > 0,
989
+ "sample_ids": sample_ids,
990
+ "num_segments": num_segments,
991
+ "segment": segments,
992
+ "segment_rel_offset": segment_rel_offset,
993
+ "segment_rel": segment_rel,
994
+ "span": spans,
995
+ "labels": tgt,
996
+ "ext_table_ids": np.array(batch_ext_table_ids, dtype=np.int32),
997
+ "ext_table_sub": np.array(batch_ext_table_sub, dtype=np.int32)
998
+ }, tensor_type="pt")