Iker commited on
Commit
5bd7f14
·
1 Parent(s): 011cb1f

Manually calculate dataloader len

Browse files
Files changed (1) hide show
  1. translate.py +20 -6
translate.py CHANGED
@@ -1,15 +1,22 @@
 
 
 
 
 
 
 
 
 
1
  from transformers import (
2
  AutoModelForSeq2SeqLM,
3
  AutoTokenizer,
4
  PreTrainedTokenizerBase,
5
  DataCollatorForSeq2Seq,
6
  )
7
- from tqdm import tqdm
8
- import argparse
9
- import torch
10
- from torch.utils.data import DataLoader
11
  from dataset import DatasetReader, count_lines
12
- import os
13
  from accelerate import Accelerator, DistributedType
14
  from accelerate.memory_utils import find_executable_batch_size
15
 
@@ -183,7 +190,14 @@ def main(
183
  generated_tokens, skip_special_tokens=True
184
  )
185
  if accelerator.is_main_process:
186
- if step == len(data_loader) - 1:
 
 
 
 
 
 
 
187
  tgt_text = tgt_text[
188
  : (total_lines * num_return_sequences) - samples_seen
189
  ]
 
1
+ import os
2
+ import math
3
+ import argparse
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+
8
+ from tqdm import tqdm
9
+
10
  from transformers import (
11
  AutoModelForSeq2SeqLM,
12
  AutoTokenizer,
13
  PreTrainedTokenizerBase,
14
  DataCollatorForSeq2Seq,
15
  )
16
+
17
+
 
 
18
  from dataset import DatasetReader, count_lines
19
+
20
  from accelerate import Accelerator, DistributedType
21
  from accelerate.memory_utils import find_executable_batch_size
22
 
 
190
  generated_tokens, skip_special_tokens=True
191
  )
192
  if accelerator.is_main_process:
193
+ if (
194
+ step
195
+ == math.ceil(
196
+ math.ceil(total_lines / batch_size)
197
+ / accelerator.num_processes
198
+ )
199
+ - 1
200
+ ):
201
  tgt_text = tgt_text[
202
  : (total_lines * num_return_sequences) - samples_seen
203
  ]