Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +2 -2
AlphaS2S.py
CHANGED
|
@@ -13,7 +13,7 @@ tf.get_logger().setLevel("ERROR")
|
|
| 13 |
SEED = 42
|
| 14 |
tf.random.set_seed(SEED)
|
| 15 |
np.random.seed(SEED)
|
| 16 |
-
max_len =
|
| 17 |
batch_size = 128
|
| 18 |
|
| 19 |
# TPU 초기화 (기존 코드와 동일)
|
|
@@ -419,7 +419,7 @@ print("\n✅ 모델 가중치 저장 완료!")
|
|
| 419 |
# 6) 추론 함수 (기존 코드 유지)
|
| 420 |
# =======================
|
| 421 |
|
| 422 |
-
def generate_text_topp(model, prompt, max_len=
|
| 423 |
# 인코더 입력은 <start> Prompt <sep> 만 사용
|
| 424 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
| 425 |
model_input = model_input[:max_len]
|
|
|
|
| 13 |
SEED = 42
|
| 14 |
tf.random.set_seed(SEED)
|
| 15 |
np.random.seed(SEED)
|
| 16 |
+
max_len = 150 # 기존 코드에서 200으로 설정됨
|
| 17 |
batch_size = 128
|
| 18 |
|
| 19 |
# TPU 초기화 (기존 코드와 동일)
|
|
|
|
| 419 |
# 6) 추론 함수 (기존 코드 유지)
|
| 420 |
# =======================
|
| 421 |
|
| 422 |
+
def generate_text_topp(model, prompt, max_len=150, max_gen=100, p=0.9, temperature=0.8, min_len=20):
|
| 423 |
# 인코더 입력은 <start> Prompt <sep> 만 사용
|
| 424 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
| 425 |
model_input = model_input[:max_len]
|