Iker commited on
Commit
1e19e28
1 Parent(s): dbb5f39

Initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +5 -0
  2. dataset.py +56 -0
  3. translate.py +160 -0
.gitignore CHANGED
@@ -122,3 +122,8 @@ dmypy.json
122
 
123
  # Pyre type checker
124
  .pyre/
 
 
 
 
 
 
122
 
123
  # Pyre type checker
124
  .pyre/
125
+
126
+ # For IntelliJ
127
+ .idea/
128
+
129
+ debug/
dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, TextIO, Dict, Optional
2
+ import torch
3
+ from torch.utils.data import IterableDataset
4
+ from torch.utils.data.dataset import T_co
5
+
6
+
7
+ def blocks(files, size=65536):
8
+ while True:
9
+ b = files.read(size)
10
+ if not b:
11
+ break
12
+ yield b
13
+
14
+
15
+ def count_lines(input_path: str) -> int:
16
+ with open(input_path, "r", encoding="utf8") as f:
17
+ return sum(bl.count("\n") for bl in blocks(f))
18
+
19
+
20
+ class DatasetReader(IterableDataset):
21
+ def __init__(self, filename, tokenizer, max_length=128):
22
+ self.filename = filename
23
+ self.tokenizer = tokenizer
24
+ self.max_length = max_length
25
+
26
+ def preprocess(self, text: str):
27
+ return self.tokenizer(
28
+ text.rstrip().strip(),
29
+ padding="max_length",
30
+ truncation=True,
31
+ max_length=self.max_length,
32
+ return_tensors="pt",
33
+ )
34
+
35
+ def __iter__(self):
36
+ file_itr = open(self.filename, "r")
37
+ mapped_itr = map(self.preprocess, file_itr)
38
+ return mapped_itr
39
+
40
+
41
+ def collate_function(batch: List[T_co]) -> Dict[str, torch.Tensor]:
42
+ return {
43
+ "input_ids": torch.stack([item["input_ids"][0] for item in batch]),
44
+ "attention_mask": torch.stack([item["attention_mask"][0] for item in batch]),
45
+ }
46
+
47
+
48
+ def get_dataloader(
49
+ filename: str, tokenizer: str, batch_size: int, max_length: int
50
+ ) -> torch.utils.data.DataLoader:
51
+ dataset = DatasetReader(filename, tokenizer, max_length)
52
+ return torch.utils.data.DataLoader(
53
+ dataset,
54
+ batch_size=batch_size,
55
+ collate_fn=collate_function,
56
+ )
translate.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
2
+ from tqdm import tqdm
3
+ from typing import TextIO, List
4
+ import argparse
5
+ import torch
6
+ from dataset import get_dataloader, count_lines
7
+ import os
8
+
9
+
10
+ def main(
11
+ sentences_path,
12
+ output_path,
13
+ source_lang,
14
+ target_lang,
15
+ batch_size,
16
+ model_name: str = "facebook/m2m100_1.2B",
17
+ tensorrt: bool = False,
18
+ precision: int = 32,
19
+ max_length: int = 128,
20
+ ):
21
+
22
+ if not os.path.exists(os.path.dirname(output_path)):
23
+ os.makedirs(os.path.dirname(output_path))
24
+
25
+ print("Loading tokenizer...")
26
+ tokenizer = M2M100Tokenizer.from_pretrained(model_name)
27
+ print("Loading model...")
28
+ model = M2M100ForConditionalGeneration.from_pretrained(model_name)
29
+ print(f"Model loaded.\n")
30
+
31
+ tokenizer.src_lang = source_lang
32
+ lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
33
+
34
+ model.eval()
35
+
36
+ total_lines: int = count_lines(sentences_path)
37
+ print(f"We will translate {total_lines} lines.")
38
+ data_loader = get_dataloader(
39
+ filename=sentences_path,
40
+ tokenizer=tokenizer,
41
+ batch_size=batch_size,
42
+ max_length=128,
43
+ )
44
+
45
+ if precision == 16:
46
+ dtype = torch.float16
47
+ elif precision == 32:
48
+ dtype = torch.float32
49
+ elif precision == 64:
50
+ dtype = torch.float64
51
+ else:
52
+ raise ValueError("Precision must be 16, 32 or 64.")
53
+
54
+ if tensorrt:
55
+ import torch_tensorrt
56
+
57
+ traced_model = torch.jit.trace(
58
+ model, [torch.randn((batch_size, max_length)).to("cuda")]
59
+ )
60
+ model = torch_tensorrt.compile(
61
+ traced_model,
62
+ inputs=[torch_tensorrt.Input((batch_size, max_length), dtype=dtype)],
63
+ enabled_precisions={dtype},
64
+ )
65
+ else:
66
+ if torch.cuda.is_available():
67
+ model.to("cuda", dtype=dtype)
68
+ else:
69
+ model.to("cpu", dtype=dtype)
70
+ print("CUDA not available. Using CPU. This will be slow.")
71
+
72
+ with tqdm(total=total_lines, desc="Dataset translation") as pbar, open(
73
+ output_path, "w+", encoding="utf-8"
74
+ ) as output_file:
75
+ with torch.no_grad():
76
+ for batch in data_loader:
77
+ generated_tokens = model.generate(
78
+ **batch, forced_bos_token_id=lang_code_to_idx
79
+ )
80
+ tgt_text = tokenizer.batch_decode(
81
+ generated_tokens.cpu(), skip_special_tokens=True
82
+ )
83
+
84
+ print("\n".join(tgt_text), file=output_file)
85
+
86
+ pbar.update(len(tgt_text))
87
+
88
+ print(f"Translation done.\n")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ parser = argparse.ArgumentParser(description="Run the translation experiments")
93
+ parser.add_argument(
94
+ "--sentences_path",
95
+ type=str,
96
+ required=True,
97
+ help="Path to a txt file containing the sentences to translate. One sentence per line.",
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--output_path",
102
+ type=str,
103
+ required=True,
104
+ help="Path to a txt file where the translated sentences will be written.",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--source_lang",
109
+ type=str,
110
+ required=True,
111
+ help="Source language id. See: https://huggingface.co/facebook/m2m100_1.2B",
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--target_lang",
116
+ type=str,
117
+ required=True,
118
+ help="Target language id. See: https://huggingface.co/facebook/m2m100_1.2B",
119
+ )
120
+
121
+ parser.add_argument(
122
+ "--batch_size",
123
+ type=int,
124
+ default=8,
125
+ help="Batch size",
126
+ )
127
+
128
+ parser.add_argument(
129
+ "--model_name",
130
+ type=str,
131
+ default="facebook/m2m100_1.2B",
132
+ help="Path to the model to use. See: https://huggingface.co/models",
133
+ )
134
+
135
+ parser.add_argument(
136
+ "--precision",
137
+ type=int,
138
+ default=32,
139
+ choices=[16, 32, 64],
140
+ help="Precision of the model. 16, 32 or 64.",
141
+ )
142
+
143
+ parser.add_argument(
144
+ "--tensorrt",
145
+ action="store_true",
146
+ help="Use TensorRT to compile the model.",
147
+ )
148
+
149
+ args = parser.parse_args()
150
+
151
+ main(
152
+ sentences_path=args.sentences_path,
153
+ output_path=args.output_path,
154
+ source_lang=args.source_lang,
155
+ target_lang=args.target_lang,
156
+ batch_size=args.batch_size,
157
+ model_name=args.model_name,
158
+ precision=args.precision,
159
+ tensorrt=args.tensorrt,
160
+ )