calculating commited on
Commit
824afbf
·
1 Parent(s): 35d94d0

committing...

Browse files
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch as T
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ import matplotlib.pyplot as plt
7
+ from utils import load_ckpt, print_colored
8
+ from tokenizer import make_tokenizer
9
+ from model import get_hertz_dev_config
10
+ from typing import Tuple
11
+ import numpy as np
12
+ import os
13
+
14
+ # Global variables for model and tokenizer
15
+ global_generator = None
16
+ global_tokenizer = None
17
+ default_audio_path = "testingtesting.wav" # Your default audio file
18
+
19
+ def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]:
20
+ """Initialize the model and tokenizer"""
21
+ global global_generator, global_tokenizer
22
+
23
+ if global_generator is not None and global_tokenizer is not None:
24
+ return global_generator, global_tokenizer
25
+
26
+ device = 'cuda' if T.cuda.is_available() else 'cpu'
27
+ T.cuda.set_device(0) if device == 'cuda' else None
28
+
29
+ print_colored("Initializing model and tokenizer...", "blue")
30
+ global_tokenizer = make_tokenizer(device)
31
+ model_config = get_hertz_dev_config(is_split=False, use_pure_audio_ablation=use_pure_audio_ablation)
32
+
33
+ global_generator = model_config()
34
+ global_generator = global_generator.eval().to(T.bfloat16).to(device)
35
+ print_colored("Model initialization complete!", "green")
36
+
37
+ return global_generator, global_tokenizer
38
+
39
+ def process_audio(audio_path: str, sr: int) -> T.Tensor:
40
+ """Load and preprocess audio file"""
41
+ audio_tensor, sr = torchaudio.load(audio_path)
42
+
43
+
44
+ if audio_tensor.shape[0] == 2:
45
+ audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)
46
+
47
+ if sr != 16000:
48
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
49
+ audio_tensor = resampler(audio_tensor)
50
+
51
+ max_samples = 16000 * 60 * 5 # 5 minutes
52
+ if audio_tensor.shape[1] > max_samples:
53
+ audio_tensor = audio_tensor[:, :max_samples]
54
+
55
+ return audio_tensor.unsqueeze(0)
56
+
57
+ def generate_completion(
58
+ audio_file,
59
+ prompt_len_seconds: float = 3.0,
60
+ num_completions: int = 5,
61
+ generation_seconds: float = 20.0,
62
+ token_temp: float = 0.8,
63
+ categorical_temp: float = 0.5,
64
+ gaussian_temp: float = 0.1,
65
+ progress=gr.Progress(track_tqdm=True)
66
+ ) -> list:
67
+ """Generate audio completions from the input audio"""
68
+ device = 'cuda' if T.cuda.is_available() else 'cpu'
69
+
70
+ # Use existing model and tokenizer
71
+ generator, audio_tokenizer = global_generator, global_tokenizer
72
+
73
+ progress(0, desc="Processing input audio...")
74
+ # Process input audio
75
+ prompt_audio = process_audio(audio_file, sr=16000)
76
+ prompt_len = int(prompt_len_seconds * 8)
77
+
78
+ progress(0.2, desc="Encoding prompt...")
79
+ # Encode prompt
80
+ with T.autocast(device_type='cuda', dtype=T.bfloat16):
81
+ encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))
82
+
83
+ completions = []
84
+ for i in range(num_completions):
85
+ progress((i + 1) / num_completions, desc=f"Generating completion {i+1}/{num_completions}")
86
+
87
+ # Generate completion
88
+ encoded_prompt = encoded_prompt_audio[:, :prompt_len]
89
+ with T.autocast(device_type='cuda', dtype=T.bfloat16):
90
+ completed_audio_batch = generator.completion(
91
+ encoded_prompt,
92
+ temps=(token_temp, (categorical_temp, gaussian_temp)),
93
+ use_cache=True,
94
+ gen_len=int(generation_seconds * 8)
95
+ )
96
+
97
+ decoded_completion = audio_tokenizer.data_from_latent(completed_audio_batch.bfloat16())
98
+
99
+ # Process audio for output
100
+ audio_tensor = decoded_completion.cpu().squeeze()
101
+ if audio_tensor.ndim == 1:
102
+ audio_tensor = audio_tensor.unsqueeze(0)
103
+ audio_tensor = audio_tensor.float()
104
+
105
+ if audio_tensor.abs().max() > 1:
106
+ audio_tensor = audio_tensor / audio_tensor.abs().max()
107
+
108
+ # Trim to include only the generated portion
109
+ output_audio = audio_tensor[:, max(prompt_len*2000 - 16000, 0):]
110
+ completions.append((16000, output_audio.numpy().T))
111
+
112
+ progress(1.0, desc="Generation complete!")
113
+ return completions
114
+
115
+ def create_interface():
116
+ # Initialize model at startup
117
+ init_model()
118
+
119
+ with gr.Blocks(title="Audio Completion Generator") as app:
120
+ gr.Markdown("""
121
+ # Audio Completion Generator
122
+ Upload an audio file (or use the default) and generate AI completions based on the prompt.
123
+ """)
124
+
125
+ with gr.Row():
126
+ with gr.Column():
127
+ # Load the default audio if it exists
128
+ default_value = default_audio_path if os.path.exists(default_audio_path) else None
129
+
130
+ audio_input = gr.Audio(
131
+ label="Input Audio",
132
+ type="filepath",
133
+ sources=["microphone", "upload"],
134
+ value=default_value
135
+ )
136
+
137
+ with gr.Row():
138
+ prompt_len = gr.Slider(
139
+ minimum=1,
140
+ maximum=10,
141
+ value=3,
142
+ step=0.5,
143
+ label="Prompt Length (seconds)"
144
+ )
145
+ default_num_completions = 5
146
+ num_completions = gr.Slider(
147
+ minimum=1,
148
+ maximum=10,
149
+ value=default_num_completions,
150
+ step=1,
151
+ label="Number of Completions"
152
+ )
153
+ gen_length = gr.Slider(
154
+ minimum=5,
155
+ maximum=60,
156
+ value=20,
157
+ step=5,
158
+ label="Generation Length (seconds)"
159
+ )
160
+
161
+ with gr.Row():
162
+ token_temp = gr.Slider(
163
+ minimum=0.1,
164
+ maximum=1.0,
165
+ value=0.8,
166
+ step=0.1,
167
+ label="Token Temperature"
168
+ )
169
+ cat_temp = gr.Slider(
170
+ minimum=0.1,
171
+ maximum=1.0,
172
+ value=0.5,
173
+ step=0.1,
174
+ label="Categorical Temperature"
175
+ )
176
+ gauss_temp = gr.Slider(
177
+ minimum=0.1,
178
+ maximum=1.0,
179
+ value=0.1,
180
+ step=0.1,
181
+ label="Gaussian Temperature"
182
+ )
183
+
184
+ generate_btn = gr.Button("Generate Completions")
185
+ status_text = gr.Markdown("Ready")
186
+
187
+ with gr.Column():
188
+ output_audios = []
189
+ for i in range(10): # Create 10 audio components
190
+ output_audios.append(gr.Audio(
191
+ label=f"Generated Completion {i+1}",
192
+ type="numpy",
193
+ visible=False
194
+ ))
195
+
196
+ def update_visibility(num):
197
+ return [gr.update(visible=(i < num)) for i in range(10)]
198
+
199
+ def generate_with_status(*args):
200
+ status_text.value = "Processing input audio..."
201
+ completions = generate_completion(*args)
202
+ status_text.value = "Generation complete!"
203
+
204
+ # Prepare outputs for all audio components
205
+ outputs = []
206
+ for i in range(10):
207
+ if i < len(completions):
208
+ outputs.append(completions[i])
209
+ else:
210
+ outputs.append(None)
211
+ return outputs
212
+
213
+ # Set initial visibility on load
214
+ app.load(
215
+ fn=update_visibility,
216
+ inputs=[num_completions],
217
+ outputs=output_audios
218
+ )
219
+
220
+ # Update visibility when slider changes
221
+ num_completions.change(
222
+ fn=update_visibility,
223
+ inputs=[num_completions],
224
+ outputs=output_audios
225
+ )
226
+
227
+ generate_btn.click(
228
+ fn=generate_with_status,
229
+ inputs=[
230
+ audio_input,
231
+ prompt_len,
232
+ num_completions,
233
+ gen_length,
234
+ token_temp,
235
+ cat_temp,
236
+ gauss_temp
237
+ ],
238
+ outputs=output_audios
239
+ )
240
+
241
+ return app
242
+
243
+ if __name__ == "__main__":
244
+ app = create_interface()
245
+ app.launch(share=True)
ioblocks.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from functools import partial
3
+ from contextlib import nullcontext
4
+ from typing import List, Tuple
5
+ from math import ceil
6
+
7
+ import torch as T
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.distributed as dist
11
+ from torch import Tensor, int32
12
+ from torch.amp import autocast
13
+
14
+ from einops import rearrange, pack, unpack
15
+
16
+
17
+ from utils import si_module, exists, default, maybe
18
+
19
+
20
+ @si_module
21
+ class GaussianMixtureIOLayer(nn.Module):
22
+ class Config:
23
+ latent_dim: int
24
+ dim: int
25
+ num_components: int
26
+
27
+ def __init__(self, c: Config):
28
+ super().__init__()
29
+ self.latent_dim = c.latent_dim
30
+ self.num_components = c.num_components
31
+ self.input_projection = nn.Linear(c.latent_dim, c.dim)
32
+
33
+ self.fc_loc = nn.Linear(c.dim, c.num_components * c.latent_dim)
34
+ self.fc_scale = nn.Linear(c.dim, c.num_components * c.latent_dim)
35
+ self.fc_weight = nn.Linear(c.dim, c.num_components)
36
+
37
+ def _square_plus(self, x):
38
+ return (x + T.sqrt(T.square(x) + 4)) / 2
39
+
40
+ def input(self, sampled_latents: T.Tensor) -> T.Tensor:
41
+ """Pre-sampled latents T.Tensor (B, L, Z) -> float tensor (B, L, D)"""
42
+ hidden = self.input_projection(sampled_latents)
43
+ return hidden
44
+
45
+ def output(self, h: T.Tensor) -> Tuple[T.Tensor, T.Tensor, T.Tensor]:
46
+ """float tensor (B, L, D) -> Tuple of locs, scales, and weights"""
47
+ batch_size, seq_len, _ = h.shape
48
+
49
+ locs = self.fc_loc(h).view(batch_size, seq_len, self.num_components, self.latent_dim)
50
+ scales = T.clamp(self._square_plus(self.fc_scale(h)), min=1e-6).view(batch_size, seq_len, self.num_components, self.latent_dim)
51
+ weights = self.fc_weight(h).view(batch_size, seq_len, self.num_components)
52
+
53
+ return (locs, scales, weights)
54
+
55
+ def loss(self, data, dataHat):
56
+ locs, scales, weights = dataHat
57
+ log_probs = -0.5 * T.sum(
58
+ (data.unsqueeze(-2) - locs).pow(2) / scales.pow(2) +
59
+ 2 * T.log(scales) +
60
+ T.log(T.tensor(2 * T.pi)),
61
+ dim=-1
62
+ )
63
+ log_weights = F.log_softmax(weights, dim=-1)
64
+ return -T.logsumexp(log_weights + log_probs, dim=-1)
65
+
66
+
67
+ def temp_sample(self, orig_pdist, temp):
68
+ locs, scales, weights = orig_pdist
69
+ if temp is None:
70
+ component_samples = locs + scales * T.randn_like(scales)
71
+ mixture_samples = F.gumbel_softmax(weights, hard=True)
72
+ sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
73
+ elif isinstance(temp, tuple):
74
+ assert len(temp) == 2
75
+ categorical_temp, gaussian_temp = temp
76
+ component_samples = locs + scales * gaussian_temp * T.randn_like(scales)
77
+ mixture_samples = F.gumbel_softmax(weights / categorical_temp, hard=True)
78
+ sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
79
+ else:
80
+ component_samples = locs + scales * temp * T.randn_like(scales)
81
+ mixture_samples = F.gumbel_softmax(weights / temp, hard=True)
82
+ sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
83
+ return sampled
84
+
85
+
86
+ class GPTOutput(nn.Module):
87
+ def __init__(self, dim, vocab_size):
88
+ super().__init__()
89
+ self.output = nn.Linear(dim, vocab_size, bias=False)
90
+
91
+ def forward(self, x):
92
+ return self.output(x)
93
+
94
+
95
+ # helper functions
96
+
97
+ def pack_one(t, pattern):
98
+ return pack([t], pattern)
99
+
100
+ def unpack_one(t, ps, pattern):
101
+ return unpack(t, ps, pattern)[0]
102
+
103
+ def first(l):
104
+ return l[0]
105
+
106
+ def round_up_multiple(num, mult):
107
+ return ceil(num / mult) * mult
108
+
109
+ def get_code_utilization(codes, codebook_size, get_global=False):
110
+ if get_global and dist.is_initialized():
111
+ world_size = dist.get_world_size()
112
+ else:
113
+ world_size = 1
114
+
115
+ if world_size > 1:
116
+ gathered_tokens = [T.zeros_like(codes) for _ in range(world_size)]
117
+ dist.all_gather(gathered_tokens, codes)
118
+ gathered_tokens = T.cat(gathered_tokens, dim=0)
119
+ else:
120
+ gathered_tokens = codes
121
+ unique_tokens = len(T.unique(gathered_tokens))
122
+ code_utilization = unique_tokens / min(gathered_tokens.numel(), codebook_size)
123
+ return code_utilization
124
+
125
+ # tensor helpers
126
+
127
+ def round_ste(z: Tensor) -> Tensor:
128
+ """Round with straight through gradients."""
129
+ zhat = z.round()
130
+ return z + (zhat - z).detach()
131
+
132
+ # main class
133
+ # lucidrains fsq
134
+ @si_module
135
+ class FSQ(nn.Module):
136
+ @property
137
+ def needs_float32_params(self):
138
+ return True
139
+
140
+ class Config:
141
+ levels: List[int]
142
+ dim: int | None = None
143
+ num_codebooks: int = 1
144
+ keep_num_codebooks_dim: bool | None = None
145
+ scale: float | None = None
146
+ allowed_dtypes: Tuple[str, ...] = ('float32', 'float64')
147
+ channel_first: bool = False
148
+ projection_has_bias: bool = True
149
+ return_indices: bool = True
150
+ force_quantization_f32: bool = True
151
+ use_rms: bool = False
152
+
153
+ def __init__(self, c: Config):
154
+ super().__init__()
155
+ _levels = T.tensor(c.levels, dtype=int32)
156
+ self.register_buffer("_levels", _levels, persistent = False)
157
+
158
+ _basis = T.cumprod(T.tensor([1] + c.levels[:-1]), dim=0, dtype=int32)
159
+ self.register_buffer("_basis", _basis, persistent = False)
160
+
161
+ self.scale = c.scale
162
+
163
+ codebook_dim = len(c.levels)
164
+ self.codebook_dim = codebook_dim
165
+
166
+ effective_codebook_dim = codebook_dim * c.num_codebooks
167
+ self.num_codebooks = c.num_codebooks
168
+
169
+ self.allowed_dtypes = []
170
+ for dtype_str in c.allowed_dtypes:
171
+ if hasattr(T, dtype_str):
172
+ self.allowed_dtypes.append(getattr(T, dtype_str))
173
+ else:
174
+ raise ValueError(f"Invalid dtype string: {dtype_str}")
175
+
176
+ self.effective_codebook_dim = effective_codebook_dim
177
+
178
+ keep_num_codebooks_dim = default(c.keep_num_codebooks_dim, c.num_codebooks > 1)
179
+ assert not (c.num_codebooks > 1 and not keep_num_codebooks_dim)
180
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
181
+
182
+ self.dim = default(c.dim, len(_levels) * c.num_codebooks)
183
+
184
+ self.channel_first = c.channel_first
185
+
186
+ has_projections = self.dim != effective_codebook_dim
187
+ self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
188
+ self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
189
+
190
+ self.has_projections = has_projections
191
+
192
+ self.return_indices = c.return_indices
193
+ if c.return_indices:
194
+ self.codebook_size = self._levels.prod().item()
195
+ implicit_codebook = self._indices_to_codes(T.arange(self.codebook_size))
196
+ self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
197
+
198
+ self.allowed_dtypes = c.allowed_dtypes
199
+ self.force_quantization_f32 = c.force_quantization_f32
200
+
201
+ self.latent_loss = None
202
+
203
+ def latent_metric(self, codes, get_global=False):
204
+ return {'code_util_estimate': get_code_utilization(codes, self.codebook_size, get_global)}
205
+
206
+ def repr_from_latent(self, latent):
207
+ return self.indices_to_codes(latent)
208
+
209
+ def bound(self, z, eps: float = 1e-3):
210
+ """ Bound `z`, an array of shape (..., d). """
211
+ half_l = (self._levels - 1) * (1 + eps) / 2
212
+ offset = T.where(self._levels % 2 == 0, 0.5, 0.0)
213
+ shift = (offset / half_l).atanh()
214
+ return (z + shift).tanh() * half_l - offset
215
+
216
+ def quantize(self, z):
217
+ """ Quantizes z, returns quantized zhat, same shape as z. """
218
+ quantized = round_ste(self.bound(z))
219
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
220
+ return quantized / half_width
221
+
222
+ def _scale_and_shift(self, zhat_normalized):
223
+ half_width = self._levels // 2
224
+ return (zhat_normalized * half_width) + half_width
225
+
226
+ def _scale_and_shift_inverse(self, zhat):
227
+ half_width = self._levels // 2
228
+ return (zhat - half_width) / half_width
229
+
230
+ def _indices_to_codes(self, indices):
231
+ level_indices = self.indices_to_level_indices(indices)
232
+ codes = self._scale_and_shift_inverse(level_indices)
233
+ return codes
234
+
235
+ def codes_to_indices(self, zhat):
236
+ """ Converts a `code` to an index in the codebook. """
237
+ assert zhat.shape[-1] == self.codebook_dim
238
+ zhat = self._scale_and_shift(zhat)
239
+ return (zhat * self._basis).sum(dim=-1).to(int32)
240
+
241
+ def indices_to_level_indices(self, indices):
242
+ """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
243
+ indices = rearrange(indices, '... -> ... 1')
244
+ codes_non_centered = (indices // self._basis) % self._levels
245
+ return codes_non_centered
246
+
247
+ def indices_to_codes(self, indices):
248
+ """ Inverse of `codes_to_indices`. """
249
+ assert exists(indices)
250
+
251
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
252
+
253
+ codes = self._indices_to_codes(indices)
254
+
255
+ if self.keep_num_codebooks_dim:
256
+ codes = rearrange(codes, '... c d -> ... (c d)')
257
+
258
+ codes = self.project_out(codes)
259
+
260
+ if is_img_or_video or self.channel_first:
261
+ codes = rearrange(codes, 'b ... d -> b d ...')
262
+
263
+ return codes
264
+
265
+ # @autocast(device_type='cuda', enabled = False)
266
+ def forward(self, z, return_codes=False):
267
+ """
268
+ einstein notation
269
+ b - batch
270
+ n - sequence (or flattened spatial dimensions)
271
+ d - feature dimension
272
+ c - number of codebook dim
273
+ """
274
+
275
+ is_img_or_video = z.ndim >= 4
276
+ need_move_channel_last = is_img_or_video or self.channel_first
277
+
278
+ # standardize image or video into (batch, seq, dimension)
279
+
280
+ if need_move_channel_last:
281
+ z = rearrange(z, 'b d ... -> b ... d')
282
+ z, ps = pack_one(z, 'b * d')
283
+
284
+ assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
285
+
286
+ z = self.project_in(z)
287
+
288
+ z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
289
+
290
+ # whether to force quantization step to be full precision or not
291
+
292
+ force_f32 = self.force_quantization_f32
293
+ quantization_context = partial(autocast, device_type='cuda', enabled = False) if force_f32 else nullcontext
294
+
295
+ with quantization_context():
296
+ orig_dtype = z.dtype
297
+
298
+ if force_f32 and orig_dtype not in self.allowed_dtypes:
299
+ z = z.float()
300
+
301
+ codes = self.quantize(z)
302
+
303
+ # returning indices could be optional
304
+
305
+ indices = None
306
+
307
+ if self.return_indices:
308
+ indices = self.codes_to_indices(codes)
309
+
310
+ codes = rearrange(codes, 'b n c d -> b n (c d)')
311
+
312
+ codes = codes.type(orig_dtype)
313
+
314
+ # project out
315
+ if return_codes:
316
+ return codes, indices
317
+
318
+ out = self.project_out(codes)
319
+
320
+ # reconstitute image or video dimensions
321
+
322
+ if need_move_channel_last:
323
+ out = unpack_one(out, ps, 'b * d')
324
+ out = rearrange(out, 'b ... d -> b d ...')
325
+
326
+ indices = maybe(unpack_one)(indices, ps, 'b * c')
327
+
328
+ if not self.keep_num_codebooks_dim and self.return_indices:
329
+ indices = maybe(rearrange)(indices, '... 1 -> ...')
330
+
331
+ # return quantized output and indices
332
+
333
+ return out, indices
model.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch as T
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from ioblocks import GaussianMixtureIOLayer, FSQ
8
+
9
+ from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm
10
+ from tokenizer import make_tokenizer
11
+
12
+
13
+ from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored
14
+ from utils import load_ckpt
15
+
16
+
17
+ @si_module
18
+ class LatentQuantizer(nn.Module):
19
+ class Config:
20
+ compressor_config: Optional[FSQ.Config] = None
21
+
22
+ dim: Optional[int] = None
23
+ ff_dim: Optional[int] = None
24
+ input_dim: int = None
25
+
26
+ from_pretrained: Optional[Tuple[str, str]] = None
27
+
28
+ def __init__(self, c: Config):
29
+ super().__init__()
30
+
31
+ if exists(c.from_pretrained):
32
+ checkpoint = load_ckpt(*c.from_pretrained)
33
+ else:
34
+ assert exists(c.compressor_config), f'hmm {c}'
35
+
36
+ self.compressor = c.compressor_config()
37
+ self.ffnn = FFNN(c.dim, c.ff_dim)
38
+ self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity()
39
+
40
+ if exists(c.from_pretrained):
41
+ self.load_state_dict(checkpoint)
42
+
43
+ @T.no_grad()
44
+ def forward(self, x, return_latent=False, known_latent=None):
45
+ """
46
+ x: (B, S, D)
47
+ """
48
+ if exists(known_latent):
49
+ return self.compressor.indices_to_codes(known_latent)
50
+
51
+ x = self.input(x)
52
+ x = self.ffnn(x)
53
+ x, tokens = self.compressor(x)
54
+
55
+ if return_latent:
56
+ return x, tokens
57
+ return x
58
+
59
+
60
+ @si_module
61
+ class TransformerVAE(nn.Module):
62
+ class Config:
63
+ io_config: Optional[GaussianMixtureIOLayer.Config] = None
64
+ stack_config: Optional[Stack.Config] = None
65
+ quantizer_config: Optional[LatentQuantizer.Config] = None
66
+
67
+ plex_layer: int = None
68
+ plex_roll: int = 1
69
+ split: bool = True
70
+
71
+ from_pretrained: Optional[Tuple[str, str]] = None
72
+
73
+ def __init__(self, c: Config):
74
+ super().__init__()
75
+
76
+ if exists(c.from_pretrained):
77
+ checkpoint = load_ckpt(*c.from_pretrained)
78
+ else:
79
+ assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}'
80
+
81
+ self.io = c.io_config()
82
+ self.stack = c.stack_config()
83
+
84
+ self.plex_layer = c.stack_config.layers//2
85
+ self.plex_roll = c.plex_roll
86
+ self.plex_dim = c.quantizer_config.dim
87
+
88
+ assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}'
89
+ self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim)
90
+ self.out_norm = Norm(c.stack_config.dim)
91
+
92
+ if c.split:
93
+ self.io2 = c.io_config()
94
+ self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim)
95
+
96
+ self.io2.fc_loc = None
97
+ self.io2.fc_scale = None
98
+ self.io2.fc_weight = None
99
+
100
+ kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
101
+ head_dim = c.stack_config.dim // c.stack_config.n_head
102
+ self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0)
103
+ cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim]
104
+ self.cache_shape = cache_shape
105
+ self.cache = [None] * self.cache_num_layers
106
+
107
+ if exists(c.from_pretrained):
108
+ result = self.load_state_dict(checkpoint, strict=False)
109
+ print0_colored(result, 'yellow')
110
+
111
+ self.quantizer = c.quantizer_config().eval()
112
+ self.quantizer.requires_grad = False
113
+
114
+ @T.no_grad()
115
+ def quantize(self, x):
116
+ if self.c.split:
117
+ x1, x2 = x.chunk(2, dim=-1)
118
+ with T.autocast(device_type='cuda', dtype=T.bfloat16):
119
+ quantized1 = self.quantizer(x1)
120
+ quantized2 = self.quantizer(x2)
121
+ return quantized1, quantized2
122
+ else:
123
+ with T.autocast(device_type='cuda', dtype=T.bfloat16):
124
+ return self.quantizer(x)
125
+
126
+ @T.no_grad()
127
+ def untokenize(self, token_data):
128
+ return self.quantizer(None, known_latent=token_data)
129
+
130
+ def init_cache(self, bsize, device, dtype, length:int=None):
131
+ cache_shape = self.cache_shape.copy()
132
+ cache_shape[1] = length or cache_shape[1]
133
+ self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
134
+
135
+ def deinit_cache(self):
136
+ self.cache = [None] * self.cache_num_layers
137
+
138
+ @T.no_grad()
139
+ def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None):
140
+ if self.c.split:
141
+ x1, x2 = data.chunk(2, dim=-1)
142
+ x = self.io.input(x1) + self.io2.input(x2)
143
+ else:
144
+ x = self.io.input(data)
145
+
146
+ cache_idx = 0
147
+ for l, layer in enumerate(self.stack.layers):
148
+ if l == self.plex_layer:
149
+ if self.c.split:
150
+ plex1, plex2 = self.quantize(data)
151
+ plex1 = T.roll(plex1, -self.c.plex_roll, dims=1)
152
+ plex2 = T.roll(plex2, -self.c.plex_roll, dims=1)
153
+ if exists(next_tokens):
154
+ plex1[:, -1:] = self.untokenize(next_tokens[0])
155
+ plex2[:, -1:] = self.untokenize(next_tokens[1])
156
+ x1 = x + self.plex_projection(plex1)
157
+ x2 = x + self.plex_projection2(plex2)
158
+ else:
159
+ plex = self.quantize(data)
160
+ plex = T.roll(plex, -self.c.plex_roll, dims=1)
161
+ if exists(next_tokens):
162
+ plex[:, -1:] = self.untokenize(next_tokens)
163
+ x = x + self.plex_projection(plex)
164
+
165
+ if l < self.plex_layer:
166
+ x = layer(x, kv=self.cache[l])
167
+ else:
168
+ if self.c.split:
169
+ x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx])
170
+ cache_idx += 1
171
+ x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx])
172
+ cache_idx += 1
173
+ else:
174
+ x = layer(x, kv=self.cache[l])
175
+
176
+ with T.autocast(device_type='cuda', dtype=T.bfloat16):
177
+ if self.c.split:
178
+ x1, x2 = self.out_norm(x1), self.out_norm(x2)
179
+ out1, out2 = self.io.output(x1), self.io.output(x2)
180
+ else:
181
+ x = self.out_norm(x)
182
+ out = self.io.output(x)
183
+
184
+ if isnt(temps):
185
+ if self.c.split:
186
+ return out1, out2
187
+ else:
188
+ return out
189
+ else:
190
+ if self.c.split:
191
+ next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :]
192
+ next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :]
193
+ next_data = T.cat([next_data1, next_data2], dim=-1)
194
+ return next_data
195
+ else:
196
+ next_data = self.io.temp_sample(out, temps)[:, -1:, :]
197
+ return next_data
198
+
199
+ @si_module
200
+ class HertzDevModel(nn.Module):
201
+ class Config:
202
+ dim: int
203
+ vocab_size: int
204
+ stack_config: Optional[Stack.Config] = None
205
+ latent_size: int = 32
206
+
207
+ split: bool = True
208
+
209
+ quantizer_config: Optional[LatentQuantizer.Config] = None
210
+ resynthesizer_config: Optional[TransformerVAE.Config] = None
211
+
212
+ from_pretrained: Optional[Tuple[str, str]] = None
213
+
214
+ def __init__(self, c: Config):
215
+ super().__init__()
216
+
217
+ if exists(c.from_pretrained):
218
+ checkpoint = load_ckpt(*c.from_pretrained)
219
+ else:
220
+ assert (exists(c.stack_config)), f'hmm {c}'
221
+
222
+ self.input = nn.Linear(c.latent_size, c.dim)
223
+ if self.c.split:
224
+ self.input2 = nn.Linear(c.latent_size, c.dim)
225
+
226
+ self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta)
227
+
228
+ self.layers = nn.ModuleList([
229
+ PerfBlock(
230
+ dim=c.stack_config.dim,
231
+ layer_id=l,
232
+ n_head=c.stack_config.n_head,
233
+ kv_heads=c.stack_config.kv_heads,
234
+ ff_dim=c.stack_config.ff_dim,
235
+ eps=c.stack_config.eps,
236
+ shape_rotator=self.shape_rotator,
237
+ ) for l in range(c.stack_config.layers)
238
+ ])
239
+
240
+ self.output = GPTOutput(c.dim, c.vocab_size)
241
+ if self.c.split:
242
+ self.output2 = GPTOutput(c.dim, c.vocab_size)
243
+
244
+ self.cache = [None] * c.stack_config.layers
245
+ self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
246
+ self.head_dim = c.stack_config.dim // c.stack_config.n_head
247
+
248
+ if exists(c.from_pretrained):
249
+ result = self.load_state_dict(checkpoint, strict=False)
250
+ print0_colored(result, 'yellow')
251
+
252
+ self.resynthesizer = c.resynthesizer_config().eval()
253
+ self.resynthesizer.requires_grad = False
254
+
255
+ self.audio_tokenizer = make_tokenizer(device='cpu')
256
+ self.audio_cache = None
257
+ self.audio_latent_cache = None
258
+ self.use_audio_cache = False
259
+
260
+ @T.no_grad()
261
+ def tokenize(self, audio_data):
262
+ orig_audio_shape = audio_data.shape
263
+ if exists(self.audio_cache):
264
+ audio_data = T.cat([self.audio_cache, audio_data], dim=-1)
265
+ self.audio_cache = audio_data[..., -(6*16_000):]
266
+ elif self.use_audio_cache:
267
+ self.audio_cache = audio_data[..., -(6*16_000):]
268
+
269
+ if audio_data.shape[1] == 2:
270
+ enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1])
271
+ enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2])
272
+ return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):]
273
+ else:
274
+ return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):]
275
+
276
+ @T.no_grad()
277
+ def untokenize(self, token_data):
278
+ if exists(self.audio_latent_cache):
279
+ token_data = T.cat([self.audio_latent_cache, token_data], dim=1)
280
+ self.audio_latent_cache = token_data[:, -(6*8):]
281
+ elif self.use_audio_cache:
282
+ self.audio_latent_cache = token_data[:, -(6*8):]
283
+
284
+ if token_data.shape[-1] == 2*self.c.latent_size:
285
+ dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size])
286
+ dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:])
287
+ return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):]
288
+ else:
289
+ return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):]
290
+
291
+ def init_cache(self, bsize, device, dtype, length:int=None):
292
+ cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim]
293
+ self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
294
+ self.resynthesizer.init_cache(bsize, device, dtype, length)
295
+ self.use_audio_cache = True
296
+
297
+ def deinit_cache(self):
298
+ self.cache = [None] * len(self.layers)
299
+ self.resynthesizer.deinit_cache()
300
+ self.audio_cache = None
301
+ self.audio_latent_cache = None
302
+ self.use_audio_cache = False
303
+
304
+ @T.no_grad()
305
+ def forward(self, data):
306
+ if self.c.split:
307
+ x1, x2 = data.chunk(2, dim=-1)
308
+ x = self.input(x1) + self.input2(x2)
309
+ else:
310
+ x = self.input(data)
311
+
312
+ for l, layer in enumerate(self.layers):
313
+ x = layer(x, kv=self.cache[l])
314
+
315
+ if self.c.split:
316
+ return self.output(x), self.output2(x)
317
+ else:
318
+ return self.output(x)
319
+
320
+ @T.no_grad()
321
+ def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))):
322
+ latents_in = self.tokenize(audio_data)
323
+ next_latents = self.next_latent(latents_in, temps)
324
+ next_model_latent = next_latents[..., self.c.latent_size:]
325
+ audio_decoded = self.untokenize(next_model_latent)[..., -2000:]
326
+ return audio_decoded
327
+
328
+
329
+ @T.no_grad()
330
+ def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))):
331
+
332
+ if self.c.split:
333
+ logits1, logits2 = self.forward(model_input)
334
+ next_logits1 = logits1[:, -1]
335
+ next_logits2 = logits2[:, -1]
336
+ next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1)
337
+ next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1)
338
+
339
+ next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1])
340
+ else:
341
+ logits = self.forward(model_input)
342
+ next_logits = logits[:, -1]
343
+ next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1)
344
+
345
+ next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1])
346
+
347
+ return next_input
348
+
349
+
350
+ @T.no_grad()
351
+ def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor:
352
+ """
353
+ only accepts latent-space data.
354
+ """
355
+ if use_cache:
356
+ self.init_cache(data.shape[0], data.device, T.bfloat16)
357
+
358
+ next_input = generated = data
359
+
360
+ target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len)
361
+
362
+ for _ in tqdm0(range(data.shape[1], target_len)):
363
+ model_input = next_input if use_cache else generated
364
+
365
+ next_input = self.next_latent(model_input, temps)
366
+
367
+ generated = T.cat([generated, next_input], dim=1)
368
+
369
+ if use_cache:
370
+ self.deinit_cache()
371
+ return generated
372
+
373
+
374
+
375
+ def get_hertz_dev_config(is_split=True, use_pure_audio_ablation=False):
376
+ if is_split:
377
+ checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')]
378
+ elif not use_pure_audio_ablation:
379
+ checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')]
380
+ else:
381
+ checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_syrup_110000', '353c48f553f1706824c11f3bb6a049e9')]
382
+
383
+ quantizer_config=LatentQuantizer.Config(
384
+ from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'),
385
+ compressor_config=FSQ.Config(
386
+ levels=[8,8,8,8,8],
387
+ dim=2048,
388
+ num_codebooks=1,
389
+ keep_num_codebooks_dim=None,
390
+ scale=None,
391
+ allowed_dtypes=['float32', 'float64', 'bfloat16'],
392
+ channel_first=False,
393
+ projection_has_bias=True,
394
+ return_indices=True,
395
+ force_quantization_f32=True,
396
+ use_rms=False
397
+ ),
398
+ dim=2048,
399
+ ff_dim=8192,
400
+ input_dim=32
401
+ )
402
+
403
+ resynthesizer_config=TransformerVAE.Config(
404
+ io_config=GaussianMixtureIOLayer.Config(
405
+ latent_dim=32,
406
+ dim=4096,
407
+ num_components=8,
408
+ ),
409
+ stack_config=Stack.Config(
410
+ layers=8,
411
+ dim=4096,
412
+ seq_len=8192,
413
+ n_head=16,
414
+ ff_dim=11008,
415
+ kv_heads=16,
416
+ eps=1e-5,
417
+ theta=10_000
418
+ ),
419
+ quantizer_config=quantizer_config,
420
+ plex_layer=None,
421
+ plex_roll=1,
422
+ split=is_split,
423
+ from_pretrained=checkpoints[0],
424
+ )
425
+
426
+ return HertzDevModel.Config(
427
+ dim=4096,
428
+ vocab_size=32_768,
429
+ stack_config=Stack.Config(
430
+ layers=32,
431
+ dim=4096,
432
+ seq_len=2048,
433
+ n_head=32,
434
+ ff_dim=None,
435
+ kv_heads=None,
436
+ eps=1e-5,
437
+ theta=10_000,
438
+ ),
439
+ quantizer_config=quantizer_config,
440
+ resynthesizer_config=resynthesizer_config,
441
+ split=is_split,
442
+ from_pretrained=checkpoints[1],
443
+ )
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchaudio==2.5.1
3
+ einops==0.8.0
4
+ tqdm==4.66.6
5
+ ipython==8.29.0
6
+ numpy==1.26.3
7
+ soundfile==0.12.1
8
+ websockets==13.1
9
+ requests==2.32.3
10
+ sounddevice==0.5.1
11
+ matplotlib==3.9.2
12
+ fastapi==0.115.4
13
+ uvicorn==0.32.0
14
+ gradio==5.5.0
tokenizer.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Union, Tuple, Literal
4
+
5
+ import torch as T
6
+ import torch.nn as nn
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ from utils import load_ckpt
10
+ from utils.interp import print_colored
11
+ from utils import si_module, get_activation
12
+
13
+
14
+
15
+ # Adapted from https://github.com/facebookresearch/AudioDec
16
+
17
+ def Conv1d1x1(in_channels, out_channels, bias=True):
18
+ return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)
19
+
20
+
21
+ class NonCausalConv1d(nn.Module):
22
+ """1D noncausal convolution w/ 2-sides padding."""
23
+
24
+ def __init__(
25
+ self,
26
+ in_channels,
27
+ out_channels,
28
+ kernel_size,
29
+ stride=1,
30
+ padding=-1,
31
+ dilation=1,
32
+ groups=1,
33
+ bias=True):
34
+ super().__init__()
35
+ self.in_channels = in_channels
36
+ self.out_channels = out_channels
37
+ self.kernel_size = kernel_size
38
+ if padding < 0:
39
+ padding = (kernel_size - 1) // 2 * dilation
40
+ self.dilation = dilation
41
+ self.conv = nn.Conv1d(
42
+ in_channels=in_channels,
43
+ out_channels=out_channels,
44
+ kernel_size=kernel_size,
45
+ stride=stride,
46
+ padding=padding,
47
+ dilation=dilation,
48
+ groups=groups,
49
+ bias=bias,
50
+ )
51
+
52
+ def forward(self, x):
53
+ """
54
+ Args:
55
+ x (Tensor): Float tensor variable with the shape (B, C, T).
56
+ Returns:
57
+ Tensor: Float tensor variable with the shape (B, C, T).
58
+ """
59
+ x = self.conv(x)
60
+ return x
61
+
62
+
63
+ class NonCausalConvTranspose1d(nn.Module):
64
+ """1D noncausal transpose convolution."""
65
+
66
+ def __init__(
67
+ self,
68
+ in_channels,
69
+ out_channels,
70
+ kernel_size,
71
+ stride,
72
+ padding=-1,
73
+ output_padding=-1,
74
+ groups=1,
75
+ bias=True,
76
+ ):
77
+ super().__init__()
78
+ if padding < 0:
79
+ padding = (stride+1) // 2
80
+ if output_padding < 0:
81
+ output_padding = 1 if stride % 2 else 0
82
+ self.deconv = nn.ConvTranspose1d(
83
+ in_channels=in_channels,
84
+ out_channels=out_channels,
85
+ kernel_size=kernel_size,
86
+ stride=stride,
87
+ padding=padding,
88
+ output_padding=output_padding,
89
+ groups=groups,
90
+ bias=bias,
91
+ )
92
+
93
+ def forward(self, x):
94
+ """
95
+ Args:
96
+ x (Tensor): Float tensor variable with the shape (B, C, T).
97
+ Returns:
98
+ Tensor: Float tensor variable with the shape (B, C', T').
99
+ """
100
+ x = self.deconv(x)
101
+ return x
102
+
103
+
104
+ class CausalConv1d(NonCausalConv1d):
105
+ def __init__(
106
+ self,
107
+ in_channels,
108
+ out_channels,
109
+ kernel_size,
110
+ stride=1,
111
+ dilation=1,
112
+ groups=1,
113
+ bias=True
114
+ ):
115
+ super(CausalConv1d, self).__init__(
116
+ in_channels=in_channels,
117
+ out_channels=out_channels,
118
+ kernel_size=kernel_size,
119
+ stride=stride,
120
+ padding=0,
121
+ dilation=dilation,
122
+ groups=groups,
123
+ bias=bias,
124
+ )
125
+ self.stride = stride
126
+ self.pad_length = (kernel_size - 1) * dilation
127
+ def forward(self, x):
128
+ pad = nn.ConstantPad1d((self.pad_length, 0), 0.0)
129
+ x = pad(x)
130
+ return self.conv(x)
131
+
132
+
133
+ class CausalConvTranspose1d(NonCausalConvTranspose1d):
134
+ def __init__(
135
+ self,
136
+ in_channels,
137
+ out_channels,
138
+ kernel_size,
139
+ stride,
140
+ bias=True,
141
+ pad_buffer=None,
142
+ ):
143
+ super(CausalConvTranspose1d, self).__init__(
144
+ in_channels=in_channels,
145
+ out_channels=out_channels,
146
+ kernel_size=kernel_size,
147
+ stride=stride,
148
+ padding=0,
149
+ output_padding=0,
150
+ bias=bias,
151
+ )
152
+ self.stride = stride
153
+ self.pad_length = (math.ceil(kernel_size/stride) - 1)
154
+ if pad_buffer is None:
155
+ pad_buffer = T.zeros(1, in_channels, self.pad_length)
156
+ self.register_buffer("pad_buffer", pad_buffer)
157
+
158
+ def forward(self, x):
159
+ pad = nn.ReplicationPad1d((self.pad_length, 0))
160
+ x = pad(x)
161
+ return self.deconv(x)[:, :, self.stride : -self.stride]
162
+
163
+ def inference(self, x):
164
+ x = T.cat((self.pad_buffer, x), -1)
165
+ self.pad_buffer = x[:, :, -self.pad_length:]
166
+ return self.deconv(x)[:, :, self.stride : -self.stride]
167
+
168
+ def reset_buffer(self):
169
+ self.pad_buffer.zero_()
170
+
171
+
172
+ class NonCausalResUnit(nn.Module):
173
+ def __init__(
174
+ self,
175
+ in_channels,
176
+ out_channels,
177
+ kernel_size=7,
178
+ dilation=1,
179
+ bias=False,
180
+ ):
181
+ super().__init__()
182
+ self.activation = nn.ELU()
183
+ self.conv1 = NonCausalConv1d(
184
+ in_channels=in_channels,
185
+ out_channels=out_channels,
186
+ kernel_size=kernel_size,
187
+ stride=1,
188
+ dilation=dilation,
189
+ bias=bias,
190
+ )
191
+ self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
192
+
193
+ def forward(self, x):
194
+ y = self.conv1(self.activation(x))
195
+ y = self.conv2(self.activation(y))
196
+ return x + y
197
+
198
+
199
+ class CausalResUnit(NonCausalResUnit):
200
+ def __init__(
201
+ self,
202
+ in_channels,
203
+ out_channels,
204
+ kernel_size=7,
205
+ dilation=1,
206
+ bias=False,
207
+ ):
208
+ super(CausalResUnit, self).__init__(
209
+ in_channels=in_channels,
210
+ out_channels=out_channels,
211
+ kernel_size=kernel_size,
212
+ dilation=dilation,
213
+ bias=bias,
214
+ )
215
+ self.conv1 = CausalConv1d(
216
+ in_channels=in_channels,
217
+ out_channels=out_channels,
218
+ kernel_size=kernel_size,
219
+ stride=1,
220
+ dilation=dilation,
221
+ bias=bias,
222
+ )
223
+
224
+ def inference(self, x):
225
+ y = self.conv1.inference(self.activation(x))
226
+ y = self.conv2(self.activation(y))
227
+ return x + y
228
+
229
+
230
+ class ResNetBlock(nn.Module):
231
+ def __init__(self,
232
+ in_channels,
233
+ out_channels,
234
+ stride,
235
+ kernel_size=7,
236
+ dilations=(1, 3, 9),
237
+ bias=True,
238
+ mode='encoder',
239
+ ):
240
+ super().__init__()
241
+ assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!"
242
+
243
+ self.mode = mode
244
+ self.stride = stride
245
+
246
+ ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d
247
+
248
+ res_channels = in_channels if mode == 'encoder' else out_channels
249
+
250
+ res_units = [CausalResUnit(
251
+ res_channels,
252
+ res_channels,
253
+ kernel_size=kernel_size,
254
+ dilation=dilation,
255
+ ) for dilation in dilations]
256
+
257
+ if in_channels == out_channels:
258
+ if mode == 'encoder':
259
+ self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride)
260
+ if mode == 'decoder':
261
+ self.upsample = nn.Upsample(scale_factor=stride, mode='nearest')
262
+ conv_unit = nn.Conv1d(
263
+ in_channels=in_channels,
264
+ out_channels=out_channels,
265
+ kernel_size=1,
266
+ bias=bias,
267
+ ) if in_channels != out_channels else nn.Identity()
268
+ else:
269
+ conv_unit = ConvUnit(
270
+ in_channels=in_channels,
271
+ out_channels=out_channels,
272
+ kernel_size=(2 * stride),
273
+ stride=stride,
274
+ bias=bias,
275
+ )
276
+
277
+ if mode == 'encoder':
278
+ if in_channels == out_channels:
279
+ self.res_block = nn.Sequential(*res_units, self.pool, conv_unit)
280
+ else:
281
+ self.res_block = nn.Sequential(*res_units, conv_unit)
282
+ elif mode == 'decoder':
283
+ if in_channels == out_channels:
284
+ self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units)
285
+ else:
286
+ self.res_block = nn.Sequential(conv_unit, *res_units)
287
+
288
+ def forward(self, x):
289
+ out = x
290
+ for unit in self.res_block:
291
+ out = unit(out)
292
+ return out
293
+
294
+ def inference(self, x):
295
+ for unit in self.res_block:
296
+ x = unit.inference(x)
297
+ return x
298
+
299
+
300
+
301
+
302
+ @si_module
303
+ class ResNetStack(nn.Module):
304
+ """
305
+ ResNet encoder or decoder stack. Channel ratios
306
+ and strides take the default order of from
307
+ data/io-layer, to the middle of the model.
308
+ """
309
+ class Config:
310
+ input_channels: int = 1
311
+ output_channels: int = 1
312
+ encode_channels: int = 32
313
+ decode_channel_multiplier: int = 1
314
+ latent_dim: int = None
315
+ kernel_size: int = 7
316
+ bias: bool = True
317
+ channel_ratios: Tuple[int, ...] = (2, 4, 8, 16)
318
+ strides: Tuple[int, ...] = (3, 4, 5, 5)
319
+ mode: Literal['encoder', 'decoder'] = 'encoder'
320
+
321
+ def __init__(self, c: Config):
322
+ super().__init__()
323
+ assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!"
324
+
325
+ self.mode = c.mode
326
+
327
+ assert len(c.channel_ratios) == len(c.strides)
328
+ channel_ratios = (1,) + c.channel_ratios
329
+ strides = c.strides
330
+ self.middle_channels = c.encode_channels * channel_ratios[-1]
331
+ if c.mode == 'decoder':
332
+ channel_ratios = tuple(reversed(channel_ratios))
333
+ strides = tuple(reversed(strides))
334
+
335
+ self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1
336
+ res_blocks = [ResNetBlock(
337
+ c.encode_channels * channel_ratios[s_idx] * self.multiplier,
338
+ c.encode_channels * channel_ratios[s_idx+1] * self.multiplier,
339
+ stride,
340
+ kernel_size=c.kernel_size,
341
+ bias=c.bias,
342
+ mode=c.mode,
343
+ ) for s_idx, stride in enumerate(strides)]
344
+
345
+ data_conv = CausalConv1d(
346
+ in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier,
347
+ out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels,
348
+ kernel_size=c.kernel_size,
349
+ stride=1,
350
+ bias=False,
351
+ )
352
+
353
+ if c.mode == 'encoder':
354
+ self.res_stack = nn.Sequential(data_conv, *res_blocks)
355
+ elif c.mode == 'decoder':
356
+ self.res_stack = nn.Sequential(*res_blocks, data_conv)
357
+
358
+ if c.latent_dim is not None:
359
+ self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias)
360
+ if self.multiplier != 1:
361
+ self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias)
362
+
363
+ def forward(self, x, return_feats=False):
364
+ if self.c.latent_dim is not None and self.mode == 'decoder':
365
+ x = self.latent_proj(x)
366
+ if self.multiplier != 1:
367
+ x = self.multiplier_proj(x)
368
+
369
+ feats = []
370
+ for block in self.res_stack:
371
+ x = block(x)
372
+ if return_feats:
373
+ feats.append(x)
374
+ if self.c.latent_dim is not None and self.mode == 'encoder':
375
+ x = self.latent_proj(x)
376
+ if return_feats:
377
+ feats.append(x)
378
+ if return_feats:
379
+ return feats
380
+ return x
381
+
382
+ def inference(self, x):
383
+ for block in self.res_stack:
384
+ x = block.inference(x)
385
+ return x
386
+
387
+ def reset_buffer(self):
388
+ def _reset_buffer(m):
389
+ if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d):
390
+ m.reset_buffer()
391
+ self.apply(_reset_buffer)
392
+
393
+ def reset_parameters(self):
394
+ def _reset_parameters(m):
395
+ if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
396
+ m.weight.data.normal_(0.0, 0.01)
397
+
398
+ self.apply(_reset_parameters)
399
+
400
+
401
+ def apply_weight_norm(self):
402
+ def _apply_weight_norm(m):
403
+ if isinstance(m, nn.Conv1d) or isinstance(
404
+ m, nn.ConvTranspose1d
405
+ ):
406
+ nn.utils.parametrizations.weight_norm(m)
407
+
408
+ self.apply(_apply_weight_norm)
409
+
410
+
411
+ def remove_weight_norm(self):
412
+ def _remove_weight_norm(m):
413
+ try:
414
+ print(m)
415
+ nn.utils.remove_weight_norm(m)
416
+ except ValueError: # this module didn't have weight norm
417
+ return
418
+
419
+ self.apply(_remove_weight_norm)
420
+
421
+
422
+
423
+ @si_module
424
+ class GaussianZ(nn.Module):
425
+ class Config:
426
+ dim: int
427
+ latent_dim: int
428
+ bias: bool = False
429
+ use_weight_norm: bool = False
430
+
431
+ def __init__(self, c: Config):
432
+ super().__init__()
433
+
434
+ self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias)
435
+ self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias)
436
+
437
+ if c.use_weight_norm:
438
+ self.proj_in = weight_norm(self.proj_in)
439
+ self.proj_out = weight_norm(self.proj_out)
440
+
441
+ def reparam(self, mu, logvar):
442
+ std = T.exp(logvar / 2)
443
+ eps = T.randn_like(std)
444
+ return mu + eps * std
445
+
446
+ def kl_divergence(self, mu, logvar):
447
+ return T.mean(-0.5 * T.sum(
448
+ 1 + logvar - mu.pow(2) - logvar.exp(),
449
+ dim=(1, 2))
450
+ )
451
+
452
+ def repr_from_latent(self, latent: Union[dict, T.Tensor]):
453
+ if isinstance(latent, T.Tensor):
454
+ z = latent
455
+ else:
456
+ z = self.reparam(latent['mu'], latent['logvar'])
457
+ l = self.proj_out(z)
458
+ return l
459
+
460
+ def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]:
461
+ mu, logvar = self.proj_in(x).chunk(2, dim=-1)
462
+ kl_div = self.kl_divergence(mu, logvar)
463
+ z = self.reparam(mu, logvar)
464
+ xhat = self.proj_out(z)
465
+ latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div}
466
+ return xhat, latent
467
+
468
+
469
+
470
+ @si_module
471
+ class WaveCodec(nn.Module):
472
+ class Config:
473
+ resnet_config: ResNetStack.Config = None
474
+ sample_rate: int = 16_000
475
+ use_weight_norm: bool = False
476
+
477
+ compressor_config: dataclass = None
478
+
479
+ norm_stddev: float = 1.0
480
+
481
+ def __init__(self, c: Config):
482
+ super().__init__()
483
+ self.norm_stddev = c.norm_stddev
484
+ self.encoder = c.resnet_config(mode='encoder')
485
+ self.sample_rate = c.sample_rate
486
+
487
+ self.total_stride = 1
488
+ for stride in c.resnet_config.strides:
489
+ self.total_stride *= stride
490
+ self.tokens_per_second = self.sample_rate / self.total_stride
491
+
492
+ self.compressor = c.compressor_config(dim=self.encoder.middle_channels)
493
+
494
+ self.decoder = c.resnet_config(mode='decoder')
495
+
496
+ if c.use_weight_norm:
497
+ self.encoder.apply_weight_norm()
498
+ self.decoder.apply_weight_norm()
499
+ self.encoder.reset_parameters()
500
+ self.decoder.reset_parameters()
501
+
502
+ def encode(self, data):
503
+ return self.encoder(data/self.norm_stddev)
504
+
505
+ def decode(self, latent):
506
+ return self.decoder(latent.transpose(1, 2))*self.norm_stddev
507
+
508
+ @T.no_grad()
509
+ def latent_from_data(self, data, get_parameters=False):
510
+ x = self.encode(data)
511
+ l_in = x.transpose(1, 2)
512
+ l, latent = self.compressor(l_in)
513
+ return latent['z'] if not get_parameters else {
514
+ 'mu': latent['mu'],
515
+ 'logvar': latent['logvar'],
516
+ 'z': latent['z'],
517
+ }
518
+
519
+ @T.no_grad()
520
+ def data_from_latent(self, latent):
521
+ l = self.compressor.repr_from_latent(latent)
522
+ x = self.decode(l)
523
+ return x
524
+
525
+ def process(self, x):
526
+ return self.latent_from_data(x)
527
+
528
+ def unprocess(self, latent):
529
+ return self.data_from_latent(latent)
530
+
531
+ def forward(self, audio_input):
532
+ x = self.encode(audio_input)
533
+
534
+ l_in = x.transpose(1, 2)
535
+ l, latent = self.compressor(l_in)
536
+
537
+ xhat = self.decode(l)
538
+ return xhat, latent
539
+
540
+
541
+
542
+ def make_tokenizer(device='cuda'):
543
+ generator_config = WaveCodec.Config(
544
+ resnet_config=ResNetStack.Config(
545
+ input_channels=1,
546
+ output_channels=1,
547
+ encode_channels=16,
548
+ decode_channel_multiplier=4,
549
+ kernel_size=7,
550
+ bias=True,
551
+ channel_ratios=(4, 8, 16, 16, 16, 16),
552
+ strides=(2, 2, 4, 5, 5, 5),
553
+ mode=None,
554
+ ),
555
+ use_weight_norm=True,
556
+
557
+ compressor_config=GaussianZ.Config(
558
+ dim=None,
559
+ latent_dim=32,
560
+
561
+ bias=True,
562
+ use_weight_norm=True
563
+ ),
564
+
565
+ norm_stddev=0.05,
566
+ )
567
+ checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97")
568
+
569
+ tokenizer = generator_config()
570
+
571
+ load_result = tokenizer.load_state_dict(checkpoint, strict=False)
572
+ print_colored(f"Loaded tokenizer state dict: {load_result}", "grey")
573
+
574
+ tokenizer = tokenizer.eval()
575
+ # Only convert to bfloat16 if using CUDA
576
+ if device == 'cuda':
577
+ tokenizer = tokenizer.bfloat16()
578
+ tokenizer = tokenizer.to(device)
579
+ tokenizer.requires_grad_ = False
580
+ return tokenizer
581
+
transformer.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, MutableMapping
2
+ from typing import Union
3
+ import math
4
+ from contextlib import nullcontext
5
+
6
+ import torch
7
+ import torch as T
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch import Tensor
11
+ from torch.nn.attention import SDPBackend
12
+
13
+ from einops import rearrange
14
+
15
+ from utils import si_module, default, exists, load_ckpt
16
+
17
+ CACHE_FILL_VALUE = -1
18
+
19
+ def get_cache_len(cache: Optional[Tensor]) -> int:
20
+ """
21
+ cache: (batch, seq_len, 2, kv_heads, head_dim)
22
+ """
23
+ if cache is None:
24
+ return 0
25
+ nonzeros = T.any(cache.flatten(2) != CACHE_FILL_VALUE, dim=-1)
26
+ length = nonzeros.sum(dim=-1).int()
27
+ assert T.all(length == length[0])
28
+ return length[0]
29
+
30
+
31
+ def rotate_half(x):
32
+ x1, x2 = x.chunk(2, dim=-1)
33
+ return torch.cat((-x2, x1), dim=-1)
34
+
35
+
36
+ def apply_rotary_pos_emb(x, cos, sin, offset: int = 0):
37
+ assert (
38
+ cos.shape[1] >= offset + x.shape[1]
39
+ ), f"Offset and/or input sequence is too large,\
40
+ \n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}"
41
+
42
+ cos_out = cos[:, offset : offset + x.shape[1], :, :]
43
+ sin_out = sin[:, offset : offset + x.shape[1], :, :]
44
+
45
+ return (x * cos_out) + (rotate_half(x) * sin_out)
46
+
47
+
48
+ # Adapted from https://github.com/foundation-model-stack/foundation-model-stack
49
+ class ShapeRotator:
50
+ def __init__(
51
+ self,
52
+ dim: int,
53
+ end: int,
54
+ theta: float = 10_000,
55
+ ):
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.ratio = theta
59
+ self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {}
60
+ self.max_seq_len_cached: MutableMapping[int, int] = {}
61
+ self.ntk_scaling = False
62
+ self.max_seq_len = end
63
+
64
+ def compute_freqs_cis(self, device, max_seq_len=None):
65
+ alpha = 1
66
+ dev_idx = device.index
67
+ max_seq_len = default(max_seq_len, self.max_seq_len)
68
+
69
+ if dev_idx not in self.cached_freqs:
70
+ self.cached_freqs[dev_idx] = {}
71
+ if dev_idx not in self.max_seq_len_cached:
72
+ self.max_seq_len_cached[dev_idx] = 0
73
+
74
+
75
+ if self.max_seq_len_cached[dev_idx] > 0:
76
+ return 1
77
+ max_seq_len = max(max_seq_len, self.max_seq_len)
78
+
79
+ if (
80
+ 1 in self.cached_freqs[dev_idx]
81
+ and max_seq_len <= self.max_seq_len_cached[dev_idx]
82
+ ):
83
+ return 1
84
+
85
+ ratio = self.ratio
86
+ dim = self.dim
87
+
88
+ freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim))
89
+
90
+ t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype)
91
+ freqs = torch.einsum("i,j->ij", t, freqs)
92
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
93
+
94
+ cos_to_cache = emb.cos()[None, :, None, :]
95
+ sin_to_cache = emb.sin()[None, :, None, :]
96
+
97
+ self.max_seq_len_cached[dev_idx] = max_seq_len
98
+
99
+ self.cached_freqs[dev_idx][alpha] = torch.stack(
100
+ [
101
+ cos_to_cache,
102
+ sin_to_cache,
103
+ ],
104
+ dim=-1,
105
+ )
106
+
107
+ return alpha
108
+
109
+ def rotate(
110
+ self,
111
+ q: Tensor,
112
+ k: Tensor,
113
+ offset: int = 0,
114
+ ) -> Tuple[Tensor, Tensor]:
115
+ """
116
+ Args
117
+ ----
118
+ q : torch.Tensor
119
+ Embedded query tensor, expected size is B x S x H x Eh
120
+ k : torch.Tensor
121
+ Embedded query tensor, expected size is B x S x H x Eh
122
+ """
123
+ assert len(q.size()) == 4
124
+ assert len(k.size()) == 4
125
+
126
+ seq_len = self.max_seq_len
127
+ alpha = self.compute_freqs_cis(q.device, seq_len)
128
+ freqs = self.cached_freqs[q.device.index][alpha]
129
+
130
+ freqs = freqs.float() # 1 L D/2 2 2
131
+ q_out = apply_rotary_pos_emb(q, freqs[..., 0], freqs[..., 1], offset=offset).type_as(q)
132
+ k_out = apply_rotary_pos_emb(k, freqs[..., 0], freqs[..., 1], offset=offset).type_as(k)
133
+
134
+ return q_out.view_as(q), k_out.view_as(k)
135
+
136
+ class Linear(nn.Linear):
137
+ def __init__(self, *args, **kwargs):
138
+ super().__init__(*args, **kwargs, bias=False)
139
+
140
+ class Norm(nn.Module):
141
+ def __init__(self,
142
+ dim: int,
143
+ eps: float = 1e-5,) -> None:
144
+ super().__init__()
145
+ self.eps = eps
146
+ self.weight = nn.Parameter(T.ones((dim,)))
147
+
148
+ def forward(self, input: Tensor) -> Tensor:
149
+ return F.layer_norm(input, (self.weight.shape[0],), weight=self.weight, bias=None, eps=self.eps)
150
+
151
+
152
+ class FFNN(nn.Module):
153
+ def __init__(self,
154
+ dim: int,
155
+ expand_dim: int = None,):
156
+ super().__init__()
157
+ expand_dim = default(expand_dim, 256 * ((int(2 * 4 * dim / 3) + 256 - 1) // 256))
158
+ self.dim = dim
159
+ self.expand_dim = expand_dim
160
+
161
+ self.gateup_proj = Linear(dim, 2*expand_dim)
162
+ self.down_proj = Linear(expand_dim, dim)
163
+
164
+ def forward(self, x):
165
+ gate, up = self.gateup_proj(x).chunk(2, dim=-1)
166
+ return self.down_proj(up * F.silu(gate))
167
+
168
+ class GQA(nn.Module):
169
+ def __init__(self,
170
+ dim: int,
171
+ n_head: int,
172
+ shape_rotator: ShapeRotator,
173
+ kv_heads: Optional[int] = None,
174
+ eps: float = 1e-5,
175
+ causal: bool = True,):
176
+ super().__init__()
177
+ self.n_heads = n_head
178
+ self.kv_heads = default(kv_heads, n_head)
179
+ self.head_dim = dim // n_head
180
+ self.causal = causal
181
+
182
+ self.proj_qkv = Linear(dim, self.head_dim*(n_head+2*self.kv_heads))
183
+
184
+ self.norm_q = Norm(self.head_dim*n_head, eps=eps)
185
+ self.norm_k = Norm(self.head_dim*self.kv_heads, eps=eps)
186
+
187
+ self.attn_out = Linear(dim, dim)
188
+
189
+ self.shape_rotator = shape_rotator
190
+
191
+ def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
192
+ k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
193
+ v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
194
+ with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) if k.device.type == 'cuda' else nullcontext():
195
+ x = F.scaled_dot_product_attention(
196
+ q.transpose(1, 2),
197
+ k.transpose(1, 2),
198
+ v.transpose(1, 2),
199
+ is_causal=False if (q.size(1) != k.size(1)) else self.causal,
200
+ )
201
+ x = x.transpose(1, 2).contiguous()
202
+ return x
203
+
204
+ def _attend(self, q: Tensor, k: Tensor, v: Tensor, kv_cache: Optional[Tensor] = None,):
205
+ cache_len = get_cache_len(kv_cache)
206
+ q, k = self.shape_rotator.rotate(q, k, offset=cache_len)
207
+ if exists(kv_cache):
208
+ k = T.cat([kv_cache[:, :cache_len, 0], k], dim=1)
209
+ v = T.cat([kv_cache[:, :cache_len, 1], v], dim=1)
210
+ kv_cache[:, :k.size(1), 0] = k
211
+ kv_cache[:, :v.size(1), 1] = v
212
+ x = self._sdpa(q, k, v)
213
+ return self.attn_out(rearrange(x, 'b s h d -> b s (h d)'))
214
+
215
+ def _project(self, x):
216
+ full_q, full_k, full_v = self.proj_qkv(x).chunk(3, dim=-1)
217
+ normed_full_q = self.norm_q(full_q).to(full_q.dtype)
218
+ normed_full_k = self.norm_k(full_k).to(full_k.dtype)
219
+
220
+ q = rearrange(normed_full_q, 'b s (h d) -> b s h d', h=self.n_heads)
221
+ k = rearrange(normed_full_k, 'b s (h d) -> b s h d', h=self.kv_heads)
222
+ v = rearrange(full_v, 'b s (h d) -> b s h d', h=self.kv_heads)
223
+ return q, k, v
224
+
225
+ def forward(self,
226
+ x: Tensor,
227
+ kv: Optional[Tensor] = None,):
228
+ """
229
+ x: (B, S, D)
230
+ kv: (B, S, H, D)
231
+ """
232
+ q, k, v = self._project(x)
233
+ return self._attend(q, k, v, kv_cache=kv)
234
+
235
+
236
+ class PreNormAttn(nn.Module):
237
+ def __init__(self,
238
+ dim: int,
239
+ n_head: int,
240
+ shape_rotator: ShapeRotator,
241
+ kv_heads: Optional[int] = None,
242
+ eps: float = 1e-5,
243
+ causal: bool = True,):
244
+ super().__init__()
245
+ self.attn_norm = Norm(dim, eps=eps)
246
+ self.attn = GQA(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
247
+
248
+ def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
249
+ """
250
+ x: (B, S, D)
251
+ kv: (B, S, H, D)
252
+ """
253
+ return x + self.attn(self.attn_norm(x), kv)
254
+
255
+ class PreNormFFNN(nn.Module):
256
+ def __init__(self,
257
+ dim: int,
258
+ ff_dim: int,
259
+ eps: float = 1e-5,):
260
+ super().__init__()
261
+ self.ffnn_norm = Norm(dim, eps=eps)
262
+ self.ffnn = FFNN(dim, ff_dim)
263
+
264
+ def forward(self, x: Tensor) -> Tensor:
265
+ return x + self.ffnn(self.ffnn_norm(x))
266
+
267
+ class Block(nn.Module):
268
+ def __init__(self,
269
+ dim: int,
270
+ layer_id: int = 0,
271
+ n_head: int = 16,
272
+ kv_heads: Optional[int] = None,
273
+ ff_dim: Optional[int] = None,
274
+ eps: float = 1e-5,
275
+ causal: bool = True,
276
+ shape_rotator: ShapeRotator = None):
277
+ super().__init__()
278
+ self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
279
+ self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps)
280
+ self.dim = dim
281
+ self.layer_id = layer_id
282
+ self.head_dim = dim // n_head
283
+ self.expand_dim = self.ffnn.ffnn.expand_dim
284
+
285
+ self.reset_parameters()
286
+
287
+ def reset_parameters(self):
288
+ std = 1.0 / math.sqrt(self.dim)
289
+ nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std)
290
+ nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std)
291
+ nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std)
292
+
293
+ xstd = 1.0 / math.sqrt(self.expand_dim)
294
+ nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd)
295
+
296
+ def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
297
+ """
298
+ x: (B, S, D)
299
+ kv: (B, S, H, D)
300
+ """
301
+ h = self.attn(x, kv)
302
+ out = self.ffnn(h)
303
+ return out
304
+
305
+
306
+
307
+ class GPTOutput(nn.Module):
308
+ def __init__(self, dim, vocab_size):
309
+ super().__init__()
310
+ self.dim = dim
311
+ self.norm = Norm(dim)
312
+ self.output = Linear(dim, vocab_size)
313
+
314
+ self.reset_parameters()
315
+
316
+ def reset_parameters(self):
317
+ std = 1.0 / math.sqrt(self.dim**2)
318
+ nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)
319
+
320
+ def forward(self, x):
321
+ return self.output(self.norm(x))
322
+
323
+ @si_module
324
+ class Stack(nn.Module):
325
+ class Config:
326
+ layers: int
327
+ dim: int
328
+ seq_len: int
329
+ n_head: int = 32
330
+ ff_dim: int = None
331
+ kv_heads: int = None
332
+ eps: float = 1e-5
333
+ theta: Union[int, float] = 10_000
334
+ causal: bool = True
335
+
336
+ from_pretrained: Optional[Tuple[str, int]] = None
337
+
338
+ def __init__(self, c: Config):
339
+ super().__init__()
340
+
341
+ from_pretrained = c.from_pretrained
342
+ if exists(from_pretrained):
343
+ checkpoint = load_ckpt(c.from_pretrained)
344
+
345
+ self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta)
346
+
347
+ self.layers = nn.ModuleList([
348
+ Block(
349
+ dim=c.dim,
350
+ layer_id=l,
351
+ n_head=c.n_head,
352
+ kv_heads=c.kv_heads,
353
+ ff_dim=c.ff_dim,
354
+ eps=c.eps,
355
+ causal=c.causal,
356
+ shape_rotator=self.shape_rotator,
357
+ ) for l in range(c.layers)
358
+ ])
359
+
360
+ kv_heads = c.kv_heads or c.n_head
361
+ head_dim = c.dim // c.n_head
362
+ cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim]
363
+ self.cache_shape = cache_shape
364
+ self.cache = [None] * c.layers
365
+
366
+ if exists(from_pretrained):
367
+ self.load_state_dict(checkpoint)
368
+
369
+ def init_cache(self, bsize, device, dtype, length:int=None):
370
+ if self.cache_shape is None:
371
+ return
372
+ cache_shape = self.cache_shape.copy()
373
+ cache_shape[1] = length or cache_shape[1]
374
+ self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
375
+
376
+ def deinit_cache(self):
377
+ self.cache = [None] * len(self.cache)
378
+
379
+ def forward(self, x: Tensor) -> Tensor:
380
+ for l, layer in enumerate(self.layers):
381
+ x = layer(x, kv=self.cache[l])
382
+ return x
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .blocks import *
2
+ from .dist import *
3
+ from .interp import *
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (211 Bytes). View file
 
utils/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
utils/__pycache__/dist.cpython-310.pyc ADDED
Binary file (3.65 kB). View file
 
utils/__pycache__/interp.cpython-310.pyc ADDED
Binary file (3.82 kB). View file
 
utils/blocks.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import TypeVar, Generic, Type, Optional
3
+ from functools import wraps
4
+ import time
5
+ import random
6
+
7
+ import torch as T
8
+ import torch.nn as nn
9
+
10
+ # @TODO: remove si_module from codebase
11
+ # we use this in our research codebase to make modules from callable configs
12
+ si_module_TpV = TypeVar('si_module_TpV')
13
+ def si_module(cls: Type[si_module_TpV]) -> Type[si_module_TpV]:
14
+ if not hasattr(cls, 'Config') or not isinstance(cls.Config, type):
15
+ class Config:
16
+ pass
17
+ cls.Config = Config
18
+
19
+ cls.Config = dataclass(cls.Config)
20
+
21
+ class ConfigWrapper(cls.Config, Generic[si_module_TpV]):
22
+ def __call__(self, *args, **kwargs) -> si_module_TpV:
23
+ if len(kwargs) > 0:
24
+ config_dict = {field.name: getattr(self, field.name) for field in self.__dataclass_fields__.values()}
25
+ config_dict.update(kwargs)
26
+ new_config = type(self)(**config_dict)
27
+ return cls(new_config)
28
+ else:
29
+ return cls(self, *args)
30
+
31
+ ConfigWrapper.__module__ = cls.__module__
32
+ ConfigWrapper.__name__ = f"{cls.__name__}Config"
33
+ ConfigWrapper.__qualname__ = f"{cls.__qualname__}.Config"
34
+
35
+ cls.Config = ConfigWrapper
36
+
37
+ original_init = cls.__init__
38
+ def new_init(self, *args, **kwargs):
39
+ self.c = next((arg for arg in args if isinstance(arg, cls.Config)), None) or next((arg for arg in kwargs.values() if isinstance(arg, cls.Config)), None)
40
+ original_init(self, *args, **kwargs)
41
+ self.register_buffer('_device_tracker', T.Tensor(), persistent=False)
42
+
43
+ cls.__init__ = new_init
44
+
45
+ @property
46
+ def device(self):
47
+ return self._device_tracker.device
48
+
49
+ @property
50
+ def dtype(self):
51
+ return self._device_tracker.dtype
52
+
53
+ cls.device = device
54
+ cls.dtype = dtype
55
+
56
+ return cls
57
+
58
+
59
+ def get_activation(nonlinear_activation, nonlinear_activation_params={}):
60
+ if hasattr(nn, nonlinear_activation):
61
+ return getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
62
+ else:
63
+ raise NotImplementedError(f"Activation {nonlinear_activation} not found in torch.nn")
64
+
65
+
66
+ def exists(v):
67
+ return v is not None
68
+
69
+ def isnt(v):
70
+ return not exists(v)
71
+
72
+ def truthyexists(v):
73
+ return exists(v) and v is not False
74
+
75
+ def truthyattr(obj, attr):
76
+ return hasattr(obj, attr) and truthyexists(getattr(obj, attr))
77
+
78
+ defaultT = TypeVar('defaultT')
79
+
80
+ def default(*args: Optional[defaultT]) -> Optional[defaultT]:
81
+ for arg in args:
82
+ if exists(arg):
83
+ return arg
84
+ return None
85
+
86
+ def maybe(fn):
87
+ @wraps(fn)
88
+ def inner(x, *args, **kwargs):
89
+ if not exists(x):
90
+ return x
91
+ return fn(x, *args, **kwargs)
92
+ return inner
utils/dist.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch as T
3
+ import re
4
+ from tqdm import tqdm
5
+ from datetime import timedelta
6
+
7
+ import requests
8
+ import hashlib
9
+
10
+ from io import BytesIO
11
+
12
+ def rank0():
13
+ rank = os.environ.get('RANK')
14
+ if rank is None or rank == '0':
15
+ return True
16
+ else:
17
+ return False
18
+
19
+ def local0():
20
+ local_rank = os.environ.get('LOCAL_RANK')
21
+ if local_rank is None or local_rank == '0':
22
+ return True
23
+ else:
24
+ return False
25
+ class tqdm0(tqdm):
26
+ def __init__(self, *args, **kwargs):
27
+ total = kwargs.get('total', None)
28
+ if total is None and len(args) > 0:
29
+ try:
30
+ total = len(args[0])
31
+ except TypeError:
32
+ pass
33
+ if total is not None:
34
+ kwargs['miniters'] = max(1, total // 20)
35
+ super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]')
36
+
37
+ def print0(*args, **kwargs):
38
+ if rank0():
39
+ print(*args, **kwargs)
40
+
41
+ _PRINTED_IDS = set()
42
+
43
+ def printonce(*args, id=None, **kwargs):
44
+ if id is None:
45
+ id = ' '.join(map(str, args))
46
+
47
+ if id not in _PRINTED_IDS:
48
+ print(*args, **kwargs)
49
+ _PRINTED_IDS.add(id)
50
+
51
+ def print0once(*args, **kwargs):
52
+ if rank0():
53
+ printonce(*args, **kwargs)
54
+
55
+ def init_dist():
56
+ if T.distributed.is_initialized():
57
+ print0('Distributed already initialized')
58
+ rank = T.distributed.get_rank()
59
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
60
+ world_size = T.distributed.get_world_size()
61
+ else:
62
+ try:
63
+ rank = int(os.environ['RANK'])
64
+ local_rank = int(os.environ['LOCAL_RANK'])
65
+ world_size = int(os.environ['WORLD_SIZE'])
66
+ device = f'cuda:{local_rank}'
67
+ T.cuda.set_device(device)
68
+ T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device))
69
+ print(f'Rank {rank} of {world_size}.')
70
+ except Exception as e:
71
+ print0once(f'Not initializing distributed env: {e}')
72
+ rank = 0
73
+ local_rank = 0
74
+ world_size = 1
75
+ return rank, local_rank, world_size
76
+
77
+ def load_ckpt(load_from_location, expected_hash=None):
78
+ if local0():
79
+ os.makedirs('ckpt', exist_ok=True)
80
+ url = f"https://ckpt.si.inc/hertz-dev/{load_from_location}.pt"
81
+ save_path = f"ckpt/{load_from_location}.pt"
82
+ if not os.path.exists(save_path):
83
+ response = requests.get(url, stream=True)
84
+ total_size = int(response.headers.get('content-length', 0))
85
+ with open(save_path, 'wb') as f, tqdm(total=total_size, desc=f'Downloading {load_from_location}.pt', unit='GB', unit_scale=1/(1024*1024*1024)) as pbar:
86
+ for chunk in response.iter_content(chunk_size=8192):
87
+ f.write(chunk)
88
+ pbar.update(len(chunk))
89
+ if expected_hash is not None:
90
+ with open(save_path, 'rb') as f:
91
+ file_hash = hashlib.md5(f.read()).hexdigest()
92
+ if file_hash != expected_hash:
93
+ print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.')
94
+ os.remove(save_path)
95
+ return load_ckpt(load_from_location, expected_hash)
96
+ if T.distributed.is_initialized():
97
+ T.distributed.barrier() # so that ranks don't try to load checkpoint before it's finished downloading
98
+ loaded = T.load(f"ckpt/{load_from_location}.pt", weights_only=False, map_location='cpu')
99
+ return loaded
utils/interp.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as T
2
+ import os
3
+
4
+ def rank0():
5
+ rank = os.environ.get('RANK')
6
+ if rank is None or rank == '0':
7
+ return True
8
+ else:
9
+ return False
10
+
11
+ def print_colored(message, color='reset', bold=False, **kwargs):
12
+ color_dict = {
13
+ 'bold': '\033[1m',
14
+ 'green': '\033[92m',
15
+ 'yellow': '\033[93m',
16
+ 'red': '\033[91m',
17
+ 'blue': '\033[94m',
18
+ 'grey': '\033[90m',
19
+ 'white': '\033[97m',
20
+ 'reset': '\033[0m'
21
+ }
22
+
23
+ color_code = color_dict.get(color.lower(), color_dict['reset'])
24
+ prefix = color_dict['bold'] if bold else ''
25
+ print(f"{prefix}{color_code}{message}{color_dict['reset']}", **kwargs)
26
+
27
+ def print0_colored(*args, **kwargs):
28
+ if rank0():
29
+ print_colored(*args, **kwargs)
30
+
31
+ def param_count(module):
32
+ def count_parameters(model):
33
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
34
+
35
+ total_params = count_parameters(module)
36
+ output = [f'Total model parameters: {total_params:,}', '---------------------------']
37
+
38
+ for name, child in module.named_children():
39
+ params = count_parameters(child)
40
+ output.append(f'{name} parameters: {params:,}')
41
+
42
+ return '\n'.join(output)
43
+
44
+ def model_size_estimation(module):
45
+ def estimate_size(model):
46
+ param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
47
+ buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
48
+ return param_size + buffer_size
49
+
50
+ total_size = estimate_size(module)
51
+ output = [f'Total model size: {total_size / 1024**2:.2f} MB', '---------------------------']
52
+
53
+ for name, child in module.named_children():
54
+ child_size = estimate_size(child)
55
+ output.append(f'{name} size: {child_size / 1024**2:.2f} MB')
56
+
57
+ return '\n'.join(output)
58
+
59
+ def layer_param_distribution(module):
60
+ def count_parameters(model):
61
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
62
+
63
+ def get_layer_types(model):
64
+ layer_types = {}
65
+ for name, module in model.named_modules():
66
+ layer_type = module.__class__.__name__
67
+ params = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad)
68
+ if params > 0:
69
+ if layer_type not in layer_types:
70
+ layer_types[layer_type] = 0
71
+ layer_types[layer_type] += params
72
+ return layer_types
73
+
74
+ total_params = count_parameters(module)
75
+ layer_types = get_layer_types(module)
76
+
77
+ output = [f'Total trainable parameters: {total_params:,}', '---------------------------']
78
+
79
+ for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
80
+ percentage = (count / total_params) * 100
81
+ output.append(f'{layer_type}: {count:,} ({percentage:.2f}%)')
82
+
83
+ return '\n'.join(output)
84
+