Iker commited on
Commit
b873cb9
·
1 Parent(s): 749ff6d

M2M100 with transformers and accelerate

Browse files
README.md CHANGED
@@ -1 +1,103 @@
1
  # Easy-Translate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Easy-Translate
2
+
3
+ Easy-translate is a script for translating large text files in your machine using the [M2M100 models](https://arxiv.org/pdf/2010.11125.pdf) from Facebook/Meta AI.
4
+
5
+ M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation.
6
+ It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
7
+ The model that can directly translate between the 9,900 directions of 100 languages.
8
+
9
+ Easy-Translate is built on top of 🤗HuggingFace's
10
+ [Transformers](https://huggingface.co/docs/transformers/index) and
11
+ 🤗HuggingFace's [Accelerate](https://huggingface.co/docs/accelerate/index) library. We support:
12
+
13
+ * CPU / GPU / multi-GPU / TPU acceleration
14
+ * BF16 / FP16 / FB32 precision.
15
+ * Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
16
+ * Sharded Data Parallel to load huge models sharded on multiple GPUs (See: https://huggingface.co/docs/accelerate/fsdp).
17
+
18
+ Test the 🔌 Online Demo here: https://huggingface.co/spaces/Iker/Translate-100-languages
19
+
20
+ ## Supported languages
21
+ See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.
22
+
23
+ **List of supported languages:**
24
+ Afrikaans, Amharic, Arabic, Asturian, Azerbaijani, Bashkir, Belarusian, Bulgarian, Bengali, Breton, Bosnian, Catalan, Cebuano, Czech, Welsh, Danish, German, Greeek, English, Spanish, Estonian, Persian, Fulah, Finnish, French, WesternFrisian, Irish, Gaelic, Galician, Gujarati, Hausa, Hebrew, Hindi, Croatian, Haitian, Hungarian, Armenian, Indonesian, Igbo, Iloko, Icelandic, Italian, Japanese, Javanese, Georgian, Kazakh, CentralKhmer, Kannada, Korean, Luxembourgish, Ganda, Lingala, Lao, Lithuanian, Latvian, Malagasy, Macedonian, Malayalam, Mongolian, Marathi, Malay, Burmese, Nepali, Dutch, Norwegian, NorthernSotho, Occitan, Oriya, Panjabi, Polish, Pushto, Portuguese, Romanian, Russian, Sindhi, Sinhala, Slovak, Slovenian, Somali, Albanian, Serbian, Swati, Sundanese, Swedish, Swahili, Tamil, Thai, Tagalog, Tswana, Turkish, Ukrainian, Urdu, Uzbek, Vietnamese, Wolof, Xhosa, Yiddish, Yoruba, Chinese, Zulu
25
+
26
+ ## Supported Models
27
+
28
+ * **Facebook/m2m100_418M**: https://huggingface.co/facebook/m2m100_418M
29
+
30
+ * **Facebook/m2m100_1.2B**: https://huggingface.co/facebook/m2m100_1.2B
31
+
32
+ * **Facebook/m2m100_12B**: https://huggingface.co/facebook/m2m100-12B-avg-5-ckpt
33
+
34
+ * Any other m2m100 model from HuggingFace's Hub: https://huggingface.co/models?search=m2m100
35
+
36
+
37
+ ## Requirements:
38
+
39
+ ```
40
+ Pytorch >= 1.10.0
41
+ See: https://pytorch.org/get-started/locally/
42
+
43
+ Accelerate >= 0.7.1
44
+ pip install --upgrade accelerate
45
+
46
+ HuggingFace Transformers
47
+ pip install --upgrade transformers
48
+ ```
49
+
50
+ ## Translate a file
51
+
52
+ Run `python translate.py -h` for more info.
53
+
54
+ #### Using a single CPU / GPU:
55
+ ```bash
56
+ accelerate launch translate.py \
57
+ --sentences_path sample_text/en.txt \
58
+ --output_path sample_text/en2es.translation.txt \
59
+ --source_lang en \
60
+ --target_lang es \
61
+ --model_name facebook/m2m100_1.2B
62
+ ```
63
+
64
+ #### Multi-GPU:
65
+ See Accelerate documentation for more information (multi-node, TPU, Sharded model...): https://huggingface.co/docs/accelerate/index
66
+ You can use the Accelerate CLI to configure the Accelerate environment (Run
67
+ `accelerate config` in your terminal) instead of using the
68
+ `--multi_gpu and --num_processes` flags.
69
+
70
+ ```bash
71
+ accelerate launch --multi_gpu --num_processes 2 --num_machines 1 translate.py \
72
+ --sentences_path sample_text/en.txt \
73
+ --output_path sample_text/en2es.translation.txt \
74
+ --source_lang en \
75
+ --target_lang es \
76
+ --model_name facebook/m2m100_1.2B
77
+ ```
78
+
79
+ #### Automatic batch size finder:
80
+ We will automatically find a batch size that fits in your GPU memory.
81
+ The default initial batch size is 128 (You can set it with the `--starting_batch_size 128` flag).
82
+ If we find an Out Of Memory error, we will automatically decrease the batch size until we find a working one.
83
+
84
+
85
+
86
+ #### Choose precision:
87
+ Use the `--precision` flag to choose the precision of the model. You can choose between: bf16, fp16 and 32.
88
+
89
+ ```bash
90
+ accelerate launch translate.py \
91
+ --sentences_path sample_text/en.txt \
92
+ --output_path sample_text/en2es.translation.txt \
93
+ --source_lang en \
94
+ --target_lang es \
95
+ --model_name facebook/m2m100_1.2B \
96
+ --precision fp16
97
+ ```
98
+
99
+ ## Evaluate translations
100
+
101
+ Work in progress...
102
+
103
+
dataset.py CHANGED
@@ -1,7 +1,4 @@
1
- from typing import List, TextIO, Dict, Optional
2
- import torch
3
  from torch.utils.data import IterableDataset
4
- from torch.utils.data.dataset import T_co
5
 
6
 
7
  def blocks(files, size=65536):
@@ -22,35 +19,22 @@ class DatasetReader(IterableDataset):
22
  self.filename = filename
23
  self.tokenizer = tokenizer
24
  self.max_length = max_length
 
25
 
26
  def preprocess(self, text: str):
 
 
 
 
27
  return self.tokenizer(
28
- text.rstrip().strip(),
29
- padding="max_length",
30
  truncation=True,
31
  max_length=self.max_length,
32
- return_tensors="pt",
33
  )
34
 
35
  def __iter__(self):
36
  file_itr = open(self.filename, "r")
37
  mapped_itr = map(self.preprocess, file_itr)
38
  return mapped_itr
39
-
40
-
41
- def collate_function(batch: List[T_co]) -> Dict[str, torch.Tensor]:
42
- return {
43
- "input_ids": torch.stack([item["input_ids"][0] for item in batch]),
44
- "attention_mask": torch.stack([item["attention_mask"][0] for item in batch]),
45
- }
46
-
47
-
48
- def get_dataloader(
49
- filename: str, tokenizer: str, batch_size: int, max_length: int
50
- ) -> torch.utils.data.DataLoader:
51
- dataset = DatasetReader(filename, tokenizer, max_length)
52
- return torch.utils.data.DataLoader(
53
- dataset,
54
- batch_size=batch_size,
55
- collate_fn=collate_function,
56
- )
 
 
 
1
  from torch.utils.data import IterableDataset
 
2
 
3
 
4
  def blocks(files, size=65536):
 
19
  self.filename = filename
20
  self.tokenizer = tokenizer
21
  self.max_length = max_length
22
+ self.current_line = 0
23
 
24
  def preprocess(self, text: str):
25
+ self.current_line += 1
26
+ text = text.rstrip().strip()
27
+ if len(text) == 0:
28
+ print(f"Warning: empty sentence at line {self.current_line}")
29
  return self.tokenizer(
30
+ text,
31
+ padding=False,
32
  truncation=True,
33
  max_length=self.max_length,
34
+ return_tensors=None,
35
  )
36
 
37
  def __iter__(self):
38
  file_itr = open(self.filename, "r")
39
  mapped_itr = map(self.preprocess, file_itr)
40
  return mapped_itr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sample_text/RADME.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Sample texts
2
+
3
+ We provide a few parallel sentences for easy debugging and testing.
4
+ Data has been extracted from the europarl v7 corpus: [http://www.statmt.org/europarl/v7/](http://www.statmt.org/europarl/v7/).
5
+
6
+ * **en.txt**: 1000 English sentences
7
+ * **es.txt**: 1000 Spanish sentences
8
+
9
+ Sentences in both files are parallel.
sample_text/en.txt ADDED
The diff for this file is too large to render. See raw diff
 
sample_text/en2es.translation.txt ADDED
The diff for this file is too large to render. See raw diff
 
sample_text/es.txt ADDED
The diff for this file is too large to render. See raw diff
 
supported_languages.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Supported languages
2
+
3
+ | Language | Id |
4
+ |---|---|
5
+ | Afrikaans | af |
6
+ | Amharic | am |
7
+ | Arabic | ar |
8
+ | Asturian | ast |
9
+ | Azerbaijani | az |
10
+ | Bashkir | ba |
11
+ | Belarusian | be |
12
+ | Bulgarian | bg |
13
+ | Bengali | bn |
14
+ | Breton | br |
15
+ | Bosnian | bs |
16
+ | Catalan | ca |
17
+ | Cebuano | ceb |
18
+ | Czech | cs |
19
+ | Welsh | cy |
20
+ | Danish | da |
21
+ | German | de |
22
+ | Greeek | el |
23
+ | English | en |
24
+ | Spanish | es |
25
+ | Estonian | et |
26
+ | Persian | fa |
27
+ | Fulah | ff |
28
+ | Finnish | fi |
29
+ | French | fr |
30
+ | WesternFrisian | fy |
31
+ | Irish | ga |
32
+ | Gaelic | gd |
33
+ | Galician | gl |
34
+ | Gujarati | gu |
35
+ | Hausa | ha |
36
+ | Hebrew | he |
37
+ | Hindi | hi |
38
+ | Croatian | hr |
39
+ | Haitian | ht |
40
+ | Hungarian | hu |
41
+ | Armenian | hy |
42
+ | Indonesian | id |
43
+ | Igbo | ig |
44
+ | Iloko | ilo |
45
+ | Icelandic | is |
46
+ | Italian | it |
47
+ | Japanese | ja |
48
+ | Javanese | jv |
49
+ | Georgian | ka |
50
+ | Kazakh | kk |
51
+ | CentralKhmer | km |
52
+ | Kannada | kn |
53
+ | Korean | ko |
54
+ | Luxembourgish | lb |
55
+ | Ganda | lg |
56
+ | Lingala | ln |
57
+ | Lao | lo |
58
+ | Lithuanian | lt |
59
+ | Latvian | lv |
60
+ | Malagasy | mg |
61
+ | Macedonian | mk |
62
+ | Malayalam | ml |
63
+ | Mongolian | mn |
64
+ | Marathi | mr |
65
+ | Malay | ms |
66
+ | Burmese | my |
67
+ | Nepali | ne |
68
+ | Dutch | nl |
69
+ | Norwegian | no |
70
+ | NorthernSotho | ns |
71
+ | Occitan | oc |
72
+ | Oriya | or |
73
+ | Panjabi | pa |
74
+ | Polish | pl |
75
+ | Pushto | ps |
76
+ | Portuguese | pt |
77
+ | Romanian | ro |
78
+ | Russian | ru |
79
+ | Sindhi | sd |
80
+ | Sinhala | si |
81
+ | Slovak | sk |
82
+ | Slovenian | sl |
83
+ | Somali | so |
84
+ | Albanian | sq |
85
+ | Serbian | sr |
86
+ | Swati | ss |
87
+ | Sundanese | su |
88
+ | Swedish | sv |
89
+ | Swahili | sw |
90
+ | Tamil | ta |
91
+ | Thai | th |
92
+ | Tagalog | tl |
93
+ | Tswana | tn |
94
+ | Turkish | tr |
95
+ | Ukrainian | uk |
96
+ | Urdu | ur |
97
+ | Uzbek | uz |
98
+ | Vietnamese | vi |
99
+ | Wolof | wo |
100
+ | Xhosa | xh |
101
+ | Yiddish | yi |
102
+ | Yoruba | yo |
103
+ | Chinese | zh |
104
+ | Zulu | zu |
translate.py CHANGED
@@ -1,99 +1,152 @@
1
- from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
 
 
 
 
 
2
  from tqdm import tqdm
3
- from typing import TextIO, List
4
  import argparse
5
  import torch
6
- from dataset import get_dataloader, count_lines
 
7
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def main(
11
- sentences_path,
12
- output_path,
13
- source_lang,
14
- target_lang,
15
- batch_size,
16
  model_name: str = "facebook/m2m100_1.2B",
17
- tensorrt: bool = False,
18
- precision: int = 32,
19
  max_length: int = 128,
 
20
  ):
21
 
22
  if not os.path.exists(os.path.dirname(output_path)):
23
  os.makedirs(os.path.dirname(output_path))
24
 
 
 
25
  print("Loading tokenizer...")
26
- tokenizer = M2M100Tokenizer.from_pretrained(model_name)
 
 
27
  print("Loading model...")
28
- model = M2M100ForConditionalGeneration.from_pretrained(model_name)
29
- print(f"Model loaded.\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  tokenizer.src_lang = source_lang
32
  lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
33
 
34
- model.eval()
 
 
 
 
35
 
36
  total_lines: int = count_lines(sentences_path)
37
- print(f"We will translate {total_lines} lines.")
38
- data_loader = get_dataloader(
39
- filename=sentences_path,
40
- tokenizer=tokenizer,
41
- batch_size=batch_size,
42
- max_length=128,
43
  )
44
 
45
- if precision == 16:
46
- dtype = torch.float16
47
- elif precision == 32:
48
- dtype = torch.float32
49
- elif precision == 64:
50
- dtype = torch.float64
51
- else:
52
- raise ValueError("Precision must be 16, 32 or 64.")
53
 
54
- if tensorrt:
55
- import torch_tensorrt
56
 
57
- device = "cuda"
 
 
 
 
 
 
58
 
59
- model.to(device)
60
 
61
- traced_model = torch.jit.trace(
62
- model, [torch.randn((batch_size, max_length)).to("cuda", dtype=torch.long)]
63
- )
64
- model = torch_tensorrt.compile(
65
- traced_model,
66
- inputs=[torch_tensorrt.Input((batch_size, max_length), dtype=torch.long)],
67
- enabled_precisions={dtype},
68
- )
69
- else:
70
- if torch.cuda.is_available():
71
- device = "cuda"
72
 
73
- else:
74
- device = "cpu"
75
- print("CUDA not available. Using CPU. This will be slow.")
76
- model.to(device, dtype=dtype)
77
 
78
- with tqdm(total=total_lines, desc="Dataset translation") as pbar, open(
79
- output_path, "w+", encoding="utf-8"
80
- ) as output_file:
81
- with torch.no_grad():
82
- for batch in data_loader:
83
- batch["input_ids"] = batch["input_ids"].to(device)
84
- batch["attention_mask"] = batch["attention_mask"].to(device)
85
 
86
- generated_tokens = model.generate(
87
- **batch, forced_bos_token_id=lang_code_to_idx
88
- )
89
- tgt_text = tokenizer.batch_decode(
90
- generated_tokens.cpu(), skip_special_tokens=True
91
- )
92
 
93
- print("\n".join(tgt_text), file=output_file)
 
 
94
 
95
- pbar.update(len(tgt_text))
96
 
 
 
 
97
  print(f"Translation done.\n")
98
 
99
 
@@ -117,21 +170,21 @@ if __name__ == "__main__":
117
  "--source_lang",
118
  type=str,
119
  required=True,
120
- help="Source language id. See: https://huggingface.co/facebook/m2m100_1.2B",
121
  )
122
 
123
  parser.add_argument(
124
  "--target_lang",
125
  type=str,
126
  required=True,
127
- help="Target language id. See: https://huggingface.co/facebook/m2m100_1.2B",
128
  )
129
 
130
  parser.add_argument(
131
- "--batch_size",
132
  type=int,
133
- default=8,
134
- help="Batch size",
135
  )
136
 
137
  parser.add_argument(
@@ -142,17 +195,33 @@ if __name__ == "__main__":
142
  )
143
 
144
  parser.add_argument(
145
- "--precision",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  type=int,
147
- default=32,
148
- choices=[16, 32, 64],
149
- help="Precision of the model. 16, 32 or 64.",
150
  )
151
 
152
  parser.add_argument(
153
- "--tensorrt",
154
- action="store_true",
155
- help="Use TensorRT to compile the model.",
 
 
156
  )
157
 
158
  args = parser.parse_args()
@@ -162,8 +231,9 @@ if __name__ == "__main__":
162
  output_path=args.output_path,
163
  source_lang=args.source_lang,
164
  target_lang=args.target_lang,
165
- batch_size=args.batch_size,
166
  model_name=args.model_name,
 
 
167
  precision=args.precision,
168
- tensorrt=args.tensorrt,
169
  )
 
1
+ from transformers import (
2
+ M2M100ForConditionalGeneration,
3
+ M2M100Tokenizer,
4
+ PreTrainedTokenizerBase,
5
+ DataCollatorForSeq2Seq,
6
+ )
7
  from tqdm import tqdm
 
8
  import argparse
9
  import torch
10
+ from torch.utils.data import DataLoader
11
+ from dataset import DatasetReader, count_lines
12
  import os
13
+ from accelerate import Accelerator, DistributedType
14
+ from accelerate.memory_utils import find_executable_batch_size
15
+
16
+
17
+ def get_dataloader(
18
+ accelerator: Accelerator,
19
+ filename: str,
20
+ tokenizer: PreTrainedTokenizerBase,
21
+ batch_size: int,
22
+ max_length: int,
23
+ ) -> DataLoader:
24
+
25
+ dataset = DatasetReader(filename, tokenizer, max_length)
26
+ if accelerator.distributed_type == DistributedType.TPU:
27
+ data_collator = DataCollatorForSeq2Seq(
28
+ tokenizer,
29
+ padding="max_length",
30
+ max_length=max_length,
31
+ label_pad_token_id=tokenizer.pad_token_id,
32
+ return_tensors="pt",
33
+ )
34
+ else:
35
+ data_collator = DataCollatorForSeq2Seq(
36
+ tokenizer,
37
+ padding=True,
38
+ label_pad_token_id=tokenizer.pad_token_id,
39
+ # max_length=max_length, No need to set max_length here, we already truncate in the preprocess function
40
+ pad_to_multiple_of=8,
41
+ return_tensors="pt",
42
+ )
43
+
44
+ return DataLoader(
45
+ dataset,
46
+ batch_size=batch_size,
47
+ collate_fn=data_collator,
48
+ )
49
 
50
 
51
  def main(
52
+ sentences_path: str,
53
+ output_path: str,
54
+ source_lang: str,
55
+ target_lang: str,
56
+ starting_batch_size: int,
57
  model_name: str = "facebook/m2m100_1.2B",
58
+ cache_dir: str = None,
59
+ precision: str = "32",
60
  max_length: int = 128,
61
+ num_beams: int = 4,
62
  ):
63
 
64
  if not os.path.exists(os.path.dirname(output_path)):
65
  os.makedirs(os.path.dirname(output_path))
66
 
67
+ accelerator = Accelerator(mixed_precision=precision if precision != "32" else "no")
68
+
69
  print("Loading tokenizer...")
70
+ tokenizer = M2M100Tokenizer.from_pretrained(
71
+ pretrained_model_name_or_path=model_name, cache_dir=cache_dir
72
+ )
73
  print("Loading model...")
74
+ model = M2M100ForConditionalGeneration.from_pretrained(
75
+ pretrained_model_name_or_path=model_name, cache_dir=cache_dir
76
+ )
77
+
78
+ model.eval()
79
+
80
+ print(f"Preparing data...\n")
81
+
82
+ if precision == "32":
83
+ model = model.float()
84
+ elif precision == "fp16":
85
+ model = model.half()
86
+ elif precision == "bf16":
87
+ model = model.bfloat16()
88
+ else:
89
+ raise ValueError("Precision not supported. Supported values: 32, fp16, bf16")
90
 
91
  tokenizer.src_lang = source_lang
92
  lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
93
 
94
+ gen_kwargs = {
95
+ "max_length": max_length,
96
+ "num_beams": num_beams,
97
+ "num_return_sequences": 1,
98
+ }
99
 
100
  total_lines: int = count_lines(sentences_path)
101
+ print(
102
+ f"We will translate {total_lines} lines. Initial batch size: {starting_batch_size}"
 
 
 
 
103
  )
104
 
105
+ @find_executable_batch_size(starting_batch_size=starting_batch_size)
106
+ def inference(batch_size):
107
+ nonlocal model, tokenizer, sentences_path, max_length, output_path, lang_code_to_idx, gen_kwargs, total_lines, precision
 
 
 
 
 
108
 
109
+ print(f"Translating with batch size {batch_size}")
 
110
 
111
+ data_loader = get_dataloader(
112
+ accelerator=accelerator,
113
+ filename=sentences_path,
114
+ tokenizer=tokenizer,
115
+ batch_size=batch_size,
116
+ max_length=max_length,
117
+ )
118
 
119
+ model, data_loader = accelerator.prepare(model, data_loader)
120
 
121
+ with tqdm(
122
+ total=total_lines, desc="Dataset translation", leave=True, ascii=True
123
+ ) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
124
+ with torch.no_grad():
125
+ for batch in data_loader:
126
+ batch["input_ids"] = batch["input_ids"]
127
+ batch["attention_mask"] = batch["attention_mask"]
 
 
 
 
128
 
129
+ generated_tokens = accelerator.unwrap_model(model).generate(
130
+ **batch, forced_bos_token_id=lang_code_to_idx, **gen_kwargs
131
+ )
 
132
 
133
+ generated_tokens = accelerator.pad_across_processes(
134
+ generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
135
+ )
 
 
 
 
136
 
137
+ generated_tokens = (
138
+ accelerator.gather(generated_tokens).cpu().numpy()
139
+ )
 
 
 
140
 
141
+ tgt_text = tokenizer.batch_decode(
142
+ generated_tokens, skip_special_tokens=True
143
+ )
144
 
145
+ print("\n".join(tgt_text), file=output_file)
146
 
147
+ pbar.update(len(tgt_text))
148
+
149
+ inference()
150
  print(f"Translation done.\n")
151
 
152
 
 
170
  "--source_lang",
171
  type=str,
172
  required=True,
173
+ help="Source language id. See: supported_languages.md",
174
  )
175
 
176
  parser.add_argument(
177
  "--target_lang",
178
  type=str,
179
  required=True,
180
+ help="Target language id. See: supported_languages.md",
181
  )
182
 
183
  parser.add_argument(
184
+ "--starting_batch_size",
185
  type=int,
186
+ default=128,
187
+ help="Starting batch size, we will automatically reduce it if we find an OOM error.",
188
  )
189
 
190
  parser.add_argument(
 
195
  )
196
 
197
  parser.add_argument(
198
+ "--cache_dir",
199
+ type=str,
200
+ default=None,
201
+ help="Cache directory from which to load the model, or None to not cache",
202
+ )
203
+
204
+ parser.add_argument(
205
+ "--max_length",
206
+ type=int,
207
+ default=128,
208
+ help="Maximum number of tokens in the source sentence and generated sentence. "
209
+ "Increase this value to translate longer sentences, at the cost of increasing memory usage.",
210
+ )
211
+
212
+ parser.add_argument(
213
+ "--num_beams",
214
  type=int,
215
+ default=5,
216
+ help="Number of beams for beam search, m2m10 author recommends 5, but it might use too much memory",
 
217
  )
218
 
219
  parser.add_argument(
220
+ "--precision",
221
+ type=str,
222
+ default="32",
223
+ choices=["bf16", "fp16", "32"],
224
+ help="Precision of the model. bf16, fp16 or 32.",
225
  )
226
 
227
  args = parser.parse_args()
 
231
  output_path=args.output_path,
232
  source_lang=args.source_lang,
233
  target_lang=args.target_lang,
234
+ starting_batch_size=args.starting_batch_size,
235
  model_name=args.model_name,
236
+ cache_dir=args.cache_dir,
237
+ num_beams=args.num_beams,
238
  precision=args.precision,
 
239
  )
translate_troch2trt.py DELETED
@@ -1,164 +0,0 @@
1
- from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
2
- from tqdm import tqdm
3
- from typing import TextIO, List
4
- import argparse
5
- import torch
6
- from dataset import get_dataloader, count_lines
7
- import os
8
-
9
-
10
- def main(
11
- sentences_path,
12
- output_path,
13
- source_lang,
14
- target_lang,
15
- batch_size,
16
- model_name: str = "facebook/m2m100_1.2B",
17
- tensorrt: bool = False,
18
- precision: int = 32,
19
- max_length: int = 128,
20
- ):
21
-
22
- if not os.path.exists(os.path.dirname(output_path)):
23
- os.makedirs(os.path.dirname(output_path))
24
-
25
- print("Loading tokenizer...")
26
- tokenizer = M2M100Tokenizer.from_pretrained(model_name)
27
- print("Loading model...")
28
- model = M2M100ForConditionalGeneration.from_pretrained(model_name)
29
- print(f"Model loaded.\n")
30
-
31
- tokenizer.src_lang = source_lang
32
- lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
33
-
34
- model.eval()
35
-
36
- total_lines: int = count_lines(sentences_path)
37
- print(f"We will translate {total_lines} lines.")
38
- data_loader = get_dataloader(
39
- filename=sentences_path,
40
- tokenizer=tokenizer,
41
- batch_size=batch_size,
42
- max_length=128,
43
- )
44
-
45
- if precision == 16:
46
- dtype = torch.float16
47
- elif precision == 32:
48
- dtype = torch.float32
49
- elif precision == 64:
50
- dtype = torch.float64
51
- else:
52
- raise ValueError("Precision must be 16, 32 or 64.")
53
-
54
- if tensorrt:
55
- device = "cuda"
56
- from torch2trt import torch2trt
57
-
58
- model.to(device, dtype=dtype)
59
-
60
- model = torch2trt(
61
- model,
62
- [torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
63
- )
64
-
65
- else:
66
- if torch.cuda.is_available():
67
- device = "cuda"
68
-
69
- else:
70
- device = "cpu"
71
- print("CUDA not available. Using CPU. This will be slow.")
72
- model.to(device, dtype=dtype)
73
-
74
- with tqdm(total=total_lines, desc="Dataset translation") as pbar, open(
75
- output_path, "w+", encoding="utf-8"
76
- ) as output_file:
77
- with torch.no_grad():
78
- for batch in data_loader:
79
- batch["input_ids"] = batch["input_ids"].to(device)
80
- batch["attention_mask"] = batch["attention_mask"].to(device)
81
- generated_tokens = model.generate(
82
- **batch, forced_bos_token_id=lang_code_to_idx
83
- )
84
- tgt_text = tokenizer.batch_decode(
85
- generated_tokens.cpu(), skip_special_tokens=True
86
- )
87
-
88
- print("\n".join(tgt_text), file=output_file)
89
-
90
- pbar.update(len(tgt_text))
91
-
92
- print(f"Translation done.\n")
93
-
94
-
95
- if __name__ == "__main__":
96
- parser = argparse.ArgumentParser(description="Run the translation experiments")
97
- parser.add_argument(
98
- "--sentences_path",
99
- type=str,
100
- required=True,
101
- help="Path to a txt file containing the sentences to translate. One sentence per line.",
102
- )
103
-
104
- parser.add_argument(
105
- "--output_path",
106
- type=str,
107
- required=True,
108
- help="Path to a txt file where the translated sentences will be written.",
109
- )
110
-
111
- parser.add_argument(
112
- "--source_lang",
113
- type=str,
114
- required=True,
115
- help="Source language id. See: https://huggingface.co/facebook/m2m100_1.2B",
116
- )
117
-
118
- parser.add_argument(
119
- "--target_lang",
120
- type=str,
121
- required=True,
122
- help="Target language id. See: https://huggingface.co/facebook/m2m100_1.2B",
123
- )
124
-
125
- parser.add_argument(
126
- "--batch_size",
127
- type=int,
128
- default=8,
129
- help="Batch size",
130
- )
131
-
132
- parser.add_argument(
133
- "--model_name",
134
- type=str,
135
- default="facebook/m2m100_1.2B",
136
- help="Path to the model to use. See: https://huggingface.co/models",
137
- )
138
-
139
- parser.add_argument(
140
- "--precision",
141
- type=int,
142
- default=32,
143
- choices=[16, 32, 64],
144
- help="Precision of the model. 16, 32 or 64.",
145
- )
146
-
147
- parser.add_argument(
148
- "--tensorrt",
149
- action="store_true",
150
- help="Use TensorRT to compile the model.",
151
- )
152
-
153
- args = parser.parse_args()
154
-
155
- main(
156
- sentences_path=args.sentences_path,
157
- output_path=args.output_path,
158
- source_lang=args.source_lang,
159
- target_lang=args.target_lang,
160
- batch_size=args.batch_size,
161
- model_name=args.model_name,
162
- precision=args.precision,
163
- tensorrt=args.tensorrt,
164
- )