File size: 6,814 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
SpeechLM2
================================

.. note::
   The SpeechLM2 collection is still in active development and the code is likely to keep changing.

SpeechLM2 refers to a collection that augments pre-trained Large Language Models (LLMs) with speech understanding and generation capabilities. 

This collection is designed to be compact, efficient, and to support easy swapping of different LLMs backed by HuggingFace AutoModel. 
It has a first-class support for using dynamic batch sizes via Lhotse and various model parallelism techniques (e.g., FSDP2, Tensor Parallel, Sequence Parallel) via PyTorch DTensor API.

We currently support three main model types:
* SALM (Speech-Augmented Language Model) - a simple but effective approach to augmenting pre-trained LLMs with speech understanding capabilities.
* DuplexS2SModel - a full-duplex speech-to-speech model with an ASR encoder, directly predicting discrete audio codes.
* DuplexS2SSpeechDecoderModel - a variant of DuplexS2SModel with a separate transformer decoder for speech generation.

Using Pretrained Models
----------------------

After :ref:`installing NeMo<installation>`, you can load and use a pretrained speechlm2 model as follows:

.. code-block:: python

    import nemo.collections.speechlm2 as slm
    
    # Load a pretrained SALM model
    model = slm.models.SALM.from_pretrained("model_name_or_path")

    # Set model to evaluation mode
    model = model.eval()

Inference with Pretrained Models
--------------------------------

SALM
****

You can run inference using the loaded pretrained SALM model:

.. code-block:: python

    import torch
    import torchaudio
    import nemo.collections.speechlm2 as slm

    model = slm.models.SALM.from_pretrained("path/to/pretrained_checkpoint").eval()
    
    # Load audio file
    audio_path = "path/to/audio.wav"
    audio_signal, sample_rate = torchaudio.load(audio_path)
    
    # Resample if needed
    if sample_rate != 16000:  # Most models expect 16kHz audio
        audio_signal = torchaudio.functional.resample(audio_signal, sample_rate, 16000)
        sample_rate = 16000
    
    # Prepare audio for model
    audio_signal = audio_signal.to(model.device)
    audio_len = torch.tensor([audio_signal.shape[1]], device=model.device)
    
    # Create a prompt for SALM model inference
    # The audio_locator_tag is a special token that will be replaced with audio embeddings
    prompt = [{"role": "user", "content": f"{model.audio_locator_tag}"}]
    
    # Generate response
    with torch.no_grad():
        output = model.generate(
            prompts=[prompt],
            audios=audio_signal,
            audio_lens=audio_len,
            generation_config=None  # You can customize generation parameters here
        )
    
    # Process the output tokens
    response = model.tokenizer.ids_to_text(output[0])
    print(f"Model response: {response}")

DuplexS2SModel
**************

You can run inference using the loaded pretrained DuplexS2SModel:

.. code-block:: python

    import torch
    import torchaudio
    import nemo.collections.speechlm2 as slm

    model = slm.models.DuplexS2SModel.from_pretrained("path/to/pretrained_checkpoint").eval()
    
    # Load audio file
    audio_path = "path/to/audio.wav"
    audio_signal, sample_rate = torchaudio.load(audio_path)
    
    # Resample if needed
    if sample_rate != 16000:  # Most models expect 16kHz audio
        audio_signal = torchaudio.functional.resample(audio_signal, sample_rate, 16000)
        sample_rate = 16000
    
    # Prepare audio for model
    audio_signal = audio_signal.to(model.device)
    audio_len = torch.tensor([audio_signal.shape[1]], device=model.device)
    
    # Run offline inference
    results = model.offline_inference(
        input_signal=audio_signal,
        input_signal_lens=audio_len
    )

    # Decode text and audio tokens
    transcription = results["text"][0]
    audio = results["audio"][0]

Training a Model
----------------

This example demonstrates how to train a SALM model. The remaining models can be trained in a similar manner.

.. code-block:: python

    from omegaconf import OmegaConf
    import torch
    from lightning.pytorch import Trainer
    from lightning.pytorch.strategies import ModelParallelStrategy
    
    import nemo.collections.speechlm2 as slm
    from nemo.collections.speechlm2.data import SALMDataset, DataModule
    from nemo.utils.exp_manager import exp_manager
    
    # Load configuration
    config_path = "path/to/config.yaml"  # E.g., from examples/speechlm2/conf/salm.yaml
    cfg = OmegaConf.load(config_path)
    
    # Initialize PyTorch Lightning trainer
    trainer = Trainer(
        max_steps=100000,
        accelerator="gpu",
        devices=1,
        precision="bf16-true",
        strategy=ModelParallelStrategy(data_parallel_size=2, tensor_parallel_size=1),
        limit_train_batches=1000,
        val_check_interval=1000,
        use_distributed_sampler=False,
        logger=False,
        enable_checkpointing=False,
    )
    
    # Set up experiment manager for logging
    exp_manager(trainer, cfg.get("exp_manager", None))
    
    # Initialize model with configuration
    model = slm.models.SALM(OmegaConf.to_container(cfg.model, resolve=True))
    
    # Create dataset and datamodule
    dataset = SALMDataset(tokenizer=model.tokenizer)
    datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset)
    
    # Train the model
    trainer.fit(model, datamodule)

Example Using Command-Line Training Script
------------------------------------------

Alternatively, you can train a model using the provided training scripts in the examples directory:

.. code-block:: bash

    # Train a SALM model
    python examples/speechlm2/salm_train.py \
      --config-path=examples/speechlm2/conf \
      --config-name=salm

    # For inference/evaluation 
    python examples/speechlm2/salm_eval.py \
      pretrained_name=/path/to/checkpoint \
      inputs=/path/to/test_manifest \
      batch_size=64 \
      max_new_tokens=128 \
      output_manifest=generations.jsonl

For more detailed information on training at scale, model parallelism, and SLURM-based training, see :doc:`training and scaling <training_and_scaling>`.

Collection Structure
------------------

The speechlm2 collection is organized into the following key components:

- **Models**: Contains implementations of DuplexS2SModel, DuplexS2SSpeechDecoderModel, and SALM
- **Modules**: Contains audio perception and speech generation modules
- **Data**: Includes dataset classes and data loading utilities

SpeechLM2 Documentation
-----------------------

For more information, see additional sections in the SpeechLM2 docs:

.. toctree::
   :maxdepth: 1

   models
   datasets
   configs
   training_and_scaling