Yuchan
commited on
Update Model.py
Browse files
Model.py
CHANGED
|
@@ -77,9 +77,11 @@ def text_to_ids(text):
|
|
| 77 |
def ids_to_text(ids):
|
| 78 |
return sp.decode(ids)
|
| 79 |
|
| 80 |
-
def txt_stream(file_path):
|
| 81 |
with open(file_path, "r", encoding="utf-8") as f:
|
| 82 |
-
for line in f:
|
|
|
|
|
|
|
| 83 |
text = line.strip()
|
| 84 |
if not text:
|
| 85 |
continue
|
|
@@ -98,15 +100,16 @@ def txt_stream(file_path):
|
|
| 98 |
tf.convert_to_tensor(target, dtype=tf.int32)
|
| 99 |
)
|
| 100 |
|
| 101 |
-
|
| 102 |
dataset = tf.data.Dataset.from_generator(
|
| 103 |
-
lambda: txt_stream(DATA_PATH),
|
| 104 |
output_signature=(
|
| 105 |
tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
|
| 106 |
tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
|
| 107 |
)
|
| 108 |
)
|
| 109 |
|
|
|
|
| 110 |
dataset = dataset.shuffle(2000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
|
| 111 |
|
| 112 |
with strategy.scope():
|
|
|
|
| 77 |
def ids_to_text(ids):
|
| 78 |
return sp.decode(ids)
|
| 79 |
|
| 80 |
+
def txt_stream(file_path, num_lines=None):
|
| 81 |
with open(file_path, "r", encoding="utf-8") as f:
|
| 82 |
+
for i, line in enumerate(f):
|
| 83 |
+
if num_lines is not None and i >= num_lines:
|
| 84 |
+
break # ์ง์ ํ ๋ผ์ธ๊น์ง๋ง ์ฝ์
|
| 85 |
text = line.strip()
|
| 86 |
if not text:
|
| 87 |
continue
|
|
|
|
| 100 |
tf.convert_to_tensor(target, dtype=tf.int32)
|
| 101 |
)
|
| 102 |
|
| 103 |
+
# Dataset ์์ฑ (์: ์ฒ์ 10,000๋ผ์ธ๋ง)
|
| 104 |
dataset = tf.data.Dataset.from_generator(
|
| 105 |
+
lambda: txt_stream(DATA_PATH, num_lines=10000),
|
| 106 |
output_signature=(
|
| 107 |
tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
|
| 108 |
tf.TensorSpec(shape=(max_len,), dtype=tf.int32),
|
| 109 |
)
|
| 110 |
)
|
| 111 |
|
| 112 |
+
|
| 113 |
dataset = dataset.shuffle(2000, seed=SEED).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
|
| 114 |
|
| 115 |
with strategy.scope():
|