File size: 6,814 Bytes
b386992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
SpeechLM2
================================
.. note::
The SpeechLM2 collection is still in active development and the code is likely to keep changing.
SpeechLM2 refers to a collection that augments pre-trained Large Language Models (LLMs) with speech understanding and generation capabilities.
This collection is designed to be compact, efficient, and to support easy swapping of different LLMs backed by HuggingFace AutoModel.
It has a first-class support for using dynamic batch sizes via Lhotse and various model parallelism techniques (e.g., FSDP2, Tensor Parallel, Sequence Parallel) via PyTorch DTensor API.
We currently support three main model types:
* SALM (Speech-Augmented Language Model) - a simple but effective approach to augmenting pre-trained LLMs with speech understanding capabilities.
* DuplexS2SModel - a full-duplex speech-to-speech model with an ASR encoder, directly predicting discrete audio codes.
* DuplexS2SSpeechDecoderModel - a variant of DuplexS2SModel with a separate transformer decoder for speech generation.
Using Pretrained Models
----------------------
After :ref:`installing NeMo<installation>`, you can load and use a pretrained speechlm2 model as follows:
.. code-block:: python
import nemo.collections.speechlm2 as slm
# Load a pretrained SALM model
model = slm.models.SALM.from_pretrained("model_name_or_path")
# Set model to evaluation mode
model = model.eval()
Inference with Pretrained Models
--------------------------------
SALM
****
You can run inference using the loaded pretrained SALM model:
.. code-block:: python
import torch
import torchaudio
import nemo.collections.speechlm2 as slm
model = slm.models.SALM.from_pretrained("path/to/pretrained_checkpoint").eval()
# Load audio file
audio_path = "path/to/audio.wav"
audio_signal, sample_rate = torchaudio.load(audio_path)
# Resample if needed
if sample_rate != 16000: # Most models expect 16kHz audio
audio_signal = torchaudio.functional.resample(audio_signal, sample_rate, 16000)
sample_rate = 16000
# Prepare audio for model
audio_signal = audio_signal.to(model.device)
audio_len = torch.tensor([audio_signal.shape[1]], device=model.device)
# Create a prompt for SALM model inference
# The audio_locator_tag is a special token that will be replaced with audio embeddings
prompt = [{"role": "user", "content": f"{model.audio_locator_tag}"}]
# Generate response
with torch.no_grad():
output = model.generate(
prompts=[prompt],
audios=audio_signal,
audio_lens=audio_len,
generation_config=None # You can customize generation parameters here
)
# Process the output tokens
response = model.tokenizer.ids_to_text(output[0])
print(f"Model response: {response}")
DuplexS2SModel
**************
You can run inference using the loaded pretrained DuplexS2SModel:
.. code-block:: python
import torch
import torchaudio
import nemo.collections.speechlm2 as slm
model = slm.models.DuplexS2SModel.from_pretrained("path/to/pretrained_checkpoint").eval()
# Load audio file
audio_path = "path/to/audio.wav"
audio_signal, sample_rate = torchaudio.load(audio_path)
# Resample if needed
if sample_rate != 16000: # Most models expect 16kHz audio
audio_signal = torchaudio.functional.resample(audio_signal, sample_rate, 16000)
sample_rate = 16000
# Prepare audio for model
audio_signal = audio_signal.to(model.device)
audio_len = torch.tensor([audio_signal.shape[1]], device=model.device)
# Run offline inference
results = model.offline_inference(
input_signal=audio_signal,
input_signal_lens=audio_len
)
# Decode text and audio tokens
transcription = results["text"][0]
audio = results["audio"][0]
Training a Model
----------------
This example demonstrates how to train a SALM model. The remaining models can be trained in a similar manner.
.. code-block:: python
from omegaconf import OmegaConf
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.strategies import ModelParallelStrategy
import nemo.collections.speechlm2 as slm
from nemo.collections.speechlm2.data import SALMDataset, DataModule
from nemo.utils.exp_manager import exp_manager
# Load configuration
config_path = "path/to/config.yaml" # E.g., from examples/speechlm2/conf/salm.yaml
cfg = OmegaConf.load(config_path)
# Initialize PyTorch Lightning trainer
trainer = Trainer(
max_steps=100000,
accelerator="gpu",
devices=1,
precision="bf16-true",
strategy=ModelParallelStrategy(data_parallel_size=2, tensor_parallel_size=1),
limit_train_batches=1000,
val_check_interval=1000,
use_distributed_sampler=False,
logger=False,
enable_checkpointing=False,
)
# Set up experiment manager for logging
exp_manager(trainer, cfg.get("exp_manager", None))
# Initialize model with configuration
model = slm.models.SALM(OmegaConf.to_container(cfg.model, resolve=True))
# Create dataset and datamodule
dataset = SALMDataset(tokenizer=model.tokenizer)
datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset)
# Train the model
trainer.fit(model, datamodule)
Example Using Command-Line Training Script
------------------------------------------
Alternatively, you can train a model using the provided training scripts in the examples directory:
.. code-block:: bash
# Train a SALM model
python examples/speechlm2/salm_train.py \
--config-path=examples/speechlm2/conf \
--config-name=salm
# For inference/evaluation
python examples/speechlm2/salm_eval.py \
pretrained_name=/path/to/checkpoint \
inputs=/path/to/test_manifest \
batch_size=64 \
max_new_tokens=128 \
output_manifest=generations.jsonl
For more detailed information on training at scale, model parallelism, and SLURM-based training, see :doc:`training and scaling <training_and_scaling>`.
Collection Structure
------------------
The speechlm2 collection is organized into the following key components:
- **Models**: Contains implementations of DuplexS2SModel, DuplexS2SSpeechDecoderModel, and SALM
- **Modules**: Contains audio perception and speech generation modules
- **Data**: Includes dataset classes and data loading utilities
SpeechLM2 Documentation
-----------------------
For more information, see additional sections in the SpeechLM2 docs:
.. toctree::
:maxdepth: 1
models
datasets
configs
training_and_scaling
|