victan's picture
Upload seamless_communication/cli/m4t/predict/README.md with huggingface_hub
fd52cc7
|
raw
history blame
No virus
4.15 kB
# Inference with SeamlessM4T models
Refer to the [SeamlessM4T README](../../../../../docs/m4t) for an overview of the M4T models.
Inference is run with the CLI, from the root directory of the repository.
The model can be specified with `--model_name` `seamlessM4T_v2_large`, `seamlessM4T_large` or `seamlessM4T_medium`:
**S2ST**:
```bash
m4t_predict <path_to_input_audio> --task s2st --tgt_lang <tgt_lang> --output_path <path_to_save_audio> --model_name seamlessM4T_large
```
**S2TT**:
```bash
m4t_predict <path_to_input_audio> --task s2tt --tgt_lang <tgt_lang>
```
**T2TT**:
```bash
m4t_predict <input_text> --task t2tt --tgt_lang <tgt_lang> --src_lang <src_lang>
```
**T2ST**:
```bash
m4t_predict <input_text> --task t2st --tgt_lang <tgt_lang> --src_lang <src_lang> --output_path <path_to_save_audio>
```
**ASR**:
```bash
m4t_predict <path_to_input_audio> --task asr --tgt_lang <tgt_lang>
```
Please set --ngram-filtering to True to get the same translation performance as the [demo](https://seamless.metademolab.com/).
The input audio must be 16kHz currently. Here's how you could resample your audio:
```python
import torchaudio
resample_rate = 16000
waveform, sample_rate = torchaudio.load(<path_to_input_audio>)
resampler = torchaudio.transforms.Resample(sample_rate, resample_rate, dtype=waveform.dtype)
resampled_waveform = resampler(waveform)
torchaudio.save(<path_to_resampled_audio>, resampled_waveform, resample_rate)
```
## Inference breakdown
Inference calls for the `Translator` object instantiated with a multitask UnitY or UnitY2 model with the options:
- [`seamlessM4T_v2_large`](https://huggingface.co/facebook/seamless-m4t-v2-large)
- [`seamlessM4T_large`](https://huggingface.co/facebook/seamless-m4t-large)
- [`seamlessM4T_medium`](https://huggingface.co/facebook/seamless-m4t-medium)
and a vocoder:
- `vocoder_v2` for `seamlessM4T_v2_large`.
- `vocoder_36langs` for `seamlessM4T_large` or `seamlessM4T_medium`.
```python
import torch
import torchaudio
from seamless_communication.inference import Translator
# Initialize a Translator object with a multitask model, vocoder on the GPU.
translator = Translator("seamlessM4T_large", "vocoder_36langs", torch.device("cuda:0"), torch.float16)
```
Now `predict()` can be used to run inference as many times on any of the supported tasks.
Given an input audio with `<path_to_input_audio>` or an input text `<input_text>` in `<src_lang>`,
we first set the `text_generation_opts`, `unit_generation_opts` and then translate into `<tgt_lang>` as follows:
## S2ST and T2ST:
```python
# S2ST
text_output, speech_output = translator.predict(
input=<path_to_input_audio>,
task_str="S2ST",
tgt_lang=<tgt_lang>,
text_generation_opts=text_generation_opts,
unit_generation_opts=unit_generation_opts
)
# T2ST
text_output, speech_output = translator.predict(
input=<input_text>,
task_str="T2ST",
tgt_lang=<tgt_lang>,
src_lang=<src_lang>,
text_generation_opts=text_generation_opts,
unit_generation_opts=unit_generation_opts
)
```
Note that `<src_lang>` must be specified for T2ST.
The generated units are synthesized and the output audio file is saved with:
```python
# Save the translated audio generation.
torchaudio.save(
<path_to_save_audio>,
speech_output.audio_wavs[0][0].cpu(),
sample_rate=speech_output.sample_rate,
)
```
## S2TT, T2TT and ASR:
```python
# S2TT
text_output, _ = translator.predict(
input=<path_to_input_audio>,
task_str="S2TT",
tgt_lang=<tgt_lang>,
text_generation_opts=text_generation_opts,
unit_generation_opts=None
)
# ASR
# This is equivalent to S2TT with `<tgt_lang>=<src_lang>`.
text_output, _ = translator.predict(
input=<path_to_input_audio>,
task_str="ASR",
tgt_lang=<src_lang>,
text_generation_opts=text_generation_opts,
unit_generation_opts=None
)
# T2TT
text_output, _ = translator.predict(
input=<input_text>,
task_str="T2TT",
tgt_lang=<tgt_lang>,
src_lang=<src_lang>,
text_generation_opts=text_generation_opts,
unit_generation_opts=None
)
```
Note that `<src_lang>` must be specified for T2TT