Spaces:
Running
Running
Manually calculate dataloader len
Browse files- translate.py +20 -6
translate.py
CHANGED
@@ -1,15 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import (
|
2 |
AutoModelForSeq2SeqLM,
|
3 |
AutoTokenizer,
|
4 |
PreTrainedTokenizerBase,
|
5 |
DataCollatorForSeq2Seq,
|
6 |
)
|
7 |
-
|
8 |
-
|
9 |
-
import torch
|
10 |
-
from torch.utils.data import DataLoader
|
11 |
from dataset import DatasetReader, count_lines
|
12 |
-
|
13 |
from accelerate import Accelerator, DistributedType
|
14 |
from accelerate.memory_utils import find_executable_batch_size
|
15 |
|
@@ -183,7 +190,14 @@ def main(
|
|
183 |
generated_tokens, skip_special_tokens=True
|
184 |
)
|
185 |
if accelerator.is_main_process:
|
186 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
tgt_text = tgt_text[
|
188 |
: (total_lines * num_return_sequences) - samples_seen
|
189 |
]
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
from transformers import (
|
11 |
AutoModelForSeq2SeqLM,
|
12 |
AutoTokenizer,
|
13 |
PreTrainedTokenizerBase,
|
14 |
DataCollatorForSeq2Seq,
|
15 |
)
|
16 |
+
|
17 |
+
|
|
|
|
|
18 |
from dataset import DatasetReader, count_lines
|
19 |
+
|
20 |
from accelerate import Accelerator, DistributedType
|
21 |
from accelerate.memory_utils import find_executable_batch_size
|
22 |
|
|
|
190 |
generated_tokens, skip_special_tokens=True
|
191 |
)
|
192 |
if accelerator.is_main_process:
|
193 |
+
if (
|
194 |
+
step
|
195 |
+
== math.ceil(
|
196 |
+
math.ceil(total_lines / batch_size)
|
197 |
+
/ accelerator.num_processes
|
198 |
+
)
|
199 |
+
- 1
|
200 |
+
):
|
201 |
tgt_text = tgt_text[
|
202 |
: (total_lines * num_return_sequences) - samples_seen
|
203 |
]
|