Spaces:
Running
Running
M2M100 with transformers and accelerate
Browse files- README.md +102 -0
- dataset.py +8 -24
- sample_text/RADME.md +9 -0
- sample_text/en.txt +0 -0
- sample_text/en2es.translation.txt +0 -0
- sample_text/es.txt +0 -0
- supported_languages.md +104 -0
- translate.py +146 -76
- translate_troch2trt.py +0 -164
README.md
CHANGED
@@ -1 +1,103 @@
|
|
1 |
# Easy-Translate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Easy-Translate
|
2 |
+
|
3 |
+
Easy-translate is a script for translating large text files in your machine using the [M2M100 models](https://arxiv.org/pdf/2010.11125.pdf) from Facebook/Meta AI.
|
4 |
+
|
5 |
+
M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation.
|
6 |
+
It was introduced in this [paper](https://arxiv.org/abs/2010.11125) and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository.
|
7 |
+
The model that can directly translate between the 9,900 directions of 100 languages.
|
8 |
+
|
9 |
+
Easy-Translate is built on top of 🤗HuggingFace's
|
10 |
+
[Transformers](https://huggingface.co/docs/transformers/index) and
|
11 |
+
🤗HuggingFace's [Accelerate](https://huggingface.co/docs/accelerate/index) library. We support:
|
12 |
+
|
13 |
+
* CPU / GPU / multi-GPU / TPU acceleration
|
14 |
+
* BF16 / FP16 / FB32 precision.
|
15 |
+
* Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
|
16 |
+
* Sharded Data Parallel to load huge models sharded on multiple GPUs (See: https://huggingface.co/docs/accelerate/fsdp).
|
17 |
+
|
18 |
+
Test the 🔌 Online Demo here: https://huggingface.co/spaces/Iker/Translate-100-languages
|
19 |
+
|
20 |
+
## Supported languages
|
21 |
+
See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.
|
22 |
+
|
23 |
+
**List of supported languages:**
|
24 |
+
Afrikaans, Amharic, Arabic, Asturian, Azerbaijani, Bashkir, Belarusian, Bulgarian, Bengali, Breton, Bosnian, Catalan, Cebuano, Czech, Welsh, Danish, German, Greeek, English, Spanish, Estonian, Persian, Fulah, Finnish, French, WesternFrisian, Irish, Gaelic, Galician, Gujarati, Hausa, Hebrew, Hindi, Croatian, Haitian, Hungarian, Armenian, Indonesian, Igbo, Iloko, Icelandic, Italian, Japanese, Javanese, Georgian, Kazakh, CentralKhmer, Kannada, Korean, Luxembourgish, Ganda, Lingala, Lao, Lithuanian, Latvian, Malagasy, Macedonian, Malayalam, Mongolian, Marathi, Malay, Burmese, Nepali, Dutch, Norwegian, NorthernSotho, Occitan, Oriya, Panjabi, Polish, Pushto, Portuguese, Romanian, Russian, Sindhi, Sinhala, Slovak, Slovenian, Somali, Albanian, Serbian, Swati, Sundanese, Swedish, Swahili, Tamil, Thai, Tagalog, Tswana, Turkish, Ukrainian, Urdu, Uzbek, Vietnamese, Wolof, Xhosa, Yiddish, Yoruba, Chinese, Zulu
|
25 |
+
|
26 |
+
## Supported Models
|
27 |
+
|
28 |
+
* **Facebook/m2m100_418M**: https://huggingface.co/facebook/m2m100_418M
|
29 |
+
|
30 |
+
* **Facebook/m2m100_1.2B**: https://huggingface.co/facebook/m2m100_1.2B
|
31 |
+
|
32 |
+
* **Facebook/m2m100_12B**: https://huggingface.co/facebook/m2m100-12B-avg-5-ckpt
|
33 |
+
|
34 |
+
* Any other m2m100 model from HuggingFace's Hub: https://huggingface.co/models?search=m2m100
|
35 |
+
|
36 |
+
|
37 |
+
## Requirements:
|
38 |
+
|
39 |
+
```
|
40 |
+
Pytorch >= 1.10.0
|
41 |
+
See: https://pytorch.org/get-started/locally/
|
42 |
+
|
43 |
+
Accelerate >= 0.7.1
|
44 |
+
pip install --upgrade accelerate
|
45 |
+
|
46 |
+
HuggingFace Transformers
|
47 |
+
pip install --upgrade transformers
|
48 |
+
```
|
49 |
+
|
50 |
+
## Translate a file
|
51 |
+
|
52 |
+
Run `python translate.py -h` for more info.
|
53 |
+
|
54 |
+
#### Using a single CPU / GPU:
|
55 |
+
```bash
|
56 |
+
accelerate launch translate.py \
|
57 |
+
--sentences_path sample_text/en.txt \
|
58 |
+
--output_path sample_text/en2es.translation.txt \
|
59 |
+
--source_lang en \
|
60 |
+
--target_lang es \
|
61 |
+
--model_name facebook/m2m100_1.2B
|
62 |
+
```
|
63 |
+
|
64 |
+
#### Multi-GPU:
|
65 |
+
See Accelerate documentation for more information (multi-node, TPU, Sharded model...): https://huggingface.co/docs/accelerate/index
|
66 |
+
You can use the Accelerate CLI to configure the Accelerate environment (Run
|
67 |
+
`accelerate config` in your terminal) instead of using the
|
68 |
+
`--multi_gpu and --num_processes` flags.
|
69 |
+
|
70 |
+
```bash
|
71 |
+
accelerate launch --multi_gpu --num_processes 2 --num_machines 1 translate.py \
|
72 |
+
--sentences_path sample_text/en.txt \
|
73 |
+
--output_path sample_text/en2es.translation.txt \
|
74 |
+
--source_lang en \
|
75 |
+
--target_lang es \
|
76 |
+
--model_name facebook/m2m100_1.2B
|
77 |
+
```
|
78 |
+
|
79 |
+
#### Automatic batch size finder:
|
80 |
+
We will automatically find a batch size that fits in your GPU memory.
|
81 |
+
The default initial batch size is 128 (You can set it with the `--starting_batch_size 128` flag).
|
82 |
+
If we find an Out Of Memory error, we will automatically decrease the batch size until we find a working one.
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
#### Choose precision:
|
87 |
+
Use the `--precision` flag to choose the precision of the model. You can choose between: bf16, fp16 and 32.
|
88 |
+
|
89 |
+
```bash
|
90 |
+
accelerate launch translate.py \
|
91 |
+
--sentences_path sample_text/en.txt \
|
92 |
+
--output_path sample_text/en2es.translation.txt \
|
93 |
+
--source_lang en \
|
94 |
+
--target_lang es \
|
95 |
+
--model_name facebook/m2m100_1.2B \
|
96 |
+
--precision fp16
|
97 |
+
```
|
98 |
+
|
99 |
+
## Evaluate translations
|
100 |
+
|
101 |
+
Work in progress...
|
102 |
+
|
103 |
+
|
dataset.py
CHANGED
@@ -1,7 +1,4 @@
|
|
1 |
-
from typing import List, TextIO, Dict, Optional
|
2 |
-
import torch
|
3 |
from torch.utils.data import IterableDataset
|
4 |
-
from torch.utils.data.dataset import T_co
|
5 |
|
6 |
|
7 |
def blocks(files, size=65536):
|
@@ -22,35 +19,22 @@ class DatasetReader(IterableDataset):
|
|
22 |
self.filename = filename
|
23 |
self.tokenizer = tokenizer
|
24 |
self.max_length = max_length
|
|
|
25 |
|
26 |
def preprocess(self, text: str):
|
|
|
|
|
|
|
|
|
27 |
return self.tokenizer(
|
28 |
-
text
|
29 |
-
padding=
|
30 |
truncation=True,
|
31 |
max_length=self.max_length,
|
32 |
-
return_tensors=
|
33 |
)
|
34 |
|
35 |
def __iter__(self):
|
36 |
file_itr = open(self.filename, "r")
|
37 |
mapped_itr = map(self.preprocess, file_itr)
|
38 |
return mapped_itr
|
39 |
-
|
40 |
-
|
41 |
-
def collate_function(batch: List[T_co]) -> Dict[str, torch.Tensor]:
|
42 |
-
return {
|
43 |
-
"input_ids": torch.stack([item["input_ids"][0] for item in batch]),
|
44 |
-
"attention_mask": torch.stack([item["attention_mask"][0] for item in batch]),
|
45 |
-
}
|
46 |
-
|
47 |
-
|
48 |
-
def get_dataloader(
|
49 |
-
filename: str, tokenizer: str, batch_size: int, max_length: int
|
50 |
-
) -> torch.utils.data.DataLoader:
|
51 |
-
dataset = DatasetReader(filename, tokenizer, max_length)
|
52 |
-
return torch.utils.data.DataLoader(
|
53 |
-
dataset,
|
54 |
-
batch_size=batch_size,
|
55 |
-
collate_fn=collate_function,
|
56 |
-
)
|
|
|
|
|
|
|
1 |
from torch.utils.data import IterableDataset
|
|
|
2 |
|
3 |
|
4 |
def blocks(files, size=65536):
|
|
|
19 |
self.filename = filename
|
20 |
self.tokenizer = tokenizer
|
21 |
self.max_length = max_length
|
22 |
+
self.current_line = 0
|
23 |
|
24 |
def preprocess(self, text: str):
|
25 |
+
self.current_line += 1
|
26 |
+
text = text.rstrip().strip()
|
27 |
+
if len(text) == 0:
|
28 |
+
print(f"Warning: empty sentence at line {self.current_line}")
|
29 |
return self.tokenizer(
|
30 |
+
text,
|
31 |
+
padding=False,
|
32 |
truncation=True,
|
33 |
max_length=self.max_length,
|
34 |
+
return_tensors=None,
|
35 |
)
|
36 |
|
37 |
def __iter__(self):
|
38 |
file_itr = open(self.filename, "r")
|
39 |
mapped_itr = map(self.preprocess, file_itr)
|
40 |
return mapped_itr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample_text/RADME.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Sample texts
|
2 |
+
|
3 |
+
We provide a few parallel sentences for easy debugging and testing.
|
4 |
+
Data has been extracted from the europarl v7 corpus: [http://www.statmt.org/europarl/v7/](http://www.statmt.org/europarl/v7/).
|
5 |
+
|
6 |
+
* **en.txt**: 1000 English sentences
|
7 |
+
* **es.txt**: 1000 Spanish sentences
|
8 |
+
|
9 |
+
Sentences in both files are parallel.
|
sample_text/en.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sample_text/en2es.translation.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sample_text/es.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
supported_languages.md
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Supported languages
|
2 |
+
|
3 |
+
| Language | Id |
|
4 |
+
|---|---|
|
5 |
+
| Afrikaans | af |
|
6 |
+
| Amharic | am |
|
7 |
+
| Arabic | ar |
|
8 |
+
| Asturian | ast |
|
9 |
+
| Azerbaijani | az |
|
10 |
+
| Bashkir | ba |
|
11 |
+
| Belarusian | be |
|
12 |
+
| Bulgarian | bg |
|
13 |
+
| Bengali | bn |
|
14 |
+
| Breton | br |
|
15 |
+
| Bosnian | bs |
|
16 |
+
| Catalan | ca |
|
17 |
+
| Cebuano | ceb |
|
18 |
+
| Czech | cs |
|
19 |
+
| Welsh | cy |
|
20 |
+
| Danish | da |
|
21 |
+
| German | de |
|
22 |
+
| Greeek | el |
|
23 |
+
| English | en |
|
24 |
+
| Spanish | es |
|
25 |
+
| Estonian | et |
|
26 |
+
| Persian | fa |
|
27 |
+
| Fulah | ff |
|
28 |
+
| Finnish | fi |
|
29 |
+
| French | fr |
|
30 |
+
| WesternFrisian | fy |
|
31 |
+
| Irish | ga |
|
32 |
+
| Gaelic | gd |
|
33 |
+
| Galician | gl |
|
34 |
+
| Gujarati | gu |
|
35 |
+
| Hausa | ha |
|
36 |
+
| Hebrew | he |
|
37 |
+
| Hindi | hi |
|
38 |
+
| Croatian | hr |
|
39 |
+
| Haitian | ht |
|
40 |
+
| Hungarian | hu |
|
41 |
+
| Armenian | hy |
|
42 |
+
| Indonesian | id |
|
43 |
+
| Igbo | ig |
|
44 |
+
| Iloko | ilo |
|
45 |
+
| Icelandic | is |
|
46 |
+
| Italian | it |
|
47 |
+
| Japanese | ja |
|
48 |
+
| Javanese | jv |
|
49 |
+
| Georgian | ka |
|
50 |
+
| Kazakh | kk |
|
51 |
+
| CentralKhmer | km |
|
52 |
+
| Kannada | kn |
|
53 |
+
| Korean | ko |
|
54 |
+
| Luxembourgish | lb |
|
55 |
+
| Ganda | lg |
|
56 |
+
| Lingala | ln |
|
57 |
+
| Lao | lo |
|
58 |
+
| Lithuanian | lt |
|
59 |
+
| Latvian | lv |
|
60 |
+
| Malagasy | mg |
|
61 |
+
| Macedonian | mk |
|
62 |
+
| Malayalam | ml |
|
63 |
+
| Mongolian | mn |
|
64 |
+
| Marathi | mr |
|
65 |
+
| Malay | ms |
|
66 |
+
| Burmese | my |
|
67 |
+
| Nepali | ne |
|
68 |
+
| Dutch | nl |
|
69 |
+
| Norwegian | no |
|
70 |
+
| NorthernSotho | ns |
|
71 |
+
| Occitan | oc |
|
72 |
+
| Oriya | or |
|
73 |
+
| Panjabi | pa |
|
74 |
+
| Polish | pl |
|
75 |
+
| Pushto | ps |
|
76 |
+
| Portuguese | pt |
|
77 |
+
| Romanian | ro |
|
78 |
+
| Russian | ru |
|
79 |
+
| Sindhi | sd |
|
80 |
+
| Sinhala | si |
|
81 |
+
| Slovak | sk |
|
82 |
+
| Slovenian | sl |
|
83 |
+
| Somali | so |
|
84 |
+
| Albanian | sq |
|
85 |
+
| Serbian | sr |
|
86 |
+
| Swati | ss |
|
87 |
+
| Sundanese | su |
|
88 |
+
| Swedish | sv |
|
89 |
+
| Swahili | sw |
|
90 |
+
| Tamil | ta |
|
91 |
+
| Thai | th |
|
92 |
+
| Tagalog | tl |
|
93 |
+
| Tswana | tn |
|
94 |
+
| Turkish | tr |
|
95 |
+
| Ukrainian | uk |
|
96 |
+
| Urdu | ur |
|
97 |
+
| Uzbek | uz |
|
98 |
+
| Vietnamese | vi |
|
99 |
+
| Wolof | wo |
|
100 |
+
| Xhosa | xh |
|
101 |
+
| Yiddish | yi |
|
102 |
+
| Yoruba | yo |
|
103 |
+
| Chinese | zh |
|
104 |
+
| Zulu | zu |
|
translate.py
CHANGED
@@ -1,99 +1,152 @@
|
|
1 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
2 |
from tqdm import tqdm
|
3 |
-
from typing import TextIO, List
|
4 |
import argparse
|
5 |
import torch
|
6 |
-
from
|
|
|
7 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
def main(
|
11 |
-
sentences_path,
|
12 |
-
output_path,
|
13 |
-
source_lang,
|
14 |
-
target_lang,
|
15 |
-
|
16 |
model_name: str = "facebook/m2m100_1.2B",
|
17 |
-
|
18 |
-
precision:
|
19 |
max_length: int = 128,
|
|
|
20 |
):
|
21 |
|
22 |
if not os.path.exists(os.path.dirname(output_path)):
|
23 |
os.makedirs(os.path.dirname(output_path))
|
24 |
|
|
|
|
|
25 |
print("Loading tokenizer...")
|
26 |
-
tokenizer = M2M100Tokenizer.from_pretrained(
|
|
|
|
|
27 |
print("Loading model...")
|
28 |
-
model = M2M100ForConditionalGeneration.from_pretrained(
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
tokenizer.src_lang = source_lang
|
32 |
lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
|
36 |
total_lines: int = count_lines(sentences_path)
|
37 |
-
print(
|
38 |
-
|
39 |
-
filename=sentences_path,
|
40 |
-
tokenizer=tokenizer,
|
41 |
-
batch_size=batch_size,
|
42 |
-
max_length=128,
|
43 |
)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
dtype = torch.float32
|
49 |
-
elif precision == 64:
|
50 |
-
dtype = torch.float64
|
51 |
-
else:
|
52 |
-
raise ValueError("Precision must be 16, 32 or 64.")
|
53 |
|
54 |
-
|
55 |
-
import torch_tensorrt
|
56 |
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
model.
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
)
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
)
|
69 |
-
else:
|
70 |
-
if torch.cuda.is_available():
|
71 |
-
device = "cuda"
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
model.to(device, dtype=dtype)
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
with torch.no_grad():
|
82 |
-
for batch in data_loader:
|
83 |
-
batch["input_ids"] = batch["input_ids"].to(device)
|
84 |
-
batch["attention_mask"] = batch["attention_mask"].to(device)
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
tgt_text = tokenizer.batch_decode(
|
90 |
-
generated_tokens.cpu(), skip_special_tokens=True
|
91 |
-
)
|
92 |
|
93 |
-
|
|
|
|
|
94 |
|
95 |
-
|
96 |
|
|
|
|
|
|
|
97 |
print(f"Translation done.\n")
|
98 |
|
99 |
|
@@ -117,21 +170,21 @@ if __name__ == "__main__":
|
|
117 |
"--source_lang",
|
118 |
type=str,
|
119 |
required=True,
|
120 |
-
help="Source language id. See:
|
121 |
)
|
122 |
|
123 |
parser.add_argument(
|
124 |
"--target_lang",
|
125 |
type=str,
|
126 |
required=True,
|
127 |
-
help="Target language id. See:
|
128 |
)
|
129 |
|
130 |
parser.add_argument(
|
131 |
-
"--
|
132 |
type=int,
|
133 |
-
default=
|
134 |
-
help="
|
135 |
)
|
136 |
|
137 |
parser.add_argument(
|
@@ -142,17 +195,33 @@ if __name__ == "__main__":
|
|
142 |
)
|
143 |
|
144 |
parser.add_argument(
|
145 |
-
"--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
type=int,
|
147 |
-
default=
|
148 |
-
|
149 |
-
help="Precision of the model. 16, 32 or 64.",
|
150 |
)
|
151 |
|
152 |
parser.add_argument(
|
153 |
-
"--
|
154 |
-
|
155 |
-
|
|
|
|
|
156 |
)
|
157 |
|
158 |
args = parser.parse_args()
|
@@ -162,8 +231,9 @@ if __name__ == "__main__":
|
|
162 |
output_path=args.output_path,
|
163 |
source_lang=args.source_lang,
|
164 |
target_lang=args.target_lang,
|
165 |
-
|
166 |
model_name=args.model_name,
|
|
|
|
|
167 |
precision=args.precision,
|
168 |
-
tensorrt=args.tensorrt,
|
169 |
)
|
|
|
1 |
+
from transformers import (
|
2 |
+
M2M100ForConditionalGeneration,
|
3 |
+
M2M100Tokenizer,
|
4 |
+
PreTrainedTokenizerBase,
|
5 |
+
DataCollatorForSeq2Seq,
|
6 |
+
)
|
7 |
from tqdm import tqdm
|
|
|
8 |
import argparse
|
9 |
import torch
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from dataset import DatasetReader, count_lines
|
12 |
import os
|
13 |
+
from accelerate import Accelerator, DistributedType
|
14 |
+
from accelerate.memory_utils import find_executable_batch_size
|
15 |
+
|
16 |
+
|
17 |
+
def get_dataloader(
|
18 |
+
accelerator: Accelerator,
|
19 |
+
filename: str,
|
20 |
+
tokenizer: PreTrainedTokenizerBase,
|
21 |
+
batch_size: int,
|
22 |
+
max_length: int,
|
23 |
+
) -> DataLoader:
|
24 |
+
|
25 |
+
dataset = DatasetReader(filename, tokenizer, max_length)
|
26 |
+
if accelerator.distributed_type == DistributedType.TPU:
|
27 |
+
data_collator = DataCollatorForSeq2Seq(
|
28 |
+
tokenizer,
|
29 |
+
padding="max_length",
|
30 |
+
max_length=max_length,
|
31 |
+
label_pad_token_id=tokenizer.pad_token_id,
|
32 |
+
return_tensors="pt",
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
data_collator = DataCollatorForSeq2Seq(
|
36 |
+
tokenizer,
|
37 |
+
padding=True,
|
38 |
+
label_pad_token_id=tokenizer.pad_token_id,
|
39 |
+
# max_length=max_length, No need to set max_length here, we already truncate in the preprocess function
|
40 |
+
pad_to_multiple_of=8,
|
41 |
+
return_tensors="pt",
|
42 |
+
)
|
43 |
+
|
44 |
+
return DataLoader(
|
45 |
+
dataset,
|
46 |
+
batch_size=batch_size,
|
47 |
+
collate_fn=data_collator,
|
48 |
+
)
|
49 |
|
50 |
|
51 |
def main(
|
52 |
+
sentences_path: str,
|
53 |
+
output_path: str,
|
54 |
+
source_lang: str,
|
55 |
+
target_lang: str,
|
56 |
+
starting_batch_size: int,
|
57 |
model_name: str = "facebook/m2m100_1.2B",
|
58 |
+
cache_dir: str = None,
|
59 |
+
precision: str = "32",
|
60 |
max_length: int = 128,
|
61 |
+
num_beams: int = 4,
|
62 |
):
|
63 |
|
64 |
if not os.path.exists(os.path.dirname(output_path)):
|
65 |
os.makedirs(os.path.dirname(output_path))
|
66 |
|
67 |
+
accelerator = Accelerator(mixed_precision=precision if precision != "32" else "no")
|
68 |
+
|
69 |
print("Loading tokenizer...")
|
70 |
+
tokenizer = M2M100Tokenizer.from_pretrained(
|
71 |
+
pretrained_model_name_or_path=model_name, cache_dir=cache_dir
|
72 |
+
)
|
73 |
print("Loading model...")
|
74 |
+
model = M2M100ForConditionalGeneration.from_pretrained(
|
75 |
+
pretrained_model_name_or_path=model_name, cache_dir=cache_dir
|
76 |
+
)
|
77 |
+
|
78 |
+
model.eval()
|
79 |
+
|
80 |
+
print(f"Preparing data...\n")
|
81 |
+
|
82 |
+
if precision == "32":
|
83 |
+
model = model.float()
|
84 |
+
elif precision == "fp16":
|
85 |
+
model = model.half()
|
86 |
+
elif precision == "bf16":
|
87 |
+
model = model.bfloat16()
|
88 |
+
else:
|
89 |
+
raise ValueError("Precision not supported. Supported values: 32, fp16, bf16")
|
90 |
|
91 |
tokenizer.src_lang = source_lang
|
92 |
lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
|
93 |
|
94 |
+
gen_kwargs = {
|
95 |
+
"max_length": max_length,
|
96 |
+
"num_beams": num_beams,
|
97 |
+
"num_return_sequences": 1,
|
98 |
+
}
|
99 |
|
100 |
total_lines: int = count_lines(sentences_path)
|
101 |
+
print(
|
102 |
+
f"We will translate {total_lines} lines. Initial batch size: {starting_batch_size}"
|
|
|
|
|
|
|
|
|
103 |
)
|
104 |
|
105 |
+
@find_executable_batch_size(starting_batch_size=starting_batch_size)
|
106 |
+
def inference(batch_size):
|
107 |
+
nonlocal model, tokenizer, sentences_path, max_length, output_path, lang_code_to_idx, gen_kwargs, total_lines, precision
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
print(f"Translating with batch size {batch_size}")
|
|
|
110 |
|
111 |
+
data_loader = get_dataloader(
|
112 |
+
accelerator=accelerator,
|
113 |
+
filename=sentences_path,
|
114 |
+
tokenizer=tokenizer,
|
115 |
+
batch_size=batch_size,
|
116 |
+
max_length=max_length,
|
117 |
+
)
|
118 |
|
119 |
+
model, data_loader = accelerator.prepare(model, data_loader)
|
120 |
|
121 |
+
with tqdm(
|
122 |
+
total=total_lines, desc="Dataset translation", leave=True, ascii=True
|
123 |
+
) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
|
124 |
+
with torch.no_grad():
|
125 |
+
for batch in data_loader:
|
126 |
+
batch["input_ids"] = batch["input_ids"]
|
127 |
+
batch["attention_mask"] = batch["attention_mask"]
|
|
|
|
|
|
|
|
|
128 |
|
129 |
+
generated_tokens = accelerator.unwrap_model(model).generate(
|
130 |
+
**batch, forced_bos_token_id=lang_code_to_idx, **gen_kwargs
|
131 |
+
)
|
|
|
132 |
|
133 |
+
generated_tokens = accelerator.pad_across_processes(
|
134 |
+
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
|
135 |
+
)
|
|
|
|
|
|
|
|
|
136 |
|
137 |
+
generated_tokens = (
|
138 |
+
accelerator.gather(generated_tokens).cpu().numpy()
|
139 |
+
)
|
|
|
|
|
|
|
140 |
|
141 |
+
tgt_text = tokenizer.batch_decode(
|
142 |
+
generated_tokens, skip_special_tokens=True
|
143 |
+
)
|
144 |
|
145 |
+
print("\n".join(tgt_text), file=output_file)
|
146 |
|
147 |
+
pbar.update(len(tgt_text))
|
148 |
+
|
149 |
+
inference()
|
150 |
print(f"Translation done.\n")
|
151 |
|
152 |
|
|
|
170 |
"--source_lang",
|
171 |
type=str,
|
172 |
required=True,
|
173 |
+
help="Source language id. See: supported_languages.md",
|
174 |
)
|
175 |
|
176 |
parser.add_argument(
|
177 |
"--target_lang",
|
178 |
type=str,
|
179 |
required=True,
|
180 |
+
help="Target language id. See: supported_languages.md",
|
181 |
)
|
182 |
|
183 |
parser.add_argument(
|
184 |
+
"--starting_batch_size",
|
185 |
type=int,
|
186 |
+
default=128,
|
187 |
+
help="Starting batch size, we will automatically reduce it if we find an OOM error.",
|
188 |
)
|
189 |
|
190 |
parser.add_argument(
|
|
|
195 |
)
|
196 |
|
197 |
parser.add_argument(
|
198 |
+
"--cache_dir",
|
199 |
+
type=str,
|
200 |
+
default=None,
|
201 |
+
help="Cache directory from which to load the model, or None to not cache",
|
202 |
+
)
|
203 |
+
|
204 |
+
parser.add_argument(
|
205 |
+
"--max_length",
|
206 |
+
type=int,
|
207 |
+
default=128,
|
208 |
+
help="Maximum number of tokens in the source sentence and generated sentence. "
|
209 |
+
"Increase this value to translate longer sentences, at the cost of increasing memory usage.",
|
210 |
+
)
|
211 |
+
|
212 |
+
parser.add_argument(
|
213 |
+
"--num_beams",
|
214 |
type=int,
|
215 |
+
default=5,
|
216 |
+
help="Number of beams for beam search, m2m10 author recommends 5, but it might use too much memory",
|
|
|
217 |
)
|
218 |
|
219 |
parser.add_argument(
|
220 |
+
"--precision",
|
221 |
+
type=str,
|
222 |
+
default="32",
|
223 |
+
choices=["bf16", "fp16", "32"],
|
224 |
+
help="Precision of the model. bf16, fp16 or 32.",
|
225 |
)
|
226 |
|
227 |
args = parser.parse_args()
|
|
|
231 |
output_path=args.output_path,
|
232 |
source_lang=args.source_lang,
|
233 |
target_lang=args.target_lang,
|
234 |
+
starting_batch_size=args.starting_batch_size,
|
235 |
model_name=args.model_name,
|
236 |
+
cache_dir=args.cache_dir,
|
237 |
+
num_beams=args.num_beams,
|
238 |
precision=args.precision,
|
|
|
239 |
)
|
translate_troch2trt.py
DELETED
@@ -1,164 +0,0 @@
|
|
1 |
-
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
2 |
-
from tqdm import tqdm
|
3 |
-
from typing import TextIO, List
|
4 |
-
import argparse
|
5 |
-
import torch
|
6 |
-
from dataset import get_dataloader, count_lines
|
7 |
-
import os
|
8 |
-
|
9 |
-
|
10 |
-
def main(
|
11 |
-
sentences_path,
|
12 |
-
output_path,
|
13 |
-
source_lang,
|
14 |
-
target_lang,
|
15 |
-
batch_size,
|
16 |
-
model_name: str = "facebook/m2m100_1.2B",
|
17 |
-
tensorrt: bool = False,
|
18 |
-
precision: int = 32,
|
19 |
-
max_length: int = 128,
|
20 |
-
):
|
21 |
-
|
22 |
-
if not os.path.exists(os.path.dirname(output_path)):
|
23 |
-
os.makedirs(os.path.dirname(output_path))
|
24 |
-
|
25 |
-
print("Loading tokenizer...")
|
26 |
-
tokenizer = M2M100Tokenizer.from_pretrained(model_name)
|
27 |
-
print("Loading model...")
|
28 |
-
model = M2M100ForConditionalGeneration.from_pretrained(model_name)
|
29 |
-
print(f"Model loaded.\n")
|
30 |
-
|
31 |
-
tokenizer.src_lang = source_lang
|
32 |
-
lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
|
33 |
-
|
34 |
-
model.eval()
|
35 |
-
|
36 |
-
total_lines: int = count_lines(sentences_path)
|
37 |
-
print(f"We will translate {total_lines} lines.")
|
38 |
-
data_loader = get_dataloader(
|
39 |
-
filename=sentences_path,
|
40 |
-
tokenizer=tokenizer,
|
41 |
-
batch_size=batch_size,
|
42 |
-
max_length=128,
|
43 |
-
)
|
44 |
-
|
45 |
-
if precision == 16:
|
46 |
-
dtype = torch.float16
|
47 |
-
elif precision == 32:
|
48 |
-
dtype = torch.float32
|
49 |
-
elif precision == 64:
|
50 |
-
dtype = torch.float64
|
51 |
-
else:
|
52 |
-
raise ValueError("Precision must be 16, 32 or 64.")
|
53 |
-
|
54 |
-
if tensorrt:
|
55 |
-
device = "cuda"
|
56 |
-
from torch2trt import torch2trt
|
57 |
-
|
58 |
-
model.to(device, dtype=dtype)
|
59 |
-
|
60 |
-
model = torch2trt(
|
61 |
-
model,
|
62 |
-
[torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
|
63 |
-
)
|
64 |
-
|
65 |
-
else:
|
66 |
-
if torch.cuda.is_available():
|
67 |
-
device = "cuda"
|
68 |
-
|
69 |
-
else:
|
70 |
-
device = "cpu"
|
71 |
-
print("CUDA not available. Using CPU. This will be slow.")
|
72 |
-
model.to(device, dtype=dtype)
|
73 |
-
|
74 |
-
with tqdm(total=total_lines, desc="Dataset translation") as pbar, open(
|
75 |
-
output_path, "w+", encoding="utf-8"
|
76 |
-
) as output_file:
|
77 |
-
with torch.no_grad():
|
78 |
-
for batch in data_loader:
|
79 |
-
batch["input_ids"] = batch["input_ids"].to(device)
|
80 |
-
batch["attention_mask"] = batch["attention_mask"].to(device)
|
81 |
-
generated_tokens = model.generate(
|
82 |
-
**batch, forced_bos_token_id=lang_code_to_idx
|
83 |
-
)
|
84 |
-
tgt_text = tokenizer.batch_decode(
|
85 |
-
generated_tokens.cpu(), skip_special_tokens=True
|
86 |
-
)
|
87 |
-
|
88 |
-
print("\n".join(tgt_text), file=output_file)
|
89 |
-
|
90 |
-
pbar.update(len(tgt_text))
|
91 |
-
|
92 |
-
print(f"Translation done.\n")
|
93 |
-
|
94 |
-
|
95 |
-
if __name__ == "__main__":
|
96 |
-
parser = argparse.ArgumentParser(description="Run the translation experiments")
|
97 |
-
parser.add_argument(
|
98 |
-
"--sentences_path",
|
99 |
-
type=str,
|
100 |
-
required=True,
|
101 |
-
help="Path to a txt file containing the sentences to translate. One sentence per line.",
|
102 |
-
)
|
103 |
-
|
104 |
-
parser.add_argument(
|
105 |
-
"--output_path",
|
106 |
-
type=str,
|
107 |
-
required=True,
|
108 |
-
help="Path to a txt file where the translated sentences will be written.",
|
109 |
-
)
|
110 |
-
|
111 |
-
parser.add_argument(
|
112 |
-
"--source_lang",
|
113 |
-
type=str,
|
114 |
-
required=True,
|
115 |
-
help="Source language id. See: https://huggingface.co/facebook/m2m100_1.2B",
|
116 |
-
)
|
117 |
-
|
118 |
-
parser.add_argument(
|
119 |
-
"--target_lang",
|
120 |
-
type=str,
|
121 |
-
required=True,
|
122 |
-
help="Target language id. See: https://huggingface.co/facebook/m2m100_1.2B",
|
123 |
-
)
|
124 |
-
|
125 |
-
parser.add_argument(
|
126 |
-
"--batch_size",
|
127 |
-
type=int,
|
128 |
-
default=8,
|
129 |
-
help="Batch size",
|
130 |
-
)
|
131 |
-
|
132 |
-
parser.add_argument(
|
133 |
-
"--model_name",
|
134 |
-
type=str,
|
135 |
-
default="facebook/m2m100_1.2B",
|
136 |
-
help="Path to the model to use. See: https://huggingface.co/models",
|
137 |
-
)
|
138 |
-
|
139 |
-
parser.add_argument(
|
140 |
-
"--precision",
|
141 |
-
type=int,
|
142 |
-
default=32,
|
143 |
-
choices=[16, 32, 64],
|
144 |
-
help="Precision of the model. 16, 32 or 64.",
|
145 |
-
)
|
146 |
-
|
147 |
-
parser.add_argument(
|
148 |
-
"--tensorrt",
|
149 |
-
action="store_true",
|
150 |
-
help="Use TensorRT to compile the model.",
|
151 |
-
)
|
152 |
-
|
153 |
-
args = parser.parse_args()
|
154 |
-
|
155 |
-
main(
|
156 |
-
sentences_path=args.sentences_path,
|
157 |
-
output_path=args.output_path,
|
158 |
-
source_lang=args.source_lang,
|
159 |
-
target_lang=args.target_lang,
|
160 |
-
batch_size=args.batch_size,
|
161 |
-
model_name=args.model_name,
|
162 |
-
precision=args.precision,
|
163 |
-
tensorrt=args.tensorrt,
|
164 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|