Yuchan commited on
Commit
a638654
·
verified ·
1 Parent(s): 19949b0

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +46 -70
AlphaS2S.py CHANGED
@@ -81,87 +81,61 @@ def ids_to_text(ids):
81
  return sp.decode(ids)
82
 
83
  # =======================
84
- # 2) 데이터셋 생성 함수 (기존 코드와 동일)
85
  # =======================
86
-
87
  def jsonl_stream(file_path):
88
  with open(file_path, "r", encoding="utf-8") as f:
89
  for line in f:
90
  data = json.loads(line)
91
- conversations = data.get("conversations", [])
92
- for i in range(0, len(conversations) - 1, 2):
93
- human_msg = conversations[i]
94
- gpt_msg = conversations[i + 1]
95
- if human_msg.get("from") != "human" or gpt_msg.get("from") != "gpt":
96
- continue
97
-
98
- prompt = human_msg.get("value", "").strip()
99
- response = gpt_msg.get("value", "").strip()
100
- full = f"<start> {prompt} <sep> {response} <end>"
101
- if "<sep>" not in full:
102
- continue
103
-
104
- sep_index = full.index("<sep>")
105
-
106
- # 인코더 입력은 <start> 프롬프트 <sep> 부분, 디코더 입력은 <sep> 응답 <end> 부분
107
- # (Unified Input: 인코더/디코더 입력 모두 full_input을 사용)
108
- input_text = full
109
-
110
- # 타겟 시퀀스는 응답 시작 부분부터 <end>까지이며, 입력보다 한 칸 시프트됨
111
- # 여기서 target_text는 응답 부분만 추출하여 타겟 마스킹에 사용됩니다.
112
- target_text_raw = full[sep_index + len("<sep>"):]
113
-
114
- input_ids = text_to_ids(input_text) # 전체 시퀀스
115
- target_ids_raw = text_to_ids(target_text_raw) # 응답 부분만
116
-
117
- # 길이 처리 마스킹 로직은 기존 코드를 그대로 유지
118
- full_input = input_ids[:max_len]
119
- target_ids = target_ids_raw[:max_len - len(input_ids)]
120
-
121
- available_len = max_len - len(input_ids)
122
-
123
- if available_len <= 0:
124
- input_ids = input_ids[-max_len:]
125
- target_ids = []
126
- target_mask = [0] * len(input_ids)
127
- else:
128
- target_ids = target_ids[:available_len]
129
- target_mask = [0] * len(input_ids) + [1] * len(target_ids)
130
-
131
- full_input = input_ids + target_ids
132
- pad_len = max_len - len(full_input)
133
- full_input += [pad_id] * pad_len
134
- target_mask += [0] * pad_len
135
-
136
- # 타겟 시퀀스는 입력 시퀀스보다 한 칸 시프트된 형태
137
- target_seq = full_input[1:] + [end_id]
138
- target_seq = target_seq[:max_len]
139
-
140
- # 마스킹된 타겟 생성 (프롬프트/패딩 부분은 pad_id로 대체)
141
- masked_target = [
142
- t if m == 1 else pad_id
143
- for t, m in zip(target_seq, target_mask)
144
- ]
145
-
146
- # AlphaS2S는 인코더/디코더 입력으로 같은 시퀀스를 사용
147
- # 입력 시퀀스 = full_input
148
- # 타겟 시퀀스 = masked_target
149
- yield (
150
- tf.convert_to_tensor(full_input, dtype=tf.int32),
151
- tf.convert_to_tensor(full_input, dtype=tf.int32), # 디코더 입력도 동일하게 전달
152
- tf.convert_to_tensor(masked_target, dtype=tf.int32) # 실제 타겟
153
- )
154
 
 
 
 
155
  dataset = tf.data.Dataset.from_generator(
156
  lambda: jsonl_stream(DATA_PATH),
157
  output_signature=(
158
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # enc_inputs
159
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # dec_inputs
160
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # target
161
- ),
162
  )
163
 
164
- # 학습을 위해 딕셔너리 형태로 맵핑
165
  def map_fn(enc_input, dec_input, dec_target):
166
  return {"enc_inputs": enc_input, "dec_inputs": dec_input}, dec_target
167
 
@@ -171,6 +145,8 @@ dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True
171
  with strategy.scope():
172
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
173
 
 
 
174
  # =======================
175
  # 3) 모델 레이어 (기존 코드 유지)
176
  # =======================
 
81
  return sp.decode(ids)
82
 
83
  # =======================
84
+ # JSONL TF Dataset 로드 (ID 레벨 특수 토큰 포함)
85
  # =======================
 
86
  def jsonl_stream(file_path):
87
  with open(file_path, "r", encoding="utf-8") as f:
88
  for line in f:
89
  data = json.loads(line)
90
+ context = data["context"]
91
+ prompt = data["prompt"]
92
+ answer = data["answer"]
93
+
94
+ # =======================
95
+ # Encoder input: ID 레벨에서 특수 토큰 명시
96
+ # =======================
97
+ enc_ids = [context_s_id] + text_to_ids(context) + [context_e_id] + \
98
+ [user_s_id] + text_to_ids(prompt) + [user_e_id]
99
+ enc_ids = enc_ids[:max_len] # max_len 제한
100
+
101
+ # =======================
102
+ # Decoder input: <sos> + answer
103
+ # =======================
104
+ dec_input_ids = [start_id] + text_to_ids(answer)
105
+ dec_input_ids = dec_input_ids[:max_len]
106
+
107
+ # =======================
108
+ # Target: answer + <eos>
109
+ # =======================
110
+ target_ids = text_to_ids(answer) + [end_id]
111
+ target_ids = target_ids[:max_len]
112
+
113
+ # =======================
114
+ # Padding
115
+ # =======================
116
+ enc_ids += [pad_id] * (max_len - len(enc_ids))
117
+ dec_input_ids += [pad_id] * (max_len - len(dec_input_ids))
118
+ target_ids += [pad_id] * (max_len - len(target_ids))
119
+
120
+ yield (
121
+ tf.convert_to_tensor(enc_ids, dtype=tf.int32),
122
+ tf.convert_to_tensor(dec_input_ids, dtype=tf.int32),
123
+ tf.convert_to_tensor(target_ids, dtype=tf.int32),
124
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # =======================
127
+ # TF Dataset 생성
128
+ # =======================
129
  dataset = tf.data.Dataset.from_generator(
130
  lambda: jsonl_stream(DATA_PATH),
131
  output_signature=(
132
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # enc_inputs
133
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # dec_inputs
134
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # target
135
+ )
136
  )
137
 
138
+ # 학습을 위해 딕셔너리 형태로 매핑
139
  def map_fn(enc_input, dec_input, dec_target):
140
  return {"enc_inputs": enc_input, "dec_inputs": dec_input}, dec_target
141
 
 
145
  with strategy.scope():
146
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
147
 
148
+ print("✅ ID 레벨 특수 토큰 적용 Dataset 로드 완료:", dist_dataset)
149
+
150
  # =======================
151
  # 3) 모델 레이어 (기존 코드 유지)
152
  # =======================