|
SpeechLM2 |
|
================================ |
|
|
|
.. note:: |
|
The SpeechLM2 collection is still in active development and the code is likely to keep changing. |
|
|
|
SpeechLM2 refers to a collection that augments pre-trained Large Language Models (LLMs) with speech understanding and generation capabilities. |
|
|
|
This collection is designed to be compact, efficient, and to support easy swapping of different LLMs backed by HuggingFace AutoModel. |
|
It has a first-class support for using dynamic batch sizes via Lhotse and various model parallelism techniques (e.g., FSDP2, Tensor Parallel, Sequence Parallel) via PyTorch DTensor API. |
|
|
|
We currently support three main model types: |
|
* SALM (Speech-Augmented Language Model) - a simple but effective approach to augmenting pre-trained LLMs with speech understanding capabilities. |
|
* DuplexS2SModel - a full-duplex speech-to-speech model with an ASR encoder, directly predicting discrete audio codes. |
|
* DuplexS2SSpeechDecoderModel - a variant of DuplexS2SModel with a separate transformer decoder for speech generation. |
|
|
|
Using Pretrained Models |
|
---------------------- |
|
|
|
After :ref:`installing NeMo<installation>`, you can load and use a pretrained speechlm2 model as follows: |
|
|
|
.. code-block:: python |
|
|
|
import nemo.collections.speechlm2 as slm |
|
|
|
|
|
model = slm.models.SALM.from_pretrained("model_name_or_path") |
|
|
|
|
|
model = model.eval() |
|
|
|
Inference with Pretrained Models |
|
-------------------------------- |
|
|
|
SALM |
|
**** |
|
|
|
You can run inference using the loaded pretrained SALM model: |
|
|
|
.. code-block:: python |
|
|
|
import torch |
|
import torchaudio |
|
import nemo.collections.speechlm2 as slm |
|
|
|
model = slm.models.SALM.from_pretrained("path/to/pretrained_checkpoint").eval() |
|
|
|
|
|
audio_path = "path/to/audio.wav" |
|
audio_signal, sample_rate = torchaudio.load(audio_path) |
|
|
|
|
|
if sample_rate != 16000: |
|
audio_signal = torchaudio.functional.resample(audio_signal, sample_rate, 16000) |
|
sample_rate = 16000 |
|
|
|
|
|
audio_signal = audio_signal.to(model.device) |
|
audio_len = torch.tensor([audio_signal.shape[1]], device=model.device) |
|
|
|
|
|
|
|
prompt = [{"role": "user", "content": f"{model.audio_locator_tag}"}] |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
prompts=[prompt], |
|
audios=audio_signal, |
|
audio_lens=audio_len, |
|
generation_config=None |
|
) |
|
|
|
|
|
response = model.tokenizer.ids_to_text(output[0]) |
|
print(f"Model response: {response}") |
|
|
|
DuplexS2SModel |
|
************** |
|
|
|
You can run inference using the loaded pretrained DuplexS2SModel: |
|
|
|
.. code-block:: python |
|
|
|
import torch |
|
import torchaudio |
|
import nemo.collections.speechlm2 as slm |
|
|
|
model = slm.models.DuplexS2SModel.from_pretrained("path/to/pretrained_checkpoint").eval() |
|
|
|
|
|
audio_path = "path/to/audio.wav" |
|
audio_signal, sample_rate = torchaudio.load(audio_path) |
|
|
|
|
|
if sample_rate != 16000: |
|
audio_signal = torchaudio.functional.resample(audio_signal, sample_rate, 16000) |
|
sample_rate = 16000 |
|
|
|
|
|
audio_signal = audio_signal.to(model.device) |
|
audio_len = torch.tensor([audio_signal.shape[1]], device=model.device) |
|
|
|
|
|
results = model.offline_inference( |
|
input_signal=audio_signal, |
|
input_signal_lens=audio_len |
|
) |
|
|
|
|
|
transcription = results["text"][0] |
|
audio = results["audio"][0] |
|
|
|
Training a Model |
|
---------------- |
|
|
|
This example demonstrates how to train a SALM model. The remaining models can be trained in a similar manner. |
|
|
|
.. code-block:: python |
|
|
|
from omegaconf import OmegaConf |
|
import torch |
|
from lightning.pytorch import Trainer |
|
from lightning.pytorch.strategies import ModelParallelStrategy |
|
|
|
import nemo.collections.speechlm2 as slm |
|
from nemo.collections.speechlm2.data import SALMDataset, DataModule |
|
from nemo.utils.exp_manager import exp_manager |
|
|
|
|
|
config_path = "path/to/config.yaml" |
|
cfg = OmegaConf.load(config_path) |
|
|
|
|
|
trainer = Trainer( |
|
max_steps=100000, |
|
accelerator="gpu", |
|
devices=1, |
|
precision="bf16-true", |
|
strategy=ModelParallelStrategy(data_parallel_size=2, tensor_parallel_size=1), |
|
limit_train_batches=1000, |
|
val_check_interval=1000, |
|
use_distributed_sampler=False, |
|
logger=False, |
|
enable_checkpointing=False, |
|
) |
|
|
|
|
|
exp_manager(trainer, cfg.get("exp_manager", None)) |
|
|
|
|
|
model = slm.models.SALM(OmegaConf.to_container(cfg.model, resolve=True)) |
|
|
|
|
|
dataset = SALMDataset(tokenizer=model.tokenizer) |
|
datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) |
|
|
|
|
|
trainer.fit(model, datamodule) |
|
|
|
Example Using Command-Line Training Script |
|
------------------------------------------ |
|
|
|
Alternatively, you can train a model using the provided training scripts in the examples directory: |
|
|
|
.. code-block:: bash |
|
|
|
|
|
python examples/speechlm2/salm_train.py \ |
|
--config-path=examples/speechlm2/conf \ |
|
--config-name=salm |
|
|
|
|
|
python examples/speechlm2/salm_eval.py \ |
|
pretrained_name=/path/to/checkpoint \ |
|
inputs=/path/to/test_manifest \ |
|
batch_size=64 \ |
|
max_new_tokens=128 \ |
|
output_manifest=generations.jsonl |
|
|
|
For more detailed information on training at scale, model parallelism, and SLURM-based training, see :doc:`training and scaling <training_and_scaling>`. |
|
|
|
Collection Structure |
|
------------------ |
|
|
|
The speechlm2 collection is organized into the following key components: |
|
|
|
- **Models**: Contains implementations of DuplexS2SModel, DuplexS2SSpeechDecoderModel, and SALM |
|
- **Modules**: Contains audio perception and speech generation modules |
|
- **Data**: Includes dataset classes and data loading utilities |
|
|
|
SpeechLM2 Documentation |
|
----------------------- |
|
|
|
For more information, see additional sections in the SpeechLM2 docs: |
|
|
|
.. toctree:: |
|
:maxdepth: 1 |
|
|
|
models |
|
datasets |
|
configs |
|
training_and_scaling |
|
|