Spaces:
Running
Running
Merge branch 'main' of https://huggingface.co/spaces/SoybeanMilk/whisper-webui-translate
Browse files1.Thank you SoybeanMilk for assisting in the development and integration of the excellent ALMA translation model.
2.Add the missing packages for the GPTQ version of ALMA in the requirements.txt.
3.Include the 7B version of the ALMA model in addition to the 13B version in the web UI.
4.Write the ReadMe document for the Translate Model, containing a brief introduction and explanation of each translation model used in this project.
- app.py +24 -15
- config.json5 +6 -1
- docs/options.md +9 -9
- docs/translateModel.md +112 -0
- requirements-fasterWhisper.txt +6 -1
- requirements-whisper.txt +7 -1
- requirements.txt +6 -1
- src/config.py +6 -1
- src/translation/translationLangs.py +22 -2
- src/translation/translationModel.py +143 -51
- src/utils.py +1 -1
app.py
CHANGED
@@ -40,7 +40,7 @@ from src.whisper.whisperFactory import create_whisper_container
|
|
40 |
from src.translation.translationModel import TranslationModel
|
41 |
from src.translation.translationLangs import (TranslationLang,
|
42 |
_TO_LANG_CODE_WHISPER, get_lang_whisper_names, get_lang_from_whisper_name, get_lang_from_whisper_code,
|
43 |
-
get_lang_nllb_names, get_lang_from_nllb_name, get_lang_m2m100_names, get_lang_from_m2m100_name)
|
44 |
import shutil
|
45 |
import zhconv
|
46 |
import tqdm
|
@@ -773,8 +773,15 @@ class WhisperTranscriber:
|
|
773 |
self.diarization = None
|
774 |
|
775 |
def create_ui(app_config: ApplicationConfig):
|
|
|
776 |
optionsMd: str = None
|
777 |
readmeMd: str = None
|
|
|
|
|
|
|
|
|
|
|
|
|
778 |
try:
|
779 |
optionsPath = pathlib.Path("docs/options.md")
|
780 |
with open(optionsPath, "r", encoding="utf-8") as optionsFile:
|
@@ -819,16 +826,6 @@ def create_ui(app_config: ApplicationConfig):
|
|
819 |
uiDescription += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
|
820 |
|
821 |
uiArticle = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
|
822 |
-
uiArticle += "\n\nWhisper's Task 'translate' only implements the functionality of translating other languages into English. "
|
823 |
-
uiArticle += "OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the Translation Model to implement the translation task. "
|
824 |
-
uiArticle += "However, it's important to note that the Translation Model runs slowly(in CPU), and the completion time may be twice as long as usual. "
|
825 |
-
uiArticle += "\n\nThe larger the parameters of the Translation model, the better its performance is expected to be. "
|
826 |
-
uiArticle += "However, it also requires higher computational resources, making it slower to operate. "
|
827 |
-
uiArticle += "On the other hand, the version converted from ct2 ([CTranslate2](https://opennmt.net/CTranslate2/guides/transformers.html)) requires lower resources and operates at a faster speed."
|
828 |
-
uiArticle += "\n\nCurrently, enabling `Highlight Words` timestamps cannot be used in conjunction with Translation Model translation "
|
829 |
-
uiArticle += "because Highlight Words will split the source text, and after translation, it becomes a non-word-level string. "
|
830 |
-
uiArticle += "\n\nThe 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. "
|
831 |
-
uiArticle += "This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English. "
|
832 |
|
833 |
whisper_models = app_config.get_model_names("whisper")
|
834 |
nllb_models = app_config.get_model_names("nllb")
|
@@ -854,7 +851,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
854 |
}
|
855 |
common_ALMA_inputs = lambda : {
|
856 |
gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
|
857 |
-
gr.Dropdown(label="ALMA - Language", choices=
|
858 |
}
|
859 |
|
860 |
common_translation_inputs = lambda : {
|
@@ -944,8 +941,10 @@ def create_ui(app_config: ApplicationConfig):
|
|
944 |
simpleInputDict.update(common_translation_inputs())
|
945 |
with gr.Column():
|
946 |
simpleOutput = common_output()
|
947 |
-
|
948 |
-
|
|
|
|
|
949 |
if optionsMd is not None:
|
950 |
with gr.Accordion("docs/options.md", open=False):
|
951 |
gr.Markdown(optionsMd)
|
@@ -1056,7 +1055,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
1056 |
print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
|
1057 |
else:
|
1058 |
print("Queue mode disabled - progress bars will not be shown.")
|
1059 |
-
|
1060 |
demo.launch(inbrowser=app_config.autolaunch, share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
|
1061 |
|
1062 |
# Clean up
|
@@ -1138,6 +1137,16 @@ if __name__ == '__main__':
|
|
1138 |
# updated_config.autolaunch = True
|
1139 |
# updated_config.auto_parallel = False
|
1140 |
# updated_config.save_downloaded_files = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1141 |
|
1142 |
if (threads := args.pop("threads")) > 0:
|
1143 |
torch.set_num_threads(threads)
|
|
|
40 |
from src.translation.translationModel import TranslationModel
|
41 |
from src.translation.translationLangs import (TranslationLang,
|
42 |
_TO_LANG_CODE_WHISPER, get_lang_whisper_names, get_lang_from_whisper_name, get_lang_from_whisper_code,
|
43 |
+
get_lang_nllb_names, get_lang_from_nllb_name, get_lang_m2m100_names, get_lang_from_m2m100_name, sort_lang_by_whisper_codes)
|
44 |
import shutil
|
45 |
import zhconv
|
46 |
import tqdm
|
|
|
773 |
self.diarization = None
|
774 |
|
775 |
def create_ui(app_config: ApplicationConfig):
|
776 |
+
translateModelMd: str = None
|
777 |
optionsMd: str = None
|
778 |
readmeMd: str = None
|
779 |
+
try:
|
780 |
+
translateModelPath = pathlib.Path("docs/translateModel.md")
|
781 |
+
with open(translateModelPath, "r", encoding="utf-8") as translateModelFile:
|
782 |
+
translateModelMd = translateModelFile.read()
|
783 |
+
except Exception as e:
|
784 |
+
print("Error occurred during read translateModel.md file: ", str(e))
|
785 |
try:
|
786 |
optionsPath = pathlib.Path("docs/options.md")
|
787 |
with open(optionsPath, "r", encoding="utf-8") as optionsFile:
|
|
|
826 |
uiDescription += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
|
827 |
|
828 |
uiArticle = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
|
830 |
whisper_models = app_config.get_model_names("whisper")
|
831 |
nllb_models = app_config.get_model_names("nllb")
|
|
|
851 |
}
|
852 |
common_ALMA_inputs = lambda : {
|
853 |
gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
|
854 |
+
gr.Dropdown(label="ALMA - Language", choices=sort_lang_by_whisper_codes(["en", "de", "cs", "is", "ru", "zh", "ja"]), elem_id="ALMALangName"),
|
855 |
}
|
856 |
|
857 |
common_translation_inputs = lambda : {
|
|
|
941 |
simpleInputDict.update(common_translation_inputs())
|
942 |
with gr.Column():
|
943 |
simpleOutput = common_output()
|
944 |
+
gr.Markdown(uiArticle)
|
945 |
+
if translateModelMd is not None:
|
946 |
+
with gr.Accordion("docs/translateModel.md", open=False):
|
947 |
+
gr.Markdown(translateModelMd)
|
948 |
if optionsMd is not None:
|
949 |
with gr.Accordion("docs/options.md", open=False):
|
950 |
gr.Markdown(optionsMd)
|
|
|
1055 |
print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
|
1056 |
else:
|
1057 |
print("Queue mode disabled - progress bars will not be shown.")
|
1058 |
+
|
1059 |
demo.launch(inbrowser=app_config.autolaunch, share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
|
1060 |
|
1061 |
# Clean up
|
|
|
1137 |
# updated_config.autolaunch = True
|
1138 |
# updated_config.auto_parallel = False
|
1139 |
# updated_config.save_downloaded_files = True
|
1140 |
+
|
1141 |
+
try:
|
1142 |
+
if torch.cuda.is_available():
|
1143 |
+
deviceId = torch.cuda.current_device()
|
1144 |
+
totalVram = torch.cuda.get_device_properties(deviceId).total_memory
|
1145 |
+
if totalVram/(1024*1024*1024) <= 4: #VRAM <= 4 GB
|
1146 |
+
updated_config.vad_process_timeout = 0
|
1147 |
+
except Exception as e:
|
1148 |
+
print(traceback.format_exc())
|
1149 |
+
print("Error detect vram: " + str(e))
|
1150 |
|
1151 |
if (threads := args.pop("threads")) > 0:
|
1152 |
torch.set_num_threads(threads)
|
config.json5
CHANGED
@@ -193,10 +193,15 @@
|
|
193 |
}
|
194 |
],
|
195 |
"ALMA": [
|
|
|
|
|
|
|
|
|
|
|
196 |
{
|
197 |
"name": "ALMA-13B-GPTQ/TheBloke",
|
198 |
"url": "TheBloke/ALMA-13B-GPTQ",
|
199 |
-
"type": "huggingface"
|
200 |
},
|
201 |
]
|
202 |
},
|
|
|
193 |
}
|
194 |
],
|
195 |
"ALMA": [
|
196 |
+
{
|
197 |
+
"name": "ALMA-7B-GPTQ/TheBloke",
|
198 |
+
"url": "TheBloke/ALMA-7B-GPTQ",
|
199 |
+
"type": "huggingface"
|
200 |
+
},
|
201 |
{
|
202 |
"name": "ALMA-13B-GPTQ/TheBloke",
|
203 |
"url": "TheBloke/ALMA-13B-GPTQ",
|
204 |
+
"type": "huggingface"
|
205 |
},
|
206 |
]
|
207 |
},
|
docs/options.md
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Standard Options
|
2 |
To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
|
3 |
supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
|
4 |
in the file selector to select any file type, including video files) or use the microphone.
|
@@ -154,29 +154,29 @@ The minimum number of speakers for Pyannote to detect.
|
|
154 |
The maximum number of speakers for Pyannote to detect.
|
155 |
|
156 |
## Repetition Penalty
|
157 |
-
- ctranslate2: repetition_penalty
|
158 |
This parameter only takes effect in [faster-whisper (ctranslate2)](https://github.com/SYSTRAN/faster-whisper/issues/478).
|
159 |
Penalty applied to the score of previously generated tokens (set > 1 to penalize).
|
160 |
|
161 |
## No Repeat Ngram Size
|
162 |
-
- ctranslate2: no_repeat_ngram_size
|
163 |
This parameter only takes effect in [faster-whisper (ctranslate2)](https://github.com/SYSTRAN/faster-whisper/issues/478).
|
164 |
Prevent repetitions of ngrams with this size (set 0 to disable).
|
165 |
|
166 |
## Translation - Batch Size
|
167 |
-
- transformers: batch_size
|
168 |
When the pipeline will use DataLoader (when passing a dataset, on GPU for a Pytorch model), the size of the batch to use, for inference this is not always beneficial.
|
169 |
-
- ctranslate2: max_batch_size
|
170 |
The maximum batch size.
|
171 |
|
172 |
## Translation - No Repeat Ngram Size
|
173 |
-
- transformers: no_repeat_ngram_size
|
174 |
Value that will be used by default in the generate method of the model for no_repeat_ngram_size. If set to int > 0, all ngrams of that size can only occur once.
|
175 |
-
- ctranslate2: no_repeat_ngram_size
|
176 |
Prevent repetitions of ngrams with this size (set 0 to disable).
|
177 |
|
178 |
## Translation - Num Beams
|
179 |
-
- transformers: num_beams
|
180 |
Number of beams for beam search that will be used by default in the generate method of the model. 1 means no beam search.
|
181 |
-
- ctranslate2: beam_size
|
182 |
Beam size (1 for greedy search).
|
|
|
1 |
+
# Standard Options
|
2 |
To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
|
3 |
supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
|
4 |
in the file selector to select any file type, including video files) or use the microphone.
|
|
|
154 |
The maximum number of speakers for Pyannote to detect.
|
155 |
|
156 |
## Repetition Penalty
|
157 |
+
- ctranslate2: repetition_penalty
|
158 |
This parameter only takes effect in [faster-whisper (ctranslate2)](https://github.com/SYSTRAN/faster-whisper/issues/478).
|
159 |
Penalty applied to the score of previously generated tokens (set > 1 to penalize).
|
160 |
|
161 |
## No Repeat Ngram Size
|
162 |
+
- ctranslate2: no_repeat_ngram_size
|
163 |
This parameter only takes effect in [faster-whisper (ctranslate2)](https://github.com/SYSTRAN/faster-whisper/issues/478).
|
164 |
Prevent repetitions of ngrams with this size (set 0 to disable).
|
165 |
|
166 |
## Translation - Batch Size
|
167 |
+
- transformers: batch_size
|
168 |
When the pipeline will use DataLoader (when passing a dataset, on GPU for a Pytorch model), the size of the batch to use, for inference this is not always beneficial.
|
169 |
+
- ctranslate2: max_batch_size
|
170 |
The maximum batch size.
|
171 |
|
172 |
## Translation - No Repeat Ngram Size
|
173 |
+
- transformers: no_repeat_ngram_size
|
174 |
Value that will be used by default in the generate method of the model for no_repeat_ngram_size. If set to int > 0, all ngrams of that size can only occur once.
|
175 |
+
- ctranslate2: no_repeat_ngram_size
|
176 |
Prevent repetitions of ngrams with this size (set 0 to disable).
|
177 |
|
178 |
## Translation - Num Beams
|
179 |
+
- transformers: num_beams
|
180 |
Number of beams for beam search that will be used by default in the generate method of the model. 1 means no beam search.
|
181 |
+
- ctranslate2: beam_size
|
182 |
Beam size (1 for greedy search).
|
docs/translateModel.md
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Describe
|
3 |
+
|
4 |
+
The `translate` task in `Whisper` only supports translating other languages `into English`. `OpenAI` does not guarantee translations between arbitrary languages. In such cases, you can opt to use the Translation Model for translation tasks. However, it's important to note that the `Translation Model runs very slowly on CPU`, and the completion time may be twice as long as usual. It is recommended to run the Translation Model on devices with `GPUs` for better performance.
|
5 |
+
|
6 |
+
The larger the parameters of the Translation model, the better its translation capability is expected. However, this also requires higher computational resources and slower running speed.
|
7 |
+
|
8 |
+
Currently, when the `Highlight Words timestamps` option is enabled in the Whisper `Word Timestamps options`, it cannot be used simultaneously with the Translation Model. This is because Highlight Words splits the source text, and after translation, it becomes a non-word-level string.
|
9 |
+
|
10 |
+
|
11 |
+
# Translation Model
|
12 |
+
|
13 |
+
The required VRAM is provided for reference and may not apply to everyone. If the model's VRAM requirement exceeds the available capacity of the system, the model will operate on the CPU, resulting in significantly longer execution times.
|
14 |
+
|
15 |
+
[CTranslate2](https://opennmt.net/CTranslate2/guides/transformers.html) is a C++ and Python library for efficient inference with Transformer models. Models converted from CTranslate2 can run with lower resources and faster speed. Encoder-decoder models currently supported: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper.
|
16 |
+
|
17 |
+
## M2M100
|
18 |
+
|
19 |
+
2M100 is a multilingual translation model introduced by Facebook AI in October 2020. It supports arbitrary translation among 101 languages. The paper is titled "`Beyond English-Centric Multilingual Machine Translation`" ([arXiv:2010.11125](https://arxiv.org/abs/2010.11125)).
|
20 |
+
|
21 |
+
| Name | Parameters | Size | type/quantize | Required VRAM |
|
22 |
+
|------|------------|------|---------------|---------------|
|
23 |
+
| [facebook/m2m100_418M](https://huggingface.co/facebook/m2m100_418M) | 480M | 1.94 GB | float32 | ≈2 GB |
|
24 |
+
| [facebook/m2m100_1.2B](https://huggingface.co/facebook/m2m100_1.2B) | 1.2B | 4.96 GB | float32 | ≈5 GB |
|
25 |
+
| [facebook/m2m100-12B-last-ckpt](https://huggingface.co/facebook/m2m100-12B-last-ckpt) | 12B | 47.2 GB | float32 | N/A |
|
26 |
+
|
27 |
+
## M2M100-CTranslate2
|
28 |
+
|
29 |
+
| Name | Parameters | Size | type/quantize | Required VRAM |
|
30 |
+
|------|------------|------|---------------|---------------|
|
31 |
+
| [michaelfeil/ct2fast-m2m100_418M](https://huggingface.co/michaelfeil/ct2fast-m2m100_418M) | 480M | 970 MB | float16 | ≈0.6 GB |
|
32 |
+
| [michaelfeil/ct2fast-m2m100_1.2B](https://huggingface.co/michaelfeil/ct2fast-m2m100_1.2B) | 1.2B | 2.48 GB | float16 | ≈1.3 GB |
|
33 |
+
| [michaelfeil/ct2fast-m2m100-12B-last-ckpt](https://huggingface.co/michaelfeil/ct2fast-m2m100-12B-last-ckpt) | 12B | 23.6 GB | float16 | N/A |
|
34 |
+
|
35 |
+
## NLLB-200
|
36 |
+
|
37 |
+
NLLB-200 is a multilingual translation model introduced by Meta AI in July 2022. It supports arbitrary translation among 202 languages. The paper is titled "`No Language Left Behind: Scaling Human-Centered Machine Translation`" ([arXiv:2207.04672](https://arxiv.org/abs/2207.04672)).
|
38 |
+
|
39 |
+
| Name | Parameters | Size | type/quantize | Required VRAM |
|
40 |
+
|------|------------|------|---------------|---------------|
|
41 |
+
| [facebook/nllb-200-distilled-600M](https://huggingface.co/facebook/nllb-200-distilled-600M) | 600M | 2.46 GB | float32 | ≈2.5 GB |
|
42 |
+
| [facebook/nllb-200-distilled-1.3B](https://huggingface.co/facebook/nllb-200-distilled-1.3B) | 1.3B | 5.48 GB | float32 | ≈5.9 GB |
|
43 |
+
| [facebook/nllb-200-1.3B](https://huggingface.co/facebook/nllb-200-1.3B) | 1.3B | 5.48 GB | float32 | 5.8 GB |
|
44 |
+
| [facebook/nllb-200-3.3B](https://huggingface.co/facebook/nllb-200-3.3B) | 3.3B | 17.58 GB | float32 | 13.4 GB |
|
45 |
+
|
46 |
+
## NLLB-200-CTranslate2
|
47 |
+
|
48 |
+
| Name | Parameters | Size | type/quantize | Required VRAM |
|
49 |
+
|------|------------|------|---------------|---------------|
|
50 |
+
| [michaelfeil/ct2fast-nllb-200-distilled-1.3B](https://huggingface.co/michaelfeil/ct2fast-nllb-200-distilled-1.3B) | 1.3B | 1.38 GB | int8_float16 | ≈1.3 GB |
|
51 |
+
| [michaelfeil/ct2fast-nllb-200-3.3B](https://huggingface.co/michaelfeil/ct2fast-nllb-200-3.3B) | 3.3B | 3.36 GB | int8_float16 | ≈3.2 GB |
|
52 |
+
| [JustFrederik/nllb-200-1.3B-ct2-int8](https://huggingface.co/JustFrederik/nllb-200-1.3B-ct2-int8) | 1.3B | 1.38 GB | int8 | ≈1.3 GB |
|
53 |
+
| [JustFrederik/nllb-200-1.3B-ct2-float16](https://huggingface.co/JustFrederik/nllb-200-1.3B-ct2-float16) | 1.3B | 2.74 GB | float16 | ≈1.3 GB |
|
54 |
+
| [JustFrederik/nllb-200-distilled-600M-ct2](https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2) | 600M | 2.46 GB | float32 | ≈0.6 GB |
|
55 |
+
| [JustFrederik/nllb-200-distilled-600M-ct2-float16](https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-float16) | 600M | 1.23 GB | float16 | ≈0.6 GB |
|
56 |
+
| [JustFrederik/nllb-200-distilled-600M-ct2-int8](https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8) | 600M | 623 MB | int8 | ≈0.6 GB |
|
57 |
+
| [JustFrederik/nllb-200-distilled-1.3B-ct2-float16](https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-float16) | 1.3B | 2.74 GB | float16 | ≈1.3 GB |
|
58 |
+
| [JustFrederik/nllb-200-distilled-1.3B-ct2-int8](https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8) | 1.3B | 1.38 GB | int8 | ≈1.3 GB |
|
59 |
+
| [JustFrederik/nllb-200-distilled-1.3B-ct2](https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2) | 1.3B | 5.49 GB | float32 | ≈1.3 GB |
|
60 |
+
| [JustFrederik/nllb-200-1.3B-ct2](https://huggingface.co/JustFrederik/nllb-200-1.3B-ct2) | 1.3B | 5.49 GB | float32 | ≈1.3 GB |
|
61 |
+
| [JustFrederik/nllb-200-3.3B-ct2-float16](https://huggingface.co/JustFrederik/nllb-200-3.3B-ct2-float16) | 3.3B | 6.69 GB | float16 | ≈3.2 GB |
|
62 |
+
|
63 |
+
## MT5
|
64 |
+
|
65 |
+
mT5 is a multilingual pre-trained Text-to-Text Transformer introduced by Google Research in October 2020. It is a multilingual variant of the T5 model, pre-trained on datasets in 101 languages. Further fine-tuning is required to transform it into a translation model. The paper is titled "`mT5: A Massively Multilingual Pre-trained Text-to-Text Transformer`" ([arXiv:2010.11934](https://arxiv.org/abs/2010.11934)).
|
66 |
+
The 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English.
|
67 |
+
|
68 |
+
| Name | Parameters | Size | type/quantize | Required VRAM |
|
69 |
+
|------|------------|------|---------------|---------------|
|
70 |
+
| [mt5-base](https://huggingface.co/google/mt5-base) | N/A | 2.33 GB | float32 | N/A |
|
71 |
+
| [K024/mt5-zh-ja-en-trimmed](https://huggingface.co/K024/mt5-zh-ja-en-trimmed) | N/A | 1.32 GB | float32 | ≈1.4 GB |
|
72 |
+
| [engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1](https://huggingface.co/engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1) | N/A | 1.32 GB | float32 | ≈1.4 GB |
|
73 |
+
|
74 |
+
## ALMA
|
75 |
+
|
76 |
+
ALMA is a many-to-many LLM-based translation model introduced by Haoran Xu and colleagues in September 2023. It is based on the fine-tuning of a large language model (LLaMA-2). The approach used for this model is referred to as Advanced Language Model-based trAnslator (ALMA). The paper is titled "`A Paradigm Shift in Machine Translation: Boosting Translation Performance of Large Language Models`" ([arXiv:2309.11674](https://arxiv.org/abs/2309.11674)).
|
77 |
+
The official support for ALMA currently includes 10 language directions: English↔German, English↔Czech, English↔Icelandic, English↔Chinese, and English↔Russian. However, the author hints that there might be surprises in other directions, so there are currently no restrictions on the languages that ALMA can be chosen for in the web UI.
|
78 |
+
|
79 |
+
| Name | Parameters | Size | type/quantize | Required VRAM |
|
80 |
+
|------|------------|------|---------------|---------------|
|
81 |
+
| [haoranxu/ALMA-7B](https://huggingface.co/haoranxu/ALMA-7B) | 7B | 26.95 GB | float32 | N/A |
|
82 |
+
| [haoranxu/ALMA-13B](https://huggingface.co/haoranxu/ALMA-13B) | 13B | 52.07 GB | float32 | N/A |
|
83 |
+
|
84 |
+
## ALMA-GPTQ
|
85 |
+
|
86 |
+
GPTQ is a technique used to quantize the parameters of large language models into integer formats such as int8 or int4. Although the quantization process may lead to a loss in model performance, it significantly reduces both file size and the required VRAM.
|
87 |
+
|
88 |
+
| Name | Parameters | Size | type/quantize | Required VRAM |
|
89 |
+
|------|------------|------|---------------|---------------|
|
90 |
+
| [TheBloke/ALMA-7B-GPTQ](https://huggingface.co/TheBloke/ALMA-7B-GPTQ) | 7B | 3.9 GB | 4 Bits | ≈4.3 GB |
|
91 |
+
| [TheBloke/ALMA-13B-GPTQ](https://huggingface.co/TheBloke/ALMA-13B-GPTQ) | 13B | 7.26 GB | 4 Bits | ≈8.1 |
|
92 |
+
|
93 |
+
|
94 |
+
# Options
|
95 |
+
|
96 |
+
## Translation - Batch Size
|
97 |
+
- transformers: batch_size
|
98 |
+
When the pipeline will use DataLoader (when passing a dataset, on GPU for a Pytorch model), the size of the batch to use, for inference this is not always beneficial.
|
99 |
+
- ctranslate2: max_batch_size
|
100 |
+
The maximum batch size.
|
101 |
+
|
102 |
+
## Translation - No Repeat Ngram Size
|
103 |
+
- transformers: no_repeat_ngram_size
|
104 |
+
Value that will be used by default in the generate method of the model for no_repeat_ngram_size. If set to int > 0, all ngrams of that size can only occur once.
|
105 |
+
- ctranslate2: no_repeat_ngram_size
|
106 |
+
Prevent repetitions of ngrams with this size (set 0 to disable).
|
107 |
+
|
108 |
+
## Translation - Num Beams
|
109 |
+
- transformers: num_beams
|
110 |
+
Number of beams for beam search that will be used by default in the generate method of the model. 1 means no beam search.
|
111 |
+
- ctranslate2: beam_size
|
112 |
+
Beam size (1 for greedy search).
|
requirements-fasterWhisper.txt
CHANGED
@@ -15,4 +15,9 @@ sentencepiece
|
|
15 |
intervaltree
|
16 |
srt
|
17 |
torch
|
18 |
-
https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
|
|
|
|
|
|
|
|
|
|
|
|
15 |
intervaltree
|
16 |
srt
|
17 |
torch
|
18 |
+
https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
|
19 |
+
|
20 |
+
# Needed by ALMA(GPTQ)
|
21 |
+
accelerate
|
22 |
+
auto-gptq
|
23 |
+
optimum
|
requirements-whisper.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
git+https://github.com/huggingface/transformers
|
|
|
2 |
git+https://github.com/openai/whisper.git
|
3 |
ffmpeg-python==0.2.0
|
4 |
gradio==3.50.2
|
@@ -13,4 +14,9 @@ sentencepiece
|
|
13 |
intervaltree
|
14 |
srt
|
15 |
torch
|
16 |
-
https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
|
|
|
|
|
|
|
|
|
|
|
|
1 |
git+https://github.com/huggingface/transformers
|
2 |
+
ctranslate2>=3.21.0
|
3 |
git+https://github.com/openai/whisper.git
|
4 |
ffmpeg-python==0.2.0
|
5 |
gradio==3.50.2
|
|
|
14 |
intervaltree
|
15 |
srt
|
16 |
torch
|
17 |
+
https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
|
18 |
+
|
19 |
+
# Needed by ALMA(GPTQ)
|
20 |
+
accelerate
|
21 |
+
auto-gptq
|
22 |
+
optimum
|
requirements.txt
CHANGED
@@ -15,4 +15,9 @@ sentencepiece
|
|
15 |
intervaltree
|
16 |
srt
|
17 |
torch
|
18 |
-
https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
|
|
|
|
|
|
|
|
|
|
|
|
15 |
intervaltree
|
16 |
srt
|
17 |
torch
|
18 |
+
https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
|
19 |
+
|
20 |
+
# Needed by ALMA(GPTQ)
|
21 |
+
accelerate
|
22 |
+
auto-gptq
|
23 |
+
optimum
|
src/config.py
CHANGED
@@ -5,7 +5,7 @@ from typing import List, Dict, Literal
|
|
5 |
|
6 |
|
7 |
class ModelConfig:
|
8 |
-
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None):
|
9 |
"""
|
10 |
Initialize a model configuration.
|
11 |
|
@@ -13,12 +13,17 @@ class ModelConfig:
|
|
13 |
url: URL to download the model from
|
14 |
path: Path to the model file. If not set, the model will be downloaded from the URL.
|
15 |
type: Type of model. Can be whisper or huggingface.
|
|
|
|
|
|
|
|
|
16 |
"""
|
17 |
self.name = name
|
18 |
self.url = url
|
19 |
self.path = path
|
20 |
self.type = type
|
21 |
self.tokenizer_url = tokenizer_url
|
|
|
22 |
|
23 |
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
|
24 |
|
|
|
5 |
|
6 |
|
7 |
class ModelConfig:
|
8 |
+
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None, revision: str = None):
|
9 |
"""
|
10 |
Initialize a model configuration.
|
11 |
|
|
|
13 |
url: URL to download the model from
|
14 |
path: Path to the model file. If not set, the model will be downloaded from the URL.
|
15 |
type: Type of model. Can be whisper or huggingface.
|
16 |
+
revision: [by transformers] The specific model version to use.
|
17 |
+
It can be a branch name, a tag name, or a commit id,
|
18 |
+
since we use a git-based system for storing models and other artifacts on huggingface.co,
|
19 |
+
so revision can be any identifier allowed by git.
|
20 |
"""
|
21 |
self.name = name
|
22 |
self.url = url
|
23 |
self.path = path
|
24 |
self.type = type
|
25 |
self.tokenizer_url = tokenizer_url
|
26 |
+
self.revision = revision
|
27 |
|
28 |
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
|
29 |
|
src/translation/translationLangs.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
|
|
|
|
|
2 |
def __init__(self, code: str, *names: str):
|
3 |
self.code = code
|
4 |
self.names = names
|
@@ -292,12 +294,30 @@ def get_lang_whisper_names():
|
|
292 |
"""Return a list of whisper language names."""
|
293 |
return list(_TO_LANG_NAME_WHISPER.keys())
|
294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
if __name__ == "__main__":
|
296 |
# Test lookup
|
297 |
print("name:Chinese (Traditional)", get_lang_from_nllb_name("Chinese (Traditional)"))
|
298 |
print("name:moldavian", get_lang_from_m2m100_name("moldavian"))
|
299 |
print("code:ja", get_lang_from_whisper_code("ja"))
|
300 |
print("name:English", get_lang_from_nllb_name('English'))
|
|
|
301 |
|
302 |
print(get_lang_m2m100_names(["en", "ja", "zh"]))
|
303 |
-
print(
|
|
|
|
1 |
+
from functools import cmp_to_key
|
2 |
+
|
3 |
+
class Lang():
|
4 |
def __init__(self, code: str, *names: str):
|
5 |
self.code = code
|
6 |
self.names = names
|
|
|
294 |
"""Return a list of whisper language names."""
|
295 |
return list(_TO_LANG_NAME_WHISPER.keys())
|
296 |
|
297 |
+
def sort_lang_by_whisper_codes(specified_order: list = []):
|
298 |
+
def sort_by_whisper_code(lang: TranslationLang, specified_order: list):
|
299 |
+
return (specified_order.index(lang.whisper.code), lang.whisper.names[0]) if lang.whisper.code in specified_order else (len(specified_order), lang.whisper.names[0])
|
300 |
+
|
301 |
+
def cmp_by_whisper_code(lang1: TranslationLang, lang2: TranslationLang):
|
302 |
+
val1 = sort_by_whisper_code(lang1, specified_order)
|
303 |
+
val2 = sort_by_whisper_code(lang2, specified_order)
|
304 |
+
if val1 > val2:
|
305 |
+
return 1
|
306 |
+
elif val1 == val2:
|
307 |
+
return 0
|
308 |
+
else: return -1
|
309 |
+
|
310 |
+
sorted_translations = sorted(_TO_LANG_NAME_WHISPER.values(), key=cmp_to_key(cmp_by_whisper_code))
|
311 |
+
return list({name.lower(): None for language in sorted_translations for name in language.whisper.names}.keys())
|
312 |
+
|
313 |
if __name__ == "__main__":
|
314 |
# Test lookup
|
315 |
print("name:Chinese (Traditional)", get_lang_from_nllb_name("Chinese (Traditional)"))
|
316 |
print("name:moldavian", get_lang_from_m2m100_name("moldavian"))
|
317 |
print("code:ja", get_lang_from_whisper_code("ja"))
|
318 |
print("name:English", get_lang_from_nllb_name('English'))
|
319 |
+
print("\n\n")
|
320 |
|
321 |
print(get_lang_m2m100_names(["en", "ja", "zh"]))
|
322 |
+
print("\n\n")
|
323 |
+
print(sort_lang_by_whisper_codes(["en", "de", "cs", "is", "ru", "zh", "ja"]))
|
src/translation/translationModel.py
CHANGED
@@ -3,11 +3,9 @@ import warnings
|
|
3 |
import huggingface_hub
|
4 |
import requests
|
5 |
import torch
|
6 |
-
|
7 |
import ctranslate2
|
8 |
import transformers
|
9 |
-
|
10 |
-
import re
|
11 |
|
12 |
from typing import Optional
|
13 |
from src.config import ModelConfig
|
@@ -85,84 +83,175 @@ class TranslationModel:
|
|
85 |
self.load_model()
|
86 |
|
87 |
def load_model(self):
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
def release_vram(self):
|
116 |
try:
|
117 |
if torch.cuda.is_available():
|
118 |
if "ct2" not in self.modelPath:
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
del self.transModel
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
print("release vram end.")
|
124 |
except Exception as e:
|
125 |
print("Error release vram: " + str(e))
|
126 |
|
127 |
|
128 |
def translation(self, text: str, max_length: int = 400):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
output = None
|
130 |
result = None
|
131 |
try:
|
132 |
if "ct2" in self.modelPath:
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
elif "mt5" in self.modelPath:
|
138 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
139 |
result = output[0]['generated_text']
|
140 |
elif "ALMA" in self.modelPath:
|
141 |
-
output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.
|
142 |
result = output[0]['generated_text']
|
143 |
-
result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
|
144 |
-
result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
|
145 |
-
return result.strip()
|
146 |
else: #M2M100 & NLLB
|
147 |
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
148 |
result = output[0]['translation_text']
|
149 |
except Exception as e:
|
|
|
150 |
print("Error translation text: " + str(e))
|
151 |
|
152 |
return result
|
153 |
|
154 |
|
155 |
-
_MODELS = ["
|
156 |
-
"
|
157 |
-
"
|
158 |
-
"
|
159 |
-
"nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
|
160 |
-
"nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
|
161 |
-
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
|
162 |
-
"m2m100_1.2B", "m2m100_418M",
|
163 |
-
"mt5-zh-ja-en-trimmed",
|
164 |
-
"mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
165 |
-
"ALMA-13B-GPTQ"]
|
166 |
|
167 |
def check_model_name(name):
|
168 |
return any(allowed_name in name for allowed_name in _MODELS)
|
@@ -230,6 +319,9 @@ def download_model(
|
|
230 |
"allow_patterns": allowPatterns,
|
231 |
#"tqdm_class": disabled_tqdm,
|
232 |
}
|
|
|
|
|
|
|
233 |
|
234 |
if outputDir is not None:
|
235 |
kwargs["local_dir"] = outputDir
|
|
|
3 |
import huggingface_hub
|
4 |
import requests
|
5 |
import torch
|
|
|
6 |
import ctranslate2
|
7 |
import transformers
|
8 |
+
import traceback
|
|
|
9 |
|
10 |
from typing import Optional
|
11 |
from src.config import ModelConfig
|
|
|
83 |
self.load_model()
|
84 |
|
85 |
def load_model(self):
|
86 |
+
"""
|
87 |
+
[from_pretrained]
|
88 |
+
low_cpu_mem_usage(bool, optional)
|
89 |
+
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is an experimental feature and a subject to change at any moment.
|
90 |
+
|
91 |
+
[transformers.AutoTokenizer.from_pretrained]
|
92 |
+
use_fast (bool, optional, defaults to True):
|
93 |
+
Use a fast Rust-based tokenizer if it is supported for a given model.
|
94 |
+
If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
|
95 |
+
|
96 |
+
[transformers.AutoModelForCausalLM.from_pretrained]
|
97 |
+
device_map (str or Dict[str, Union[int, str, torch.device], optional):
|
98 |
+
Sent directly as model_kwargs (just a simpler shortcut). When accelerate library is present,
|
99 |
+
set device_map="auto" to compute the most optimized device_map automatically.
|
100 |
+
revision (str, optional, defaults to "main"):
|
101 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id,
|
102 |
+
since we use a git-based system for storing models and other artifacts on huggingface.co,
|
103 |
+
so revision can be any identifier allowed by git.
|
104 |
+
code_revision (str, optional, defaults to "main")
|
105 |
+
The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model.
|
106 |
+
It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co,
|
107 |
+
so revision can be any identifier allowed by git.
|
108 |
+
trust_remote_code (bool, optional, defaults to False):
|
109 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files.
|
110 |
+
This option should only be set to True for repositories you trust and in which you have read the code,
|
111 |
+
as it will execute code present on the Hub on your local machine.
|
112 |
+
|
113 |
+
[transformers.pipeline "text-generation"]
|
114 |
+
do_sample:
|
115 |
+
if set to True, this parameter enables decoding strategies such as multinomial sampling,
|
116 |
+
beam-search multinomial sampling, Top-K sampling and Top-p sampling.
|
117 |
+
All these strategies select the next token from the probability distribution
|
118 |
+
over the entire vocabulary with various strategy-specific adjustments.
|
119 |
+
temperature (float, optional, defaults to 1.0):
|
120 |
+
The value used to modulate the next token probabilities.
|
121 |
+
top_k (int, optional, defaults to 50):
|
122 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
123 |
+
top_p (float, optional, defaults to 1.0):
|
124 |
+
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
125 |
+
repetition_penalty (float, optional, defaults to 1.0)
|
126 |
+
The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details.
|
127 |
+
"""
|
128 |
+
try:
|
129 |
+
print('\n\nLoading model: %s\n\n' % self.modelPath)
|
130 |
+
if "ct2" in self.modelPath:
|
131 |
+
if any(name in self.modelPath for name in ["nllb", "m2m100"]):
|
132 |
+
if "nllb" in self.modelPath:
|
133 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
|
134 |
+
self.targetPrefix = [self.translationLang.nllb.code]
|
135 |
+
elif "m2m100" in self.modelPath:
|
136 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
|
137 |
+
self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
|
138 |
+
self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
|
139 |
+
elif "ALMA" in self.modelPath:
|
140 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath)
|
141 |
+
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
|
142 |
+
self.transModel = ctranslate2.Generator(self.modelPath, device=self.device)
|
143 |
+
elif "mt5" in self.modelPath:
|
144 |
+
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
|
145 |
+
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
146 |
+
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath, low_cpu_mem_usage=True)
|
147 |
+
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
148 |
+
elif "ALMA" in self.modelPath:
|
149 |
+
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
|
150 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
|
151 |
+
self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=False, revision=self.modelConfig.revision)
|
152 |
+
self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, do_sample=True, temperature=0.7, top_k=40, top_p=0.95, repetition_penalty=1.1)
|
153 |
+
else:
|
154 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
155 |
+
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
156 |
+
if "m2m100" in self.modelPath:
|
157 |
+
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
|
158 |
+
else: #NLLB
|
159 |
+
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
|
160 |
+
|
161 |
+
except Exception as e:
|
162 |
+
print(traceback.format_exc())
|
163 |
+
self.release_vram()
|
164 |
|
165 |
def release_vram(self):
|
166 |
try:
|
167 |
if torch.cuda.is_available():
|
168 |
if "ct2" not in self.modelPath:
|
169 |
+
try:
|
170 |
+
device = torch.device("cpu")
|
171 |
+
self.transModel.to(device)
|
172 |
+
except Exception as e:
|
173 |
+
print(traceback.format_exc())
|
174 |
+
print("\tself.transModel.to cpu, error: " + str(e))
|
175 |
+
del self.transTranslator
|
176 |
+
del self.transTokenizer
|
177 |
del self.transModel
|
178 |
+
try:
|
179 |
+
torch.cuda.empty_cache()
|
180 |
+
except Exception as e:
|
181 |
+
print(traceback.format_exc())
|
182 |
+
print("\tcuda empty cache, error: " + str(e))
|
183 |
+
import gc
|
184 |
+
gc.collect()
|
185 |
print("release vram end.")
|
186 |
except Exception as e:
|
187 |
print("Error release vram: " + str(e))
|
188 |
|
189 |
|
190 |
def translation(self, text: str, max_length: int = 400):
|
191 |
+
"""
|
192 |
+
[ctranslate2]
|
193 |
+
max_batch_size:
|
194 |
+
The maximum batch size. If the number of inputs is greater than max_batch_size,
|
195 |
+
the inputs are sorted by length and split by chunks of max_batch_size examples
|
196 |
+
so that the number of padding positions is minimized.
|
197 |
+
no_repeat_ngram_size:
|
198 |
+
Prevent repetitions of ngrams with this size (set 0 to disable).
|
199 |
+
beam_size:
|
200 |
+
Beam size (1 for greedy search).
|
201 |
+
|
202 |
+
[ctranslate2.Generator.generate_batch]
|
203 |
+
sampling_temperature:
|
204 |
+
Sampling temperature to generate more random samples.
|
205 |
+
sampling_topk:
|
206 |
+
Randomly sample predictions from the top K candidates.
|
207 |
+
sampling_topp:
|
208 |
+
Keep the most probable tokens whose cumulative probability exceeds this value.
|
209 |
+
repetition_penalty:
|
210 |
+
Penalty applied to the score of previously generated tokens (set > 1 to penalize).
|
211 |
+
include_prompt_in_result:
|
212 |
+
Include the start_tokens in the result.
|
213 |
+
If include_prompt_in_result is True (the default), the decoding loop is constrained to generate the start tokens that are then included in the result.
|
214 |
+
If include_prompt_in_result is False, the start tokens are forwarded in the decoder at once to initialize its state (i.e. the KV cache for Transformer models).
|
215 |
+
For variable-length inputs, only the tokens up to the minimum length in the batch are forwarded at once. The remaining tokens are generated in the decoding loop with constrained decoding.
|
216 |
+
|
217 |
+
[transformers.TextGenerationPipeline.__call__]
|
218 |
+
return_full_text (bool, optional, defaults to True):
|
219 |
+
If set to False only added text is returned, otherwise the full text is returned. Only meaningful if return_text is set to True.
|
220 |
+
"""
|
221 |
output = None
|
222 |
result = None
|
223 |
try:
|
224 |
if "ct2" in self.modelPath:
|
225 |
+
if any(name in self.modelPath for name in ["nllb", "m2m100"]):
|
226 |
+
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
|
227 |
+
output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
|
228 |
+
target = output[0].hypotheses[0][1:]
|
229 |
+
result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
|
230 |
+
elif "ALMA" in self.modelPath:
|
231 |
+
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.ALMAPrefix + text + "\n" + self.translationLang.whisper.names[0] + ": "))
|
232 |
+
output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
|
233 |
+
target = output[0]
|
234 |
+
result = self.transTokenizer.decode(target.sequences_ids[0])
|
235 |
elif "mt5" in self.modelPath:
|
236 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
237 |
result = output[0]['generated_text']
|
238 |
elif "ALMA" in self.modelPath:
|
239 |
+
output = self.transTranslator(self.ALMAPrefix + text + "\n" + self.translationLang.whisper.names[0] + ": ", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams, return_full_text=False)
|
240 |
result = output[0]['generated_text']
|
|
|
|
|
|
|
241 |
else: #M2M100 & NLLB
|
242 |
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
243 |
result = output[0]['translation_text']
|
244 |
except Exception as e:
|
245 |
+
print(traceback.format_exc())
|
246 |
print("Error translation text: " + str(e))
|
247 |
|
248 |
return result
|
249 |
|
250 |
|
251 |
+
_MODELS = ["nllb-200",
|
252 |
+
"m2m100",
|
253 |
+
"mt5",
|
254 |
+
"ALMA"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
def check_model_name(name):
|
257 |
return any(allowed_name in name for allowed_name in _MODELS)
|
|
|
319 |
"allow_patterns": allowPatterns,
|
320 |
#"tqdm_class": disabled_tqdm,
|
321 |
}
|
322 |
+
|
323 |
+
if modelConfig.revision is not None:
|
324 |
+
kwargs["revision"] = modelConfig.revision
|
325 |
|
326 |
if outputDir is not None:
|
327 |
kwargs["local_dir"] = outputDir
|
src/utils.py
CHANGED
@@ -130,7 +130,7 @@ def write_srt_original(transcript: Iterator[dict], file: TextIO,
|
|
130 |
flush=True,
|
131 |
)
|
132 |
|
133 |
-
if original is not None: print(f"{original}
|
134 |
file=file,
|
135 |
flush=True)
|
136 |
|
|
|
130 |
flush=True,
|
131 |
)
|
132 |
|
133 |
+
if original is not None: print(f"{original}",
|
134 |
file=file,
|
135 |
flush=True)
|
136 |
|