Spaces:
Running
Running
Implement SeamlessM4T
Browse files- README.md +32 -12
- model.py +35 -40
- sample_text/en2es.seamless-m4t-large.json +60 -0
- sample_text/en2es.seamless-m4t-medium.json +60 -0
- sample_text/en2es.translation.seamless-m4t-large.txt +0 -0
- sample_text/en2es.translation.seamless-m4t-medium.txt +0 -0
- tests/__init__.py +0 -0
- tests/test_translation.py +548 -0
- translate.py +119 -21
README.md
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
<p align="center">
|
3 |
<br>
|
4 |
<img src="images/title.png" width="900"/>
|
@@ -29,23 +28,19 @@ We currently support:
|
|
29 |
- BF16 / FP16 / FP32 / 8 Bits / 4 Bits precision.
|
30 |
- Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
|
31 |
- Multiple decoding strategies: Greedy Search, Beam Search, Top-K Sampling, Top-p (nucleus) sampling, etc. See [Decoding Strategies](#decodingsampling-strategies) for more information.
|
32 |
-
-
|
33 |
-
-
|
34 |
-
-
|
35 |
-
-
|
|
|
36 |
|
37 |
>Test the 🔌 Online Demo here: <https://huggingface.co/spaces/Iker/Translate-100-languages>
|
38 |
|
39 |
|
40 |
-
|
41 |
-
## Supported languages
|
42 |
-
|
43 |
-
See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.
|
44 |
-
|
45 |
## Supported Models
|
46 |
|
47 |
💥 EasyTranslate now supports any Seq2SeqLM (m2m100, nllb200, small100, mbart, MarianMT, T5, FlanT5, etc.) and any CausalLM (GPT2, LLaMA, Vicuna, Falcon) model from 🤗 Hugging Face's Hub!!
|
48 |
-
We still recommend you to use M2M100 or
|
49 |
You can also see [the examples folder](examples) for examples of how to use EasyTranslate with different models.
|
50 |
|
51 |
### M2M100
|
@@ -73,13 +68,23 @@ You can also see [the examples folder](examples) for examples of how to use Easy
|
|
73 |
|
74 |
- **facebook/nllb-200-distilled-600M**: <https://huggingface.co/facebook/nllb-200-distilled-600M>
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
### Other MT Models supported
|
77 |
We support every MT model in the 🤗 Hugging Face's Hub. If you find a model that doesn't work, please open an issue for us to fix it or a PR with the fix. This includes, among many others:
|
78 |
- **Small100**: <https://huggingface.co/alirezamsh/small100>
|
79 |
- **Mbart many-to-many / many-to-one**: <https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt>
|
80 |
- **Opus MT**: <https://huggingface.co/Helsinki-NLP/opus-mt-es-en>
|
81 |
|
82 |
-
|
83 |
|
84 |
## Citation
|
85 |
If you use this software please cite
|
@@ -110,6 +115,7 @@ pip install accelerate
|
|
110 |
|
111 |
HuggingFace Transformers
|
112 |
If you plan to use NLLB200, please use >= 4.28.0, as an important bug was fixed in this version.
|
|
|
113 |
pip install --upgrade transformers
|
114 |
|
115 |
BitsAndBytes (Optional, required for 8-bits / 4-bits quantization)
|
@@ -135,6 +141,20 @@ python3 translate.py \
|
|
135 |
--model_name facebook/m2m100_1.2B
|
136 |
```
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
#### Multi-GPU
|
139 |
|
140 |
See Accelerate documentation for more information (multi-node, TPU, Sharded model...): <https://huggingface.co/docs/accelerate/index>
|
|
|
|
|
1 |
<p align="center">
|
2 |
<br>
|
3 |
<img src="images/title.png" width="900"/>
|
|
|
28 |
- BF16 / FP16 / FP32 / 8 Bits / 4 Bits precision.
|
29 |
- Automatic batch size finder: Forget CUDA OOM errors. Set an initial batch size, if it doesn't fit, we will automatically adjust it.
|
30 |
- Multiple decoding strategies: Greedy Search, Beam Search, Top-K Sampling, Top-p (nucleus) sampling, etc. See [Decoding Strategies](#decodingsampling-strategies) for more information.
|
31 |
+
- Load huge models in a single GPU with 8-bits / 4-bits quantization and support for splitting the model between GPU and CPU. See [Loading Huge Models](#loading-huge-models) for more information.
|
32 |
+
- LoRA models support
|
33 |
+
- Support for any Seq2SeqLM or CausalLM model from HuggingFace's Hub.
|
34 |
+
- Prompt support! See [Prompting](#prompting) for more information.
|
35 |
+
- :new: Add support for [SeamlessM4T](https://huggingface.co/docs/transformers/main/en/model_doc/seamless_m4t)!
|
36 |
|
37 |
>Test the 🔌 Online Demo here: <https://huggingface.co/spaces/Iker/Translate-100-languages>
|
38 |
|
39 |
|
|
|
|
|
|
|
|
|
|
|
40 |
## Supported Models
|
41 |
|
42 |
💥 EasyTranslate now supports any Seq2SeqLM (m2m100, nllb200, small100, mbart, MarianMT, T5, FlanT5, etc.) and any CausalLM (GPT2, LLaMA, Vicuna, Falcon) model from 🤗 Hugging Face's Hub!!
|
43 |
+
We still recommend you to use M2M100, NLLB200 or SeamlessM4T for the best results, but you can experiment with any other MT model, as well as prompting LLMs to generate translations (See [Prompting Section](#prompting) for more details).
|
44 |
You can also see [the examples folder](examples) for examples of how to use EasyTranslate with different models.
|
45 |
|
46 |
### M2M100
|
|
|
68 |
|
69 |
- **facebook/nllb-200-distilled-600M**: <https://huggingface.co/facebook/nllb-200-distilled-600M>
|
70 |
|
71 |
+
### SeamlessM4T
|
72 |
+
|
73 |
+
**SeamlessM4T** a collection of models designed to provide high quality translation, allowing people from different linguistic communities to communicate effortlessly through speech and text. It was introduced in this [paper](https://dl.fbaipublicfiles.com/seamless/seamless_m4t_paper.pdf) and first released in [this](https://github.com/facebookresearch/seamless_communication) repository.
|
74 |
+
>SeamlessM4T can directly translate between 196 Languages for text input/output.
|
75 |
+
|
76 |
+
- **facebook/hf-seamless-m4t-medium**: <https://huggingface.co/facebook/hf-seamless-m4t-medium> (Requires transformers 4.35.0)
|
77 |
+
|
78 |
+
- **facebook/hf-seamless-m4t-large**: <https://huggingface.co/facebook/hf-seamless-m4t-large> (Requires transformers 4.35.0)
|
79 |
+
|
80 |
+
|
81 |
### Other MT Models supported
|
82 |
We support every MT model in the 🤗 Hugging Face's Hub. If you find a model that doesn't work, please open an issue for us to fix it or a PR with the fix. This includes, among many others:
|
83 |
- **Small100**: <https://huggingface.co/alirezamsh/small100>
|
84 |
- **Mbart many-to-many / many-to-one**: <https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt>
|
85 |
- **Opus MT**: <https://huggingface.co/Helsinki-NLP/opus-mt-es-en>
|
86 |
|
87 |
+
See the [Supported languages table](supported_languages.md) for a table of the supported languages and their ids.
|
88 |
|
89 |
## Citation
|
90 |
If you use this software please cite
|
|
|
115 |
|
116 |
HuggingFace Transformers
|
117 |
If you plan to use NLLB200, please use >= 4.28.0, as an important bug was fixed in this version.
|
118 |
+
If you plan to use SeamlessM4T, please use >= 4.35.0.
|
119 |
pip install --upgrade transformers
|
120 |
|
121 |
BitsAndBytes (Optional, required for 8-bits / 4-bits quantization)
|
|
|
141 |
--model_name facebook/m2m100_1.2B
|
142 |
```
|
143 |
|
144 |
+
If you want to translate all the files in a directory, use the `--sentences_dir` flag instead of `--sentences_path`.
|
145 |
+
```bash
|
146 |
+
# We use --files_extension txt to translate only files with this extension.
|
147 |
+
# Use empty string to translate all files in the directory
|
148 |
+
|
149 |
+
python3 translate.py \
|
150 |
+
--sentences_dir sample_text/ \
|
151 |
+
--output_path sample_text/translations \
|
152 |
+
--files_extension txt \
|
153 |
+
--source_lang en \
|
154 |
+
--target_lang es \
|
155 |
+
--model_name facebook/m2m100_1.2B
|
156 |
+
```
|
157 |
+
|
158 |
#### Multi-GPU
|
159 |
|
160 |
See Accelerate documentation for more information (multi-node, TPU, Sharded model...): <https://huggingface.co/docs/accelerate/index>
|
model.py
CHANGED
@@ -14,8 +14,6 @@ from transformers.models.auto.modeling_auto import (
|
|
14 |
|
15 |
from typing import Optional, Tuple
|
16 |
|
17 |
-
import os
|
18 |
-
|
19 |
import torch
|
20 |
|
21 |
import json
|
@@ -27,6 +25,7 @@ def load_model_for_inference(
|
|
27 |
lora_weights_name_or_path: Optional[str] = None,
|
28 |
torch_dtype: Optional[str] = None,
|
29 |
force_auto_device_map: bool = False,
|
|
|
30 |
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
31 |
"""
|
32 |
Load any Decoder model for inference.
|
@@ -50,6 +49,8 @@ def load_model_for_inference(
|
|
50 |
Whether to force the use of the auto device map. If set to True, the model will be split across
|
51 |
GPUs and CPU to fit the model in memory. If set to False, a full copy of the model will be loaded
|
52 |
into each GPU. Defaults to False.
|
|
|
|
|
53 |
|
54 |
Returns:
|
55 |
`Tuple[PreTrainedModel, PreTrainedTokenizerBase]`:
|
@@ -64,19 +65,8 @@ def load_model_for_inference(
|
|
64 |
|
65 |
print(f"Loading model from {weights_path}")
|
66 |
|
67 |
-
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
|
68 |
-
{
|
69 |
-
"mpt": "MPTForCausalLM",
|
70 |
-
"RefinedWebModel": "RWForCausalLM",
|
71 |
-
"RefinedWeb": "RWForCausalLM",
|
72 |
-
}
|
73 |
-
) # MPT and Falcon are not in transformers yet
|
74 |
-
|
75 |
config = AutoConfig.from_pretrained(
|
76 |
-
weights_path,
|
77 |
-
trust_remote_code=True
|
78 |
-
if ("mpt" in weights_path or "falcon" in weights_path)
|
79 |
-
else False,
|
80 |
)
|
81 |
|
82 |
torch_dtype = (
|
@@ -84,20 +74,40 @@ def load_model_for_inference(
|
|
84 |
)
|
85 |
|
86 |
if "small100" in weights_path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
print(f"Loading custom small100 tokenizer for utils.tokenization_small100")
|
88 |
from utils.tokenization_small100 import SMALL100Tokenizer as AutoTokenizer
|
89 |
else:
|
90 |
from transformers import AutoTokenizer
|
91 |
|
92 |
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
|
93 |
-
weights_path,
|
94 |
-
add_eos_token=True,
|
95 |
-
trust_remote_code=True
|
96 |
-
if ("mpt" in weights_path or "falcon" in weights_path)
|
97 |
-
else False,
|
98 |
)
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
quant_args = {}
|
|
|
101 |
if quantization is not None:
|
102 |
quant_args = (
|
103 |
{"load_in_4bit": True} if quantization == 4 else {"load_in_8bit": True}
|
@@ -107,16 +117,17 @@ def load_model_for_inference(
|
|
107 |
load_in_4bit=True,
|
108 |
bnb_4bit_use_double_quant=True,
|
109 |
bnb_4bit_quant_type="nf4",
|
110 |
-
bnb_4bit_compute_dtype=torch.bfloat16
|
|
|
|
|
111 |
)
|
112 |
-
torch_dtype = torch.bfloat16
|
113 |
|
114 |
else:
|
115 |
bnb_config = BitsAndBytesConfig(
|
116 |
load_in_8bit=True,
|
117 |
)
|
118 |
print(
|
119 |
-
f"Bits and Bytes config: {json.dumps(bnb_config.to_dict(),indent=4,ensure_ascii=False)}"
|
120 |
)
|
121 |
else:
|
122 |
print(f"Loading model with dtype: {torch_dtype}")
|
@@ -131,6 +142,7 @@ def load_model_for_inference(
|
|
131 |
device_map="auto" if force_auto_device_map else None,
|
132 |
torch_dtype=torch_dtype,
|
133 |
quantization_config=bnb_config,
|
|
|
134 |
**quant_args,
|
135 |
)
|
136 |
|
@@ -142,9 +154,7 @@ def load_model_for_inference(
|
|
142 |
pretrained_model_name_or_path=weights_path,
|
143 |
device_map="auto" if force_auto_device_map else None,
|
144 |
torch_dtype=torch_dtype,
|
145 |
-
trust_remote_code=
|
146 |
-
if ("mpt" in weights_path or "falcon" in weights_path)
|
147 |
-
else False,
|
148 |
quantization_config=bnb_config,
|
149 |
**quant_args,
|
150 |
)
|
@@ -159,21 +169,6 @@ def load_model_for_inference(
|
|
159 |
f"CausalLM: {MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}\n"
|
160 |
)
|
161 |
|
162 |
-
if tokenizer.pad_token_id is None:
|
163 |
-
if "<|padding|>" in tokenizer.get_vocab():
|
164 |
-
# StableLM specific fix
|
165 |
-
tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
|
166 |
-
elif tokenizer.unk_token is not None:
|
167 |
-
print(
|
168 |
-
"Model does not have a pad token, we will use the unk token as pad token."
|
169 |
-
)
|
170 |
-
tokenizer.pad_token_id = tokenizer.unk_token_id
|
171 |
-
else:
|
172 |
-
print(
|
173 |
-
"Model does not have a pad token. We will use the eos token as pad token."
|
174 |
-
)
|
175 |
-
tokenizer.pad_token_id = tokenizer.eos_token_id
|
176 |
-
|
177 |
if lora_weights_name_or_path:
|
178 |
from peft import PeftModel
|
179 |
|
|
|
14 |
|
15 |
from typing import Optional, Tuple
|
16 |
|
|
|
|
|
17 |
import torch
|
18 |
|
19 |
import json
|
|
|
25 |
lora_weights_name_or_path: Optional[str] = None,
|
26 |
torch_dtype: Optional[str] = None,
|
27 |
force_auto_device_map: bool = False,
|
28 |
+
trust_remote_code: bool = False,
|
29 |
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
30 |
"""
|
31 |
Load any Decoder model for inference.
|
|
|
49 |
Whether to force the use of the auto device map. If set to True, the model will be split across
|
50 |
GPUs and CPU to fit the model in memory. If set to False, a full copy of the model will be loaded
|
51 |
into each GPU. Defaults to False.
|
52 |
+
trust_remote_code (`bool`, optional):
|
53 |
+
Trust the remote code from HuggingFace model hub. Defaults to False.
|
54 |
|
55 |
Returns:
|
56 |
`Tuple[PreTrainedModel, PreTrainedTokenizerBase]`:
|
|
|
65 |
|
66 |
print(f"Loading model from {weights_path}")
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
config = AutoConfig.from_pretrained(
|
69 |
+
weights_path, trust_remote_code=trust_remote_code
|
|
|
|
|
|
|
70 |
)
|
71 |
|
72 |
torch_dtype = (
|
|
|
74 |
)
|
75 |
|
76 |
if "small100" in weights_path:
|
77 |
+
import transformers
|
78 |
+
|
79 |
+
if transformers.__version__ > "4.34.0":
|
80 |
+
raise ValueError(
|
81 |
+
"Small100 tokenizer is not supported in transformers > 4.34.0. Please "
|
82 |
+
"use transformers <= 4.34.0 if you want to use small100"
|
83 |
+
)
|
84 |
+
|
85 |
print(f"Loading custom small100 tokenizer for utils.tokenization_small100")
|
86 |
from utils.tokenization_small100 import SMALL100Tokenizer as AutoTokenizer
|
87 |
else:
|
88 |
from transformers import AutoTokenizer
|
89 |
|
90 |
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
|
91 |
+
weights_path, add_eos_token=True, trust_remote_code=trust_remote_code
|
|
|
|
|
|
|
|
|
92 |
)
|
93 |
|
94 |
+
if tokenizer.pad_token_id is None:
|
95 |
+
if "<|padding|>" in tokenizer.get_vocab():
|
96 |
+
# StabilityLM specific fix
|
97 |
+
tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
|
98 |
+
elif tokenizer.unk_token is not None:
|
99 |
+
print(
|
100 |
+
"Tokenizer does not have a pad token, we will use the unk token as pad token."
|
101 |
+
)
|
102 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
103 |
+
else:
|
104 |
+
print(
|
105 |
+
"Tokenizer does not have a pad token. We will use the eos token as pad token."
|
106 |
+
)
|
107 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
108 |
+
|
109 |
quant_args = {}
|
110 |
+
|
111 |
if quantization is not None:
|
112 |
quant_args = (
|
113 |
{"load_in_4bit": True} if quantization == 4 else {"load_in_8bit": True}
|
|
|
117 |
load_in_4bit=True,
|
118 |
bnb_4bit_use_double_quant=True,
|
119 |
bnb_4bit_quant_type="nf4",
|
120 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
121 |
+
if torch_dtype in ["auto", None]
|
122 |
+
else torch_dtype,
|
123 |
)
|
|
|
124 |
|
125 |
else:
|
126 |
bnb_config = BitsAndBytesConfig(
|
127 |
load_in_8bit=True,
|
128 |
)
|
129 |
print(
|
130 |
+
f"Bits and Bytes config: {json.dumps(bnb_config.to_dict(), indent=4, ensure_ascii=False)}"
|
131 |
)
|
132 |
else:
|
133 |
print(f"Loading model with dtype: {torch_dtype}")
|
|
|
142 |
device_map="auto" if force_auto_device_map else None,
|
143 |
torch_dtype=torch_dtype,
|
144 |
quantization_config=bnb_config,
|
145 |
+
trust_remote_code=trust_remote_code,
|
146 |
**quant_args,
|
147 |
)
|
148 |
|
|
|
154 |
pretrained_model_name_or_path=weights_path,
|
155 |
device_map="auto" if force_auto_device_map else None,
|
156 |
torch_dtype=torch_dtype,
|
157 |
+
trust_remote_code=trust_remote_code,
|
|
|
|
|
158 |
quantization_config=bnb_config,
|
159 |
**quant_args,
|
160 |
)
|
|
|
169 |
f"CausalLM: {MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}\n"
|
170 |
)
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
if lora_weights_name_or_path:
|
173 |
from peft import PeftModel
|
174 |
|
sample_text/en2es.seamless-m4t-large.json
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"path": "sample_text/en2es.translation.seamless-m4t-large.txt",
|
3 |
+
"sacrebleu": {
|
4 |
+
"score": 36.315142112223896,
|
5 |
+
"counts": [
|
6 |
+
20334,
|
7 |
+
12742,
|
8 |
+
8758,
|
9 |
+
6156
|
10 |
+
],
|
11 |
+
"totals": [
|
12 |
+
31021,
|
13 |
+
30021,
|
14 |
+
29021,
|
15 |
+
28021
|
16 |
+
],
|
17 |
+
"precisions": [
|
18 |
+
65.54914412817124,
|
19 |
+
42.44362279737517,
|
20 |
+
30.178146859170944,
|
21 |
+
21.969237357696013
|
22 |
+
],
|
23 |
+
"bp": 0.9854077938820913,
|
24 |
+
"sys_len": 31021,
|
25 |
+
"ref_len": 31477
|
26 |
+
},
|
27 |
+
"rouge": {
|
28 |
+
"rouge1": 0.6330701226501922,
|
29 |
+
"rouge2": 0.4284215608900075,
|
30 |
+
"rougeL": 0.5852948888167713,
|
31 |
+
"rougeLsum": 0.5852893813466102
|
32 |
+
},
|
33 |
+
"bleu": {
|
34 |
+
"bleu": 0.36315142112223897,
|
35 |
+
"precisions": [
|
36 |
+
0.6554914412817124,
|
37 |
+
0.4244362279737517,
|
38 |
+
0.30178146859170946,
|
39 |
+
0.21969237357696014
|
40 |
+
],
|
41 |
+
"brevity_penalty": 0.9854077938820913,
|
42 |
+
"length_ratio": 0.9855132318835975,
|
43 |
+
"translation_length": 31021,
|
44 |
+
"reference_length": 31477
|
45 |
+
},
|
46 |
+
"meteor": {
|
47 |
+
"meteor": 0.5988659867679048
|
48 |
+
},
|
49 |
+
"ter": {
|
50 |
+
"score": 53.42233524051706,
|
51 |
+
"num_edits": 15126,
|
52 |
+
"ref_length": 28314.0
|
53 |
+
},
|
54 |
+
"bert_score": {
|
55 |
+
"precision": 0.8355873214006424,
|
56 |
+
"recall": 0.8343284497857094,
|
57 |
+
"f1": 0.8346186644434929,
|
58 |
+
"hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.35.2)_fast-tokenizer"
|
59 |
+
}
|
60 |
+
}
|
sample_text/en2es.seamless-m4t-medium.json
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"path": "sample_text/en2es.translation.seamless-m4t-medium.txt",
|
3 |
+
"sacrebleu": {
|
4 |
+
"score": 32.86110838375764,
|
5 |
+
"counts": [
|
6 |
+
19564,
|
7 |
+
11721,
|
8 |
+
7752,
|
9 |
+
5264
|
10 |
+
],
|
11 |
+
"totals": [
|
12 |
+
30811,
|
13 |
+
29811,
|
14 |
+
28811,
|
15 |
+
27812
|
16 |
+
],
|
17 |
+
"precisions": [
|
18 |
+
63.49680308980559,
|
19 |
+
39.31770151957331,
|
20 |
+
26.90638992051647,
|
21 |
+
18.92708183517906
|
22 |
+
],
|
23 |
+
"bp": 0.978616287348328,
|
24 |
+
"sys_len": 30811,
|
25 |
+
"ref_len": 31477
|
26 |
+
},
|
27 |
+
"rouge": {
|
28 |
+
"rouge1": 0.609193205717968,
|
29 |
+
"rouge2": 0.3944070815557623,
|
30 |
+
"rougeL": 0.558841464797821,
|
31 |
+
"rougeLsum": 0.5594046328281417
|
32 |
+
},
|
33 |
+
"bleu": {
|
34 |
+
"bleu": 0.3286110838375765,
|
35 |
+
"precisions": [
|
36 |
+
0.6349680308980559,
|
37 |
+
0.3931770151957331,
|
38 |
+
0.2690638992051647,
|
39 |
+
0.1892708183517906
|
40 |
+
],
|
41 |
+
"brevity_penalty": 0.978616287348328,
|
42 |
+
"length_ratio": 0.9788416939352543,
|
43 |
+
"translation_length": 30811,
|
44 |
+
"reference_length": 31477
|
45 |
+
},
|
46 |
+
"meteor": {
|
47 |
+
"meteor": 0.5707261528520716
|
48 |
+
},
|
49 |
+
"ter": {
|
50 |
+
"score": 55.88754679663771,
|
51 |
+
"num_edits": 15824,
|
52 |
+
"ref_length": 28314.0
|
53 |
+
},
|
54 |
+
"bert_score": {
|
55 |
+
"precision": 0.8278114783763886,
|
56 |
+
"recall": 0.824702616840601,
|
57 |
+
"f1": 0.8259151731133461,
|
58 |
+
"hashcode": "microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.35.2)_fast-tokenizer"
|
59 |
+
}
|
60 |
+
}
|
sample_text/en2es.translation.seamless-m4t-large.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sample_text/en2es.translation.seamless-m4t-medium.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tests/__init__.py
ADDED
File without changes
|
tests/test_translation.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Run with 'python -m unittest tests.test_translation'
|
2 |
+
|
3 |
+
import unittest
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
from translate import main
|
7 |
+
import transformers
|
8 |
+
|
9 |
+
|
10 |
+
class Inputs(unittest.TestCase):
|
11 |
+
def test_m2m100_inputs(self):
|
12 |
+
# Create a temporary directory
|
13 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
14 |
+
# Create a temporary file
|
15 |
+
|
16 |
+
input_path = os.path.join(tmpdirname, "source.txt")
|
17 |
+
output_path = os.path.join(tmpdirname, "target.txt")
|
18 |
+
|
19 |
+
with open(
|
20 |
+
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
|
21 |
+
) as f:
|
22 |
+
print("Hello, world, my name is Iker!", file=f)
|
23 |
+
|
24 |
+
main(
|
25 |
+
sentences_path=input_path,
|
26 |
+
sentences_dir=None,
|
27 |
+
files_extension="txt",
|
28 |
+
output_path=output_path,
|
29 |
+
source_lang="en",
|
30 |
+
target_lang="es",
|
31 |
+
starting_batch_size=32,
|
32 |
+
model_name="facebook/m2m100_418M",
|
33 |
+
lora_weights_name_or_path=None,
|
34 |
+
force_auto_device_map=True,
|
35 |
+
precision=None,
|
36 |
+
max_length=64,
|
37 |
+
num_beams=2,
|
38 |
+
num_return_sequences=1,
|
39 |
+
do_sample=False,
|
40 |
+
temperature=1.0,
|
41 |
+
top_k=50,
|
42 |
+
top_p=1.0,
|
43 |
+
keep_special_tokens=False,
|
44 |
+
keep_tokenization_spaces=False,
|
45 |
+
repetition_penalty=None,
|
46 |
+
prompt=None,
|
47 |
+
)
|
48 |
+
|
49 |
+
main(
|
50 |
+
sentences_path=None,
|
51 |
+
sentences_dir=tmpdirname,
|
52 |
+
files_extension="txt",
|
53 |
+
output_path=os.path.join(tmpdirname, "target"),
|
54 |
+
source_lang="en",
|
55 |
+
target_lang="es",
|
56 |
+
starting_batch_size=32,
|
57 |
+
model_name="facebook/m2m100_418M",
|
58 |
+
lora_weights_name_or_path=None,
|
59 |
+
force_auto_device_map=True,
|
60 |
+
precision=None,
|
61 |
+
max_length=64,
|
62 |
+
num_beams=2,
|
63 |
+
num_return_sequences=1,
|
64 |
+
do_sample=False,
|
65 |
+
temperature=1.0,
|
66 |
+
top_k=50,
|
67 |
+
top_p=1.0,
|
68 |
+
keep_special_tokens=False,
|
69 |
+
keep_tokenization_spaces=False,
|
70 |
+
repetition_penalty=None,
|
71 |
+
prompt=None,
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
class Translations(unittest.TestCase):
|
76 |
+
def test_m2m100(self):
|
77 |
+
# Create a temporary directory
|
78 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
79 |
+
# Create a temporary file
|
80 |
+
|
81 |
+
input_path = os.path.join(tmpdirname, "source.txt")
|
82 |
+
output_path = os.path.join(tmpdirname, "target.txt")
|
83 |
+
|
84 |
+
with open(
|
85 |
+
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
|
86 |
+
) as f:
|
87 |
+
print("Hello, world, my name is Iker!", file=f)
|
88 |
+
|
89 |
+
model_name = "facebook/m2m100_418M"
|
90 |
+
src_lang = "en"
|
91 |
+
tgt_lang = "es"
|
92 |
+
|
93 |
+
main(
|
94 |
+
sentences_path=input_path,
|
95 |
+
sentences_dir=None,
|
96 |
+
files_extension="txt",
|
97 |
+
output_path=output_path,
|
98 |
+
source_lang=src_lang,
|
99 |
+
target_lang=tgt_lang,
|
100 |
+
starting_batch_size=32,
|
101 |
+
model_name=model_name,
|
102 |
+
lora_weights_name_or_path=None,
|
103 |
+
force_auto_device_map=True,
|
104 |
+
precision="bf16",
|
105 |
+
max_length=64,
|
106 |
+
num_beams=2,
|
107 |
+
num_return_sequences=1,
|
108 |
+
do_sample=False,
|
109 |
+
temperature=1.0,
|
110 |
+
top_k=50,
|
111 |
+
top_p=1.0,
|
112 |
+
keep_special_tokens=False,
|
113 |
+
keep_tokenization_spaces=False,
|
114 |
+
repetition_penalty=None,
|
115 |
+
prompt=None,
|
116 |
+
)
|
117 |
+
|
118 |
+
main(
|
119 |
+
sentences_path=input_path,
|
120 |
+
sentences_dir=None,
|
121 |
+
files_extension="txt",
|
122 |
+
output_path=output_path,
|
123 |
+
source_lang=src_lang,
|
124 |
+
target_lang=tgt_lang,
|
125 |
+
starting_batch_size=32,
|
126 |
+
model_name=model_name,
|
127 |
+
lora_weights_name_or_path=None,
|
128 |
+
force_auto_device_map=True,
|
129 |
+
precision="4",
|
130 |
+
max_length=64,
|
131 |
+
num_beams=2,
|
132 |
+
num_return_sequences=1,
|
133 |
+
do_sample=False,
|
134 |
+
temperature=1.0,
|
135 |
+
top_k=50,
|
136 |
+
top_p=1.0,
|
137 |
+
keep_special_tokens=False,
|
138 |
+
keep_tokenization_spaces=False,
|
139 |
+
repetition_penalty=None,
|
140 |
+
prompt=None,
|
141 |
+
)
|
142 |
+
|
143 |
+
def test_nllb200(self):
|
144 |
+
# Create a temporary directory
|
145 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
146 |
+
# Create a temporary file
|
147 |
+
|
148 |
+
input_path = os.path.join(tmpdirname, "source.txt")
|
149 |
+
output_path = os.path.join(tmpdirname, "target.txt")
|
150 |
+
|
151 |
+
with open(
|
152 |
+
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
|
153 |
+
) as f:
|
154 |
+
print("Hello, world, my name is Iker!", file=f)
|
155 |
+
|
156 |
+
model_name = "facebook/nllb-200-distilled-600M"
|
157 |
+
src_lang = "eng_Latn"
|
158 |
+
tgt_lang = "spa_Latn"
|
159 |
+
|
160 |
+
main(
|
161 |
+
sentences_path=input_path,
|
162 |
+
sentences_dir=None,
|
163 |
+
files_extension="txt",
|
164 |
+
output_path=output_path,
|
165 |
+
source_lang=src_lang,
|
166 |
+
target_lang=tgt_lang,
|
167 |
+
starting_batch_size=32,
|
168 |
+
model_name=model_name,
|
169 |
+
lora_weights_name_or_path=None,
|
170 |
+
force_auto_device_map=True,
|
171 |
+
precision="bf16",
|
172 |
+
max_length=64,
|
173 |
+
num_beams=2,
|
174 |
+
num_return_sequences=1,
|
175 |
+
do_sample=False,
|
176 |
+
temperature=1.0,
|
177 |
+
top_k=50,
|
178 |
+
top_p=1.0,
|
179 |
+
keep_special_tokens=False,
|
180 |
+
keep_tokenization_spaces=False,
|
181 |
+
repetition_penalty=None,
|
182 |
+
prompt=None,
|
183 |
+
)
|
184 |
+
|
185 |
+
main(
|
186 |
+
sentences_path=input_path,
|
187 |
+
sentences_dir=None,
|
188 |
+
files_extension="txt",
|
189 |
+
output_path=output_path,
|
190 |
+
source_lang=src_lang,
|
191 |
+
target_lang=tgt_lang,
|
192 |
+
starting_batch_size=32,
|
193 |
+
model_name=model_name,
|
194 |
+
lora_weights_name_or_path=None,
|
195 |
+
force_auto_device_map=True,
|
196 |
+
precision="4",
|
197 |
+
max_length=64,
|
198 |
+
num_beams=2,
|
199 |
+
num_return_sequences=1,
|
200 |
+
do_sample=False,
|
201 |
+
temperature=1.0,
|
202 |
+
top_k=50,
|
203 |
+
top_p=1.0,
|
204 |
+
keep_special_tokens=False,
|
205 |
+
keep_tokenization_spaces=False,
|
206 |
+
repetition_penalty=None,
|
207 |
+
prompt=None,
|
208 |
+
)
|
209 |
+
|
210 |
+
def test_mbart(self):
|
211 |
+
# Create a temporary directory
|
212 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
213 |
+
# Create a temporary file
|
214 |
+
|
215 |
+
input_path = os.path.join(tmpdirname, "source.txt")
|
216 |
+
output_path = os.path.join(tmpdirname, "target.txt")
|
217 |
+
|
218 |
+
with open(
|
219 |
+
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
|
220 |
+
) as f:
|
221 |
+
print("Hello, world, my name is Iker!", file=f)
|
222 |
+
|
223 |
+
model_name = "facebook/mbart-large-50"
|
224 |
+
src_lang = "en_XX"
|
225 |
+
tgt_lang = "es_XX"
|
226 |
+
|
227 |
+
main(
|
228 |
+
sentences_path=input_path,
|
229 |
+
sentences_dir=None,
|
230 |
+
files_extension="txt",
|
231 |
+
output_path=output_path,
|
232 |
+
source_lang=src_lang,
|
233 |
+
target_lang=tgt_lang,
|
234 |
+
starting_batch_size=32,
|
235 |
+
model_name=model_name,
|
236 |
+
lora_weights_name_or_path=None,
|
237 |
+
force_auto_device_map=True,
|
238 |
+
precision="bf16",
|
239 |
+
max_length=64,
|
240 |
+
num_beams=2,
|
241 |
+
num_return_sequences=1,
|
242 |
+
do_sample=False,
|
243 |
+
temperature=1.0,
|
244 |
+
top_k=50,
|
245 |
+
top_p=1.0,
|
246 |
+
keep_special_tokens=False,
|
247 |
+
keep_tokenization_spaces=False,
|
248 |
+
repetition_penalty=None,
|
249 |
+
prompt=None,
|
250 |
+
)
|
251 |
+
|
252 |
+
main(
|
253 |
+
sentences_path=input_path,
|
254 |
+
sentences_dir=None,
|
255 |
+
files_extension="txt",
|
256 |
+
output_path=output_path,
|
257 |
+
source_lang=src_lang,
|
258 |
+
target_lang=tgt_lang,
|
259 |
+
starting_batch_size=32,
|
260 |
+
model_name=model_name,
|
261 |
+
lora_weights_name_or_path=None,
|
262 |
+
force_auto_device_map=True,
|
263 |
+
precision="4",
|
264 |
+
max_length=64,
|
265 |
+
num_beams=2,
|
266 |
+
num_return_sequences=1,
|
267 |
+
do_sample=False,
|
268 |
+
temperature=1.0,
|
269 |
+
top_k=50,
|
270 |
+
top_p=1.0,
|
271 |
+
keep_special_tokens=False,
|
272 |
+
keep_tokenization_spaces=False,
|
273 |
+
repetition_penalty=None,
|
274 |
+
prompt=None,
|
275 |
+
)
|
276 |
+
|
277 |
+
def test_opus(self):
|
278 |
+
# Create a temporary directory
|
279 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
280 |
+
# Create a temporary file
|
281 |
+
|
282 |
+
input_path = os.path.join(tmpdirname, "source.txt")
|
283 |
+
output_path = os.path.join(tmpdirname, "target.txt")
|
284 |
+
|
285 |
+
with open(
|
286 |
+
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
|
287 |
+
) as f:
|
288 |
+
print("Hello, world, my name is Iker!", file=f)
|
289 |
+
|
290 |
+
model_name = "Helsinki-NLP/opus-mt-en-es"
|
291 |
+
src_lang = None
|
292 |
+
tgt_lang = None
|
293 |
+
|
294 |
+
main(
|
295 |
+
sentences_path=input_path,
|
296 |
+
sentences_dir=None,
|
297 |
+
files_extension="txt",
|
298 |
+
output_path=output_path,
|
299 |
+
source_lang=src_lang,
|
300 |
+
target_lang=tgt_lang,
|
301 |
+
starting_batch_size=32,
|
302 |
+
model_name=model_name,
|
303 |
+
lora_weights_name_or_path=None,
|
304 |
+
force_auto_device_map=False,
|
305 |
+
precision="bf16",
|
306 |
+
max_length=64,
|
307 |
+
num_beams=2,
|
308 |
+
num_return_sequences=1,
|
309 |
+
do_sample=False,
|
310 |
+
temperature=1.0,
|
311 |
+
top_k=50,
|
312 |
+
top_p=1.0,
|
313 |
+
keep_special_tokens=False,
|
314 |
+
keep_tokenization_spaces=False,
|
315 |
+
repetition_penalty=None,
|
316 |
+
prompt=None,
|
317 |
+
)
|
318 |
+
|
319 |
+
main(
|
320 |
+
sentences_path=input_path,
|
321 |
+
sentences_dir=None,
|
322 |
+
files_extension="txt",
|
323 |
+
output_path=output_path,
|
324 |
+
source_lang=src_lang,
|
325 |
+
target_lang=tgt_lang,
|
326 |
+
starting_batch_size=32,
|
327 |
+
model_name=model_name,
|
328 |
+
lora_weights_name_or_path=None,
|
329 |
+
force_auto_device_map=False,
|
330 |
+
precision="4",
|
331 |
+
max_length=64,
|
332 |
+
num_beams=2,
|
333 |
+
num_return_sequences=1,
|
334 |
+
do_sample=False,
|
335 |
+
temperature=1.0,
|
336 |
+
top_k=50,
|
337 |
+
top_p=1.0,
|
338 |
+
keep_special_tokens=False,
|
339 |
+
keep_tokenization_spaces=False,
|
340 |
+
repetition_penalty=None,
|
341 |
+
prompt=None,
|
342 |
+
)
|
343 |
+
|
344 |
+
@unittest.skipIf(
|
345 |
+
transformers.__version__ > "4.34.0",
|
346 |
+
"Small100 tokenizer is not supported in transformers > 4.34.0. Please use transformers <= 4.34.0 if you want to use small100",
|
347 |
+
)
|
348 |
+
def test_small100(self):
|
349 |
+
# Create a temporary directory
|
350 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
351 |
+
# Create a temporary file
|
352 |
+
|
353 |
+
input_path = os.path.join(tmpdirname, "source.txt")
|
354 |
+
output_path = os.path.join(tmpdirname, "target.txt")
|
355 |
+
|
356 |
+
with open(
|
357 |
+
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
|
358 |
+
) as f:
|
359 |
+
print("Hello, world, my name is Iker!", file=f)
|
360 |
+
|
361 |
+
model_name = "alirezamsh/small100"
|
362 |
+
src_lang = None
|
363 |
+
tgt_lang = "es"
|
364 |
+
|
365 |
+
main(
|
366 |
+
sentences_path=input_path,
|
367 |
+
sentences_dir=None,
|
368 |
+
files_extension="txt",
|
369 |
+
output_path=output_path,
|
370 |
+
source_lang=src_lang,
|
371 |
+
target_lang=tgt_lang,
|
372 |
+
starting_batch_size=32,
|
373 |
+
model_name=model_name,
|
374 |
+
lora_weights_name_or_path=None,
|
375 |
+
force_auto_device_map=True,
|
376 |
+
precision="bf16",
|
377 |
+
max_length=64,
|
378 |
+
num_beams=2,
|
379 |
+
num_return_sequences=1,
|
380 |
+
do_sample=False,
|
381 |
+
temperature=1.0,
|
382 |
+
top_k=50,
|
383 |
+
top_p=1.0,
|
384 |
+
keep_special_tokens=False,
|
385 |
+
keep_tokenization_spaces=False,
|
386 |
+
repetition_penalty=None,
|
387 |
+
prompt=None,
|
388 |
+
)
|
389 |
+
|
390 |
+
main(
|
391 |
+
sentences_path=input_path,
|
392 |
+
sentences_dir=None,
|
393 |
+
files_extension="txt",
|
394 |
+
output_path=output_path,
|
395 |
+
source_lang=src_lang,
|
396 |
+
target_lang=tgt_lang,
|
397 |
+
starting_batch_size=32,
|
398 |
+
model_name=model_name,
|
399 |
+
lora_weights_name_or_path=None,
|
400 |
+
force_auto_device_map=True,
|
401 |
+
precision="4",
|
402 |
+
max_length=64,
|
403 |
+
num_beams=2,
|
404 |
+
num_return_sequences=1,
|
405 |
+
do_sample=False,
|
406 |
+
temperature=1.0,
|
407 |
+
top_k=50,
|
408 |
+
top_p=1.0,
|
409 |
+
keep_special_tokens=False,
|
410 |
+
keep_tokenization_spaces=False,
|
411 |
+
repetition_penalty=None,
|
412 |
+
prompt=None,
|
413 |
+
)
|
414 |
+
|
415 |
+
def test_seamless(self):
|
416 |
+
# Create a temporary directory
|
417 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
418 |
+
# Create a temporary file
|
419 |
+
|
420 |
+
input_path = os.path.join(tmpdirname, "source.txt")
|
421 |
+
output_path = os.path.join(tmpdirname, "target.txt")
|
422 |
+
|
423 |
+
with open(
|
424 |
+
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
|
425 |
+
) as f:
|
426 |
+
print("Hello, world, my name is Iker!", file=f)
|
427 |
+
|
428 |
+
model_name = "facebook/hf-seamless-m4t-medium"
|
429 |
+
src_lang = "eng"
|
430 |
+
tgt_lang = "spa"
|
431 |
+
|
432 |
+
main(
|
433 |
+
sentences_path=input_path,
|
434 |
+
sentences_dir=None,
|
435 |
+
files_extension="txt",
|
436 |
+
output_path=output_path,
|
437 |
+
source_lang=src_lang,
|
438 |
+
target_lang=tgt_lang,
|
439 |
+
starting_batch_size=32,
|
440 |
+
model_name=model_name,
|
441 |
+
lora_weights_name_or_path=None,
|
442 |
+
force_auto_device_map=True,
|
443 |
+
precision="bf16",
|
444 |
+
max_length=64,
|
445 |
+
num_beams=2,
|
446 |
+
num_return_sequences=1,
|
447 |
+
do_sample=False,
|
448 |
+
temperature=1.0,
|
449 |
+
top_k=50,
|
450 |
+
top_p=1.0,
|
451 |
+
keep_special_tokens=False,
|
452 |
+
keep_tokenization_spaces=False,
|
453 |
+
repetition_penalty=None,
|
454 |
+
prompt=None,
|
455 |
+
)
|
456 |
+
|
457 |
+
main(
|
458 |
+
sentences_path=input_path,
|
459 |
+
sentences_dir=None,
|
460 |
+
files_extension="txt",
|
461 |
+
output_path=output_path,
|
462 |
+
source_lang=src_lang,
|
463 |
+
target_lang=tgt_lang,
|
464 |
+
starting_batch_size=32,
|
465 |
+
model_name=model_name,
|
466 |
+
lora_weights_name_or_path=None,
|
467 |
+
force_auto_device_map=True,
|
468 |
+
precision="4",
|
469 |
+
max_length=64,
|
470 |
+
num_beams=2,
|
471 |
+
num_return_sequences=1,
|
472 |
+
do_sample=False,
|
473 |
+
temperature=1.0,
|
474 |
+
top_k=50,
|
475 |
+
top_p=1.0,
|
476 |
+
keep_special_tokens=False,
|
477 |
+
keep_tokenization_spaces=False,
|
478 |
+
repetition_penalty=None,
|
479 |
+
prompt=None,
|
480 |
+
)
|
481 |
+
|
482 |
+
|
483 |
+
class Prompting(unittest.TestCase):
|
484 |
+
def test_llama(self):
|
485 |
+
# Create a temporary directory
|
486 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
487 |
+
# Create a temporary file
|
488 |
+
|
489 |
+
input_path = os.path.join(tmpdirname, "source.txt")
|
490 |
+
output_path = os.path.join(tmpdirname, "target.txt")
|
491 |
+
|
492 |
+
with open(
|
493 |
+
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8"
|
494 |
+
) as f:
|
495 |
+
print("Hello, world, my name is Iker!", file=f)
|
496 |
+
|
497 |
+
model_name = "stas/tiny-random-llama-2"
|
498 |
+
prompt = "Translate English to Spanish: %%SENTENCE%%"
|
499 |
+
|
500 |
+
main(
|
501 |
+
sentences_path=input_path,
|
502 |
+
sentences_dir=None,
|
503 |
+
files_extension="txt",
|
504 |
+
output_path=output_path,
|
505 |
+
source_lang=None,
|
506 |
+
target_lang=None,
|
507 |
+
starting_batch_size=32,
|
508 |
+
model_name=model_name,
|
509 |
+
lora_weights_name_or_path=None,
|
510 |
+
force_auto_device_map=True,
|
511 |
+
precision="bf16",
|
512 |
+
max_length=64,
|
513 |
+
num_beams=2,
|
514 |
+
num_return_sequences=1,
|
515 |
+
do_sample=True,
|
516 |
+
temperature=1.0,
|
517 |
+
top_k=50,
|
518 |
+
top_p=1.0,
|
519 |
+
keep_special_tokens=False,
|
520 |
+
keep_tokenization_spaces=False,
|
521 |
+
repetition_penalty=None,
|
522 |
+
prompt=prompt,
|
523 |
+
)
|
524 |
+
|
525 |
+
main(
|
526 |
+
sentences_path=input_path,
|
527 |
+
sentences_dir=None,
|
528 |
+
files_extension="txt",
|
529 |
+
output_path=output_path,
|
530 |
+
source_lang=None,
|
531 |
+
target_lang=None,
|
532 |
+
starting_batch_size=32,
|
533 |
+
model_name=model_name,
|
534 |
+
lora_weights_name_or_path=None,
|
535 |
+
force_auto_device_map=True,
|
536 |
+
precision="4",
|
537 |
+
max_length=64,
|
538 |
+
num_beams=2,
|
539 |
+
num_return_sequences=1,
|
540 |
+
do_sample=True,
|
541 |
+
temperature=1.0,
|
542 |
+
top_k=50,
|
543 |
+
top_p=1.0,
|
544 |
+
keep_special_tokens=False,
|
545 |
+
keep_tokenization_spaces=False,
|
546 |
+
repetition_penalty=None,
|
547 |
+
prompt=prompt,
|
548 |
+
)
|
translate.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import math
|
3 |
import argparse
|
|
|
4 |
|
5 |
import torch
|
6 |
from torch.utils.data import DataLoader
|
@@ -18,6 +19,8 @@ from dataset import DatasetReader, count_lines
|
|
18 |
|
19 |
from accelerate import Accelerator, DistributedType, find_executable_batch_size
|
20 |
|
|
|
|
|
21 |
|
22 |
def encode_string(text):
|
23 |
return text.replace("\r", r"\r").replace("\n", r"\n").replace("\t", r"\t")
|
@@ -31,7 +34,12 @@ def get_dataloader(
|
|
31 |
max_length: int,
|
32 |
prompt: str,
|
33 |
) -> DataLoader:
|
34 |
-
dataset = DatasetReader(
|
|
|
|
|
|
|
|
|
|
|
35 |
if accelerator.distributed_type == DistributedType.TPU:
|
36 |
data_collator = DataCollatorForSeq2Seq(
|
37 |
tokenizer,
|
@@ -59,16 +67,18 @@ def get_dataloader(
|
|
59 |
|
60 |
|
61 |
def main(
|
62 |
-
sentences_path: str,
|
|
|
|
|
63 |
output_path: str,
|
64 |
-
source_lang: str,
|
65 |
-
target_lang: str,
|
66 |
starting_batch_size: int,
|
67 |
model_name: str = "facebook/m2m100_1.2B",
|
68 |
lora_weights_name_or_path: str = None,
|
69 |
force_auto_device_map: bool = False,
|
70 |
precision: str = None,
|
71 |
-
max_length: int =
|
72 |
num_beams: int = 4,
|
73 |
num_return_sequences: int = 1,
|
74 |
do_sample: bool = False,
|
@@ -79,9 +89,8 @@ def main(
|
|
79 |
keep_tokenization_spaces: bool = False,
|
80 |
repetition_penalty: float = None,
|
81 |
prompt: str = None,
|
|
|
82 |
):
|
83 |
-
os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
|
84 |
-
|
85 |
accelerator = Accelerator()
|
86 |
|
87 |
if force_auto_device_map and starting_batch_size >= 64:
|
@@ -92,6 +101,16 @@ def main(
|
|
92 |
f"inference. You should consider using a smaller batch size, i.e '--starting_batch_size 8'"
|
93 |
)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
if precision is None:
|
96 |
quantization = None
|
97 |
dtype = None
|
@@ -118,11 +137,17 @@ def main(
|
|
118 |
lora_weights_name_or_path=lora_weights_name_or_path,
|
119 |
torch_dtype=dtype,
|
120 |
force_auto_device_map=force_auto_device_map,
|
|
|
121 |
)
|
122 |
|
123 |
is_translation_model = hasattr(tokenizer, "lang_code_to_id")
|
|
|
124 |
|
125 |
-
if
|
|
|
|
|
|
|
|
|
126 |
raise ValueError(
|
127 |
f"The model you are using requires a source and target language. "
|
128 |
f"Please specify them with --source-lang and --target-lang. "
|
@@ -169,8 +194,32 @@ def main(
|
|
169 |
# We don't need to force the BOS token, so we set is_translation_model to False
|
170 |
is_translation_model = False
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
gen_kwargs = {
|
173 |
-
"
|
174 |
"num_beams": num_beams,
|
175 |
"num_return_sequences": num_return_sequences,
|
176 |
"do_sample": do_sample,
|
@@ -182,12 +231,17 @@ def main(
|
|
182 |
if repetition_penalty is not None:
|
183 |
gen_kwargs["repetition_penalty"] = repetition_penalty
|
184 |
|
185 |
-
|
|
|
|
|
|
|
|
|
186 |
|
187 |
if accelerator.is_main_process:
|
188 |
print(
|
189 |
f"** Translation **\n"
|
190 |
f"Input file: {sentences_path}\n"
|
|
|
191 |
f"Output file: {output_path}\n"
|
192 |
f"Source language: {source_lang}\n"
|
193 |
f"Target language: {target_lang}\n"
|
@@ -211,10 +265,12 @@ def main(
|
|
211 |
print("\n")
|
212 |
|
213 |
@find_executable_batch_size(starting_batch_size=starting_batch_size)
|
214 |
-
def inference(batch_size):
|
215 |
-
nonlocal model, tokenizer,
|
|
|
|
|
216 |
|
217 |
-
|
218 |
|
219 |
data_loader = get_dataloader(
|
220 |
accelerator=accelerator,
|
@@ -243,9 +299,6 @@ def main(
|
|
243 |
|
244 |
generated_tokens = accelerator.unwrap_model(model).generate(
|
245 |
**batch,
|
246 |
-
forced_bos_token_id=lang_code_to_idx
|
247 |
-
if is_translation_model
|
248 |
-
else None,
|
249 |
**gen_kwargs,
|
250 |
)
|
251 |
|
@@ -286,24 +339,60 @@ def main(
|
|
286 |
|
287 |
pbar.update(len(tgt_text) // gen_kwargs["num_return_sequences"])
|
288 |
|
289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
print(f"Translation done.\n")
|
291 |
|
292 |
|
293 |
if __name__ == "__main__":
|
294 |
parser = argparse.ArgumentParser(description="Run the translation experiments")
|
295 |
-
parser.
|
|
|
296 |
"--sentences_path",
|
|
|
297 |
type=str,
|
298 |
-
required=True,
|
299 |
help="Path to a txt file containing the sentences to translate. One sentence per line.",
|
300 |
)
|
301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
parser.add_argument(
|
303 |
"--output_path",
|
304 |
type=str,
|
305 |
required=True,
|
306 |
-
help="Path to a txt file where the translated sentences will be written."
|
|
|
307 |
)
|
308 |
|
309 |
parser.add_argument(
|
@@ -355,7 +444,7 @@ if __name__ == "__main__":
|
|
355 |
parser.add_argument(
|
356 |
"--max_length",
|
357 |
type=int,
|
358 |
-
default=
|
359 |
help="Maximum number of tokens in the source sentence and generated sentence. "
|
360 |
"Increase this value to translate longer sentences, at the cost of increasing memory usage.",
|
361 |
)
|
@@ -438,10 +527,18 @@ if __name__ == "__main__":
|
|
438 |
"It must include the special token %%SENTENCE%% which will be replaced by the sentence to translate.",
|
439 |
)
|
440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
args = parser.parse_args()
|
442 |
|
443 |
main(
|
444 |
sentences_path=args.sentences_path,
|
|
|
|
|
445 |
output_path=args.output_path,
|
446 |
source_lang=args.source_lang,
|
447 |
target_lang=args.target_lang,
|
@@ -459,4 +556,5 @@ if __name__ == "__main__":
|
|
459 |
keep_tokenization_spaces=args.keep_tokenization_spaces,
|
460 |
repetition_penalty=args.repetition_penalty,
|
461 |
prompt=args.prompt,
|
|
|
462 |
)
|
|
|
1 |
import os
|
2 |
import math
|
3 |
import argparse
|
4 |
+
import glob
|
5 |
|
6 |
import torch
|
7 |
from torch.utils.data import DataLoader
|
|
|
19 |
|
20 |
from accelerate import Accelerator, DistributedType, find_executable_batch_size
|
21 |
|
22 |
+
from typing import Optional
|
23 |
+
|
24 |
|
25 |
def encode_string(text):
|
26 |
return text.replace("\r", r"\r").replace("\n", r"\n").replace("\t", r"\t")
|
|
|
34 |
max_length: int,
|
35 |
prompt: str,
|
36 |
) -> DataLoader:
|
37 |
+
dataset = DatasetReader(
|
38 |
+
filename=filename,
|
39 |
+
tokenizer=tokenizer,
|
40 |
+
max_length=max_length,
|
41 |
+
prompt=prompt,
|
42 |
+
)
|
43 |
if accelerator.distributed_type == DistributedType.TPU:
|
44 |
data_collator = DataCollatorForSeq2Seq(
|
45 |
tokenizer,
|
|
|
67 |
|
68 |
|
69 |
def main(
|
70 |
+
sentences_path: Optional[str],
|
71 |
+
sentences_dir: Optional[str],
|
72 |
+
files_extension: str,
|
73 |
output_path: str,
|
74 |
+
source_lang: Optional[str],
|
75 |
+
target_lang: Optional[str],
|
76 |
starting_batch_size: int,
|
77 |
model_name: str = "facebook/m2m100_1.2B",
|
78 |
lora_weights_name_or_path: str = None,
|
79 |
force_auto_device_map: bool = False,
|
80 |
precision: str = None,
|
81 |
+
max_length: int = 256,
|
82 |
num_beams: int = 4,
|
83 |
num_return_sequences: int = 1,
|
84 |
do_sample: bool = False,
|
|
|
89 |
keep_tokenization_spaces: bool = False,
|
90 |
repetition_penalty: float = None,
|
91 |
prompt: str = None,
|
92 |
+
trust_remote_code: bool = False,
|
93 |
):
|
|
|
|
|
94 |
accelerator = Accelerator()
|
95 |
|
96 |
if force_auto_device_map and starting_batch_size >= 64:
|
|
|
101 |
f"inference. You should consider using a smaller batch size, i.e '--starting_batch_size 8'"
|
102 |
)
|
103 |
|
104 |
+
if sentences_path is None and sentences_dir is None:
|
105 |
+
raise ValueError(
|
106 |
+
"You must specify either --sentences_path or --sentences_dir. Use --help for more details."
|
107 |
+
)
|
108 |
+
|
109 |
+
if sentences_path is not None and sentences_dir is not None:
|
110 |
+
raise ValueError(
|
111 |
+
"You must specify either --sentences_path or --sentences_dir, not both. Use --help for more details."
|
112 |
+
)
|
113 |
+
|
114 |
if precision is None:
|
115 |
quantization = None
|
116 |
dtype = None
|
|
|
137 |
lora_weights_name_or_path=lora_weights_name_or_path,
|
138 |
torch_dtype=dtype,
|
139 |
force_auto_device_map=force_auto_device_map,
|
140 |
+
trust_remote_code=trust_remote_code,
|
141 |
)
|
142 |
|
143 |
is_translation_model = hasattr(tokenizer, "lang_code_to_id")
|
144 |
+
lang_code_to_idx = None
|
145 |
|
146 |
+
if (
|
147 |
+
is_translation_model
|
148 |
+
and (source_lang is None or target_lang is None)
|
149 |
+
and "small100" not in model_name
|
150 |
+
):
|
151 |
raise ValueError(
|
152 |
f"The model you are using requires a source and target language. "
|
153 |
f"Please specify them with --source-lang and --target-lang. "
|
|
|
194 |
# We don't need to force the BOS token, so we set is_translation_model to False
|
195 |
is_translation_model = False
|
196 |
|
197 |
+
if model.config.model_type == "seamless_m4t":
|
198 |
+
# Loading a seamless_m4t model, we need to set a few things to ensure compatibility
|
199 |
+
|
200 |
+
supported_langs = tokenizer.additional_special_tokens
|
201 |
+
supported_langs = [lang.replace("__", "") for lang in supported_langs]
|
202 |
+
|
203 |
+
if source_lang is None or target_lang is None:
|
204 |
+
raise ValueError(
|
205 |
+
f"The model you are using requires a source and target language. "
|
206 |
+
f"Please specify them with --source-lang and --target-lang. "
|
207 |
+
f"The supported languages are: {supported_langs}"
|
208 |
+
)
|
209 |
+
|
210 |
+
if source_lang not in supported_langs:
|
211 |
+
raise ValueError(
|
212 |
+
f"Language {source_lang} not found in tokenizer. Available languages: {supported_langs}"
|
213 |
+
)
|
214 |
+
if target_lang not in supported_langs:
|
215 |
+
raise ValueError(
|
216 |
+
f"Language {target_lang} not found in tokenizer. Available languages: {supported_langs}"
|
217 |
+
)
|
218 |
+
|
219 |
+
tokenizer.src_lang = source_lang
|
220 |
+
|
221 |
gen_kwargs = {
|
222 |
+
"max_new_tokens": max_length,
|
223 |
"num_beams": num_beams,
|
224 |
"num_return_sequences": num_return_sequences,
|
225 |
"do_sample": do_sample,
|
|
|
231 |
if repetition_penalty is not None:
|
232 |
gen_kwargs["repetition_penalty"] = repetition_penalty
|
233 |
|
234 |
+
if is_translation_model:
|
235 |
+
gen_kwargs["forced_bos_token_id"] = lang_code_to_idx
|
236 |
+
|
237 |
+
if model.config.model_type == "seamless_m4t":
|
238 |
+
gen_kwargs["tgt_lang"] = target_lang
|
239 |
|
240 |
if accelerator.is_main_process:
|
241 |
print(
|
242 |
f"** Translation **\n"
|
243 |
f"Input file: {sentences_path}\n"
|
244 |
+
f"Sentences dir: {sentences_dir}\n"
|
245 |
f"Output file: {output_path}\n"
|
246 |
f"Source language: {source_lang}\n"
|
247 |
f"Target language: {target_lang}\n"
|
|
|
265 |
print("\n")
|
266 |
|
267 |
@find_executable_batch_size(starting_batch_size=starting_batch_size)
|
268 |
+
def inference(batch_size, sentences_path, output_path):
|
269 |
+
nonlocal model, tokenizer, max_length, gen_kwargs, precision, prompt, is_translation_model
|
270 |
+
|
271 |
+
print(f"Translating {sentences_path} with batch size {batch_size}")
|
272 |
|
273 |
+
total_lines: int = count_lines(sentences_path)
|
274 |
|
275 |
data_loader = get_dataloader(
|
276 |
accelerator=accelerator,
|
|
|
299 |
|
300 |
generated_tokens = accelerator.unwrap_model(model).generate(
|
301 |
**batch,
|
|
|
|
|
|
|
302 |
**gen_kwargs,
|
303 |
)
|
304 |
|
|
|
339 |
|
340 |
pbar.update(len(tgt_text) // gen_kwargs["num_return_sequences"])
|
341 |
|
342 |
+
print(f"Translation done. Output written to {output_path}\n")
|
343 |
+
|
344 |
+
if sentences_path is not None:
|
345 |
+
os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
|
346 |
+
inference(sentences_path=sentences_path, output_path=output_path)
|
347 |
+
|
348 |
+
if sentences_dir is not None:
|
349 |
+
print(
|
350 |
+
f"Translating all files in {sentences_dir}, with extension {files_extension}"
|
351 |
+
)
|
352 |
+
os.makedirs(os.path.abspath(output_path), exist_ok=True)
|
353 |
+
for filename in glob.glob(
|
354 |
+
os.path.join(
|
355 |
+
sentences_dir, f"*.{files_extension}" if files_extension else "*"
|
356 |
+
)
|
357 |
+
):
|
358 |
+
output_filename = os.path.join(output_path, os.path.basename(filename))
|
359 |
+
inference(sentences_path=filename, output_path=output_filename)
|
360 |
+
|
361 |
print(f"Translation done.\n")
|
362 |
|
363 |
|
364 |
if __name__ == "__main__":
|
365 |
parser = argparse.ArgumentParser(description="Run the translation experiments")
|
366 |
+
input_group = parser.add_mutually_exclusive_group(required=True)
|
367 |
+
input_group.add_argument(
|
368 |
"--sentences_path",
|
369 |
+
default=None,
|
370 |
type=str,
|
|
|
371 |
help="Path to a txt file containing the sentences to translate. One sentence per line.",
|
372 |
)
|
373 |
|
374 |
+
input_group.add_argument(
|
375 |
+
"--sentences_dir",
|
376 |
+
type=str,
|
377 |
+
default=None,
|
378 |
+
help="Path to a directory containing the sentences to translate. "
|
379 |
+
"Sentences must be in .txt files containing containing one sentence per line.",
|
380 |
+
)
|
381 |
+
|
382 |
+
parser.add_argument(
|
383 |
+
"--files_extension",
|
384 |
+
type=str,
|
385 |
+
default="txt",
|
386 |
+
help="If sentences_dir is specified, extension of the files to translate. Defaults to txt. "
|
387 |
+
"If set to an empty string, we will translate all files in the directory.",
|
388 |
+
)
|
389 |
+
|
390 |
parser.add_argument(
|
391 |
"--output_path",
|
392 |
type=str,
|
393 |
required=True,
|
394 |
+
help="Path to a txt file where the translated sentences will be written. If the input is a directory, "
|
395 |
+
"the output will be a directory with the same structure.",
|
396 |
)
|
397 |
|
398 |
parser.add_argument(
|
|
|
444 |
parser.add_argument(
|
445 |
"--max_length",
|
446 |
type=int,
|
447 |
+
default=256,
|
448 |
help="Maximum number of tokens in the source sentence and generated sentence. "
|
449 |
"Increase this value to translate longer sentences, at the cost of increasing memory usage.",
|
450 |
)
|
|
|
527 |
"It must include the special token %%SENTENCE%% which will be replaced by the sentence to translate.",
|
528 |
)
|
529 |
|
530 |
+
parser.add_argument(
|
531 |
+
"--trust_remote_code",
|
532 |
+
action="store_true",
|
533 |
+
help="If set we will trust remote code in HuggingFace models. This is required for some models.",
|
534 |
+
)
|
535 |
+
|
536 |
args = parser.parse_args()
|
537 |
|
538 |
main(
|
539 |
sentences_path=args.sentences_path,
|
540 |
+
sentences_dir=args.sentences_dir,
|
541 |
+
files_extension=args.files_extension,
|
542 |
output_path=args.output_path,
|
543 |
source_lang=args.source_lang,
|
544 |
target_lang=args.target_lang,
|
|
|
556 |
keep_tokenization_spaces=args.keep_tokenization_spaces,
|
557 |
repetition_penalty=args.repetition_penalty,
|
558 |
prompt=args.prompt,
|
559 |
+
trust_remote_code=args.trust_remote_code,
|
560 |
)
|