tsp-stefano commited on
Commit
8e6392f
Β·
verified Β·
1 Parent(s): f07085a

Initial upload: INT4 RTN quantized Voxtral Mini 4B for Jetson

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tekken.json filter=lfs diff=lfs merge=lfs -text
consolidated.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd25f9d675042c37b0a9b051a5333ef001c129d302461995bdc3e7b321c3b2b6
3
+ size 4382321392
kernels/fused_ops.cu ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Fused CUDA kernels for Voxtral decoder on Jetson Orin Nano (SM87).
2
+ //
3
+ // Three kernels that collapse ~500 PyTorch kernel launches per token into ~80:
4
+ // 1. fused_rmsnorm: 5-6 kernels β†’ 1 per call (52 calls/token)
5
+ // 2. fused_rope_hf: ~14 kernels β†’ 2 per call (26 calls/token)
6
+ // 3. fused_silu_mul: 2 kernels β†’ 1 per call (26 calls/token)
7
+ //
8
+ // Build: JIT via torch.utils.cpp_extension.load() with -arch=sm_87
9
+ // All kernels: fp16 I/O, fp32 internal accumulation where needed.
10
+
11
+ #include <torch/extension.h>
12
+ #include <cuda_fp16.h>
13
+ #include <cuda_runtime.h>
14
+
15
+ #define BLOCK_SIZE 256
16
+ #define WARP_SIZE 32
17
+
18
+ // ============================================================================
19
+ // Warp-level reduction
20
+ // ============================================================================
21
+
22
+ __device__ __forceinline__ float warp_reduce_sum(float val) {
23
+ #pragma unroll
24
+ for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1)
25
+ val += __shfl_xor_sync(0xffffffff, val, offset);
26
+ return val;
27
+ }
28
+
29
+ // ============================================================================
30
+ // Kernel 1: Fused RMSNorm
31
+ // ============================================================================
32
+ //
33
+ // Replaces: x.float() β†’ pow(2) β†’ mean β†’ add(eps) β†’ rsqrt β†’ mul(x) β†’ half β†’ mul(weight)
34
+ // One block per row. 256 threads, each handles dim/256 elements.
35
+ // Warp shuffle + shared memory for cross-warp reduction.
36
+
37
+ __global__ void fused_rmsnorm_kernel(
38
+ const half* __restrict__ x,
39
+ const half* __restrict__ w,
40
+ half* __restrict__ out,
41
+ int dim,
42
+ float eps
43
+ ) {
44
+ const int row = blockIdx.x;
45
+ const half* x_row = x + (int64_t)row * dim;
46
+ half* out_row = out + (int64_t)row * dim;
47
+
48
+ // Phase 1: partial sum of squares (fp32 accumulation)
49
+ float partial = 0.0f;
50
+ for (int i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
51
+ float v = __half2float(x_row[i]);
52
+ partial += v * v;
53
+ }
54
+
55
+ // Phase 2: warp reduction
56
+ partial = warp_reduce_sum(partial);
57
+
58
+ // Phase 3: cross-warp reduction via shared memory
59
+ __shared__ float warp_sums[BLOCK_SIZE / WARP_SIZE];
60
+ const int lane = threadIdx.x % WARP_SIZE;
61
+ const int warp_id = threadIdx.x / WARP_SIZE;
62
+
63
+ if (lane == 0)
64
+ warp_sums[warp_id] = partial;
65
+ __syncthreads();
66
+
67
+ float total = 0.0f;
68
+ if (warp_id == 0) {
69
+ total = (lane < (BLOCK_SIZE / WARP_SIZE)) ? warp_sums[lane] : 0.0f;
70
+ total = warp_reduce_sum(total);
71
+ }
72
+
73
+ // Phase 4: broadcast norm factor
74
+ __shared__ float s_norm;
75
+ if (threadIdx.x == 0)
76
+ s_norm = rsqrtf(total / (float)dim + eps);
77
+ __syncthreads();
78
+
79
+ // Phase 5: normalize and write
80
+ const float nf = s_norm;
81
+ for (int i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
82
+ float v = __half2float(x_row[i]);
83
+ float wt = __half2float(w[i]);
84
+ out_row[i] = __float2half(v * nf * wt);
85
+ }
86
+ }
87
+
88
+ torch::Tensor fused_rmsnorm(torch::Tensor x, torch::Tensor weight, float eps) {
89
+ TORCH_CHECK(x.is_cuda() && x.scalar_type() == torch::kHalf);
90
+ TORCH_CHECK(weight.is_cuda() && weight.scalar_type() == torch::kHalf);
91
+
92
+ auto x_c = x.contiguous();
93
+ auto out = torch::empty_like(x_c);
94
+ const int dim = x_c.size(-1);
95
+ const int rows = x_c.numel() / dim;
96
+
97
+ fused_rmsnorm_kernel<<<rows, BLOCK_SIZE>>>(
98
+ reinterpret_cast<const half*>(x_c.data_ptr<at::Half>()),
99
+ reinterpret_cast<const half*>(weight.data_ptr<at::Half>()),
100
+ reinterpret_cast<half*>(out.data_ptr<at::Half>()),
101
+ dim, eps
102
+ );
103
+
104
+ return out;
105
+ }
106
+
107
+ // ============================================================================
108
+ // Kernel 2: Fused Half-Rotation RoPE (in-place on contiguous copy)
109
+ // ============================================================================
110
+ //
111
+ // Half-rotation format: q[..., :hd/2] = real, q[..., hd/2:] = imaginary
112
+ // Replaces: float() cast + 4 slices + 4 muls + 2 subs + 2 cats + half() cast
113
+ //
114
+ // Works for both decode (T=1) and prefill (T>1).
115
+ // Data layout after .contiguous(): q[h, t, d] = q_ptr[h*T*hd + t*hd + d]
116
+
117
+ __global__ void fused_rope_kernel(
118
+ half* __restrict__ data, // [n_heads, T, head_dim] contiguous
119
+ const half* __restrict__ cos_ptr, // [T, half_dim]
120
+ const half* __restrict__ sin_ptr, // [T, half_dim]
121
+ int n_heads,
122
+ int seq_len,
123
+ int head_dim
124
+ ) {
125
+ const int half_dim = head_dim >> 1;
126
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
127
+ const int total = n_heads * seq_len * half_dim;
128
+ if (idx >= total) return;
129
+
130
+ const int i = idx % half_dim;
131
+ const int t = (idx / half_dim) % seq_len;
132
+ const int h = idx / (half_dim * seq_len);
133
+
134
+ const int base = h * seq_len * head_dim + t * head_dim;
135
+ const int cs = t * half_dim + i;
136
+
137
+ const float c = __half2float(cos_ptr[cs]);
138
+ const float s = __half2float(sin_ptr[cs]);
139
+ const float real = __half2float(data[base + i]);
140
+ const float imag = __half2float(data[base + half_dim + i]);
141
+
142
+ data[base + i] = __float2half(real * c - imag * s);
143
+ data[base + half_dim + i] = __float2half(real * s + imag * c);
144
+ }
145
+
146
+ std::vector<torch::Tensor> fused_rope_hf(
147
+ torch::Tensor q, // [B, n_q_heads, T, head_dim]
148
+ torch::Tensor k, // [B, n_kv_heads, T, head_dim]
149
+ torch::Tensor cos, // [T, head_dim/2]
150
+ torch::Tensor sin // [T, head_dim/2]
151
+ ) {
152
+ TORCH_CHECK(q.is_cuda() && q.scalar_type() == torch::kHalf);
153
+
154
+ // Make contiguous copies (needed because transpose makes q/k non-contiguous)
155
+ auto q_c = q.contiguous();
156
+ auto k_c = k.contiguous();
157
+ auto cos_c = cos.contiguous();
158
+ auto sin_c = sin.contiguous();
159
+
160
+ const int B = q_c.size(0);
161
+ const int n_q = q_c.size(1);
162
+ const int T = q_c.size(2);
163
+ const int hd = q_c.size(3);
164
+ const int n_kv = k_c.size(1);
165
+ const int half_dim = hd >> 1;
166
+
167
+ // Process each batch item (B=1 in practice, so no overhead)
168
+ for (int b = 0; b < B; b++) {
169
+ half* q_ptr = reinterpret_cast<half*>(q_c.data_ptr<at::Half>()) + b * n_q * T * hd;
170
+ half* k_ptr = reinterpret_cast<half*>(k_c.data_ptr<at::Half>()) + b * n_kv * T * hd;
171
+
172
+ const int total_q = n_q * T * half_dim;
173
+ const int total_kv = n_kv * T * half_dim;
174
+
175
+ fused_rope_kernel<<<(total_q + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
176
+ q_ptr,
177
+ reinterpret_cast<const half*>(cos_c.data_ptr<at::Half>()),
178
+ reinterpret_cast<const half*>(sin_c.data_ptr<at::Half>()),
179
+ n_q, T, hd
180
+ );
181
+
182
+ fused_rope_kernel<<<(total_kv + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
183
+ k_ptr,
184
+ reinterpret_cast<const half*>(cos_c.data_ptr<at::Half>()),
185
+ reinterpret_cast<const half*>(sin_c.data_ptr<at::Half>()),
186
+ n_kv, T, hd
187
+ );
188
+ }
189
+
190
+ return {q_c, k_c};
191
+ }
192
+
193
+ // ============================================================================
194
+ // Kernel 3: Fused SiLU * Multiply
195
+ // ============================================================================
196
+ //
197
+ // Replaces: F.silu(gate) * up (2 separate kernels)
198
+ // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
199
+
200
+ __global__ void fused_silu_mul_kernel(
201
+ const half* __restrict__ gate,
202
+ const half* __restrict__ up,
203
+ half* __restrict__ out,
204
+ int n
205
+ ) {
206
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
207
+ if (idx < n) {
208
+ const float g = __half2float(gate[idx]);
209
+ const float u = __half2float(up[idx]);
210
+ out[idx] = __float2half((g / (1.0f + expf(-g))) * u);
211
+ }
212
+ }
213
+
214
+ torch::Tensor fused_silu_mul(torch::Tensor gate, torch::Tensor up) {
215
+ TORCH_CHECK(gate.is_cuda() && gate.scalar_type() == torch::kHalf);
216
+ TORCH_CHECK(up.is_cuda() && up.scalar_type() == torch::kHalf);
217
+
218
+ auto gate_c = gate.contiguous();
219
+ auto up_c = up.contiguous();
220
+ auto out = torch::empty_like(gate_c);
221
+ const int n = gate_c.numel();
222
+
223
+ fused_silu_mul_kernel<<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
224
+ reinterpret_cast<const half*>(gate_c.data_ptr<at::Half>()),
225
+ reinterpret_cast<const half*>(up_c.data_ptr<at::Half>()),
226
+ reinterpret_cast<half*>(out.data_ptr<at::Half>()),
227
+ n
228
+ );
229
+
230
+ return out;
231
+ }
232
+
233
+ // ============================================================================
234
+ // Module
235
+ // ============================================================================
236
+
237
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
238
+ m.def("fused_rmsnorm", &fused_rmsnorm, "Fused RMSNorm (fp16 I/O, fp32 accumulation)");
239
+ m.def("fused_rope_hf", &fused_rope_hf, "Fused half-rotation RoPE (fp16, in-place on contiguous copy)");
240
+ m.def("fused_silu_mul", &fused_silu_mul, "Fused SiLU * mul (fp16)");
241
+ }
params.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dim": 3072,
3
+ "n_layers": 26,
4
+ "head_dim": 128,
5
+ "hidden_dim": 9216,
6
+ "n_heads": 32,
7
+ "n_kv_heads": 8,
8
+ "use_biases": false,
9
+ "causal": true,
10
+ "rope_theta": 1000000.0,
11
+ "norm_eps": 1e-05,
12
+ "vocab_size": 131072,
13
+ "model_parallel": 1,
14
+ "tied_embeddings": true,
15
+ "sliding_window": 8192,
16
+ "model_max_length": 131072,
17
+ "multimodal": {
18
+ "whisper_model_args": {
19
+ "encoder_args": {
20
+ "audio_encoding_args": {
21
+ "sampling_rate": 16000,
22
+ "frame_rate": 12.5,
23
+ "num_mel_bins": 128,
24
+ "hop_length": 160,
25
+ "window_size": 400,
26
+ "chunk_length_s": null,
27
+ "global_log_mel_max": 1.5,
28
+ "transcription_format": "streaming"
29
+ },
30
+ "dim": 1280,
31
+ "n_layers": 32,
32
+ "head_dim": 64,
33
+ "hidden_dim": 5120,
34
+ "n_heads": 32,
35
+ "vocab_size": 131072,
36
+ "n_kv_heads": 32,
37
+ "use_biases": true,
38
+ "use_cache": false,
39
+ "rope_theta": 1000000.0,
40
+ "causal": true,
41
+ "norm_eps": 1e-05,
42
+ "pos_embed": "rope",
43
+ "max_source_positions": null,
44
+ "ffn_type": "swiglu",
45
+ "norm_type": "rms_norm",
46
+ "sliding_window": 750,
47
+ "ragged_attention": "750"
48
+ },
49
+ "downsample_args": {
50
+ "downsample_factor": 4
51
+ }
52
+ }
53
+ },
54
+ "ada_rms_norm_t_cond": true,
55
+ "ada_rms_norm_t_cond_dim": 32,
56
+ "quantization_config": {
57
+ "quant_method": "gptq",
58
+ "bits": 4,
59
+ "group_size": 128,
60
+ "desc_act": false,
61
+ "sym": true,
62
+ "checkpoint_format": "gptq",
63
+ "pack_dtype": "int32"
64
+ }
65
+ }
scripts/jetson_serve_sdpa.py ADDED
@@ -0,0 +1,1277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Voxtral Mini 4B Realtime β€” Jetson Orin Nano 8GB inference server.
3
+
4
+ Loads INT4-packed GPTQ weights from Mistral native format and serves
5
+ transcription via WebSocket at ws://localhost:8000/v1/realtime.
6
+
7
+ Key architecture detail: at each decoder position, the input embedding is
8
+ adapter_out[pos] + tok_embed(token_id)
9
+ where token_id is BOS/STREAMING_PAD during the prompt, then previously
10
+ generated text tokens during generation. Generation runs for exactly
11
+ T_adapter total positions.
12
+
13
+ Memory budget: ~6 GB total (model 4.08 GB + runtime 1.5 GB + KV cache 0.2 GB).
14
+
15
+ Optimizations over baseline:
16
+ - Marlin fused INT4 dequant+matmul (24x faster generation)
17
+ - F.scaled_dot_product_attention (fused attention kernel)
18
+ - Pre-allocated KV cache (eliminates torch.cat per token per layer)
19
+ - Optional torch.compile (fuses pointwise ops between matmuls)
20
+
21
+ Usage (inside PyTorch Jetson container):
22
+ pip install safetensors websockets soundfile numpy librosa
23
+ python3 jetson_serve.py --test /workspace/voxtral-quant/test_audio_0.wav
24
+ python3 jetson_serve.py # starts WebSocket server on port 8000
25
+ """
26
+
27
+ import argparse
28
+ import asyncio
29
+ import base64
30
+ import json
31
+ import math
32
+ import os
33
+ import sys
34
+ import time
35
+ import uuid
36
+ import warnings
37
+ from typing import List, Optional, Tuple
38
+
39
+ # Suppress harmless warnings
40
+ warnings.filterwarnings('ignore', message='.*divide by zero.*')
41
+ warnings.filterwarnings('ignore', message='.*FutureWarning.*tokenizer.*')
42
+
43
+ # Set CUDA allocator config before importing torch (critical for Jetson)
44
+ os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
45
+
46
+ import numpy as np
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+ from safetensors import safe_open
51
+
52
+ # Try to import Marlin fused INT4 kernel (50x faster than on-the-fly dequantization)
53
+ try:
54
+ import marlin as _marlin
55
+ HAS_MARLIN = True
56
+ except ImportError:
57
+ HAS_MARLIN = False
58
+
59
+ # Try to JIT-compile fused CUDA kernels (collapses ~500 kernel launches/token to ~80)
60
+ HAS_FUSED = False
61
+ try:
62
+ from torch.utils.cpp_extension import load as _load_ext
63
+ _cu_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
64
+ '..', 'kernels', 'fused_ops.cu')
65
+ if not os.path.exists(_cu_path):
66
+ # Fallback for container mount path
67
+ _cu_path = '/workspace/voxtral-quant/kernels/fused_ops.cu'
68
+ if os.path.exists(_cu_path):
69
+ voxtral_kernels = _load_ext(
70
+ name='voxtral_kernels',
71
+ sources=[_cu_path],
72
+ extra_cuda_cflags=['-O3', '--use_fast_math', '-arch=sm_87'],
73
+ verbose=True
74
+ )
75
+ HAS_FUSED = True
76
+ print(f"Fused CUDA kernels loaded from {_cu_path}")
77
+ else:
78
+ print(f"Fused kernels .cu not found, using PyTorch fallback")
79
+ except Exception as e:
80
+ print(f"Fused kernel JIT failed ({e}), using PyTorch fallback")
81
+
82
+ # ─── Constants ───────────────────────────────────────────────────────────────
83
+
84
+ BITS = 4
85
+ GROUP_SIZE = 128
86
+ PACK_FACTOR = 32 // BITS # 8 int4 values per int32
87
+ BIAS = 1 << (BITS - 1) # 8 (uint4b8 encoding)
88
+
89
+ TOKEN_BOS = 1
90
+ TOKEN_EOS = 2
91
+ TOKEN_STREAMING_PAD = 32
92
+
93
+ SAMPLE_RATE = 16000
94
+ HOP_LENGTH = 160
95
+ N_MELS = 128
96
+ WINDOW_SIZE = 400
97
+ GLOBAL_LOG_MEL_MAX = 1.5 # Fixed constant from params.json (NOT per-sample)
98
+ RAW_AUDIO_PER_TOK = 1280 # Raw audio samples per adapter token
99
+ N_LEFT_PAD_TOKENS = 32 # Left silence padding (aligns with SPAD prompt)
100
+ N_RIGHT_PAD_BASE = 17 # Right silence padding (delay+1 + OFFLINE_BUFFER)
101
+ DOWNSAMPLE_FACTOR = 4
102
+
103
+
104
+ # ─── Marlin Fused INT4 Linear ────────────────────────────────────────────────
105
+
106
+ class MarlinLinear(nn.Module):
107
+ """Linear layer using Marlin fused INT4 dequant+matmul CUDA kernel.
108
+
109
+ Repacks GPTQ INT4 weights into Marlin's optimized format at construction time.
110
+ Forward pass is a single fused kernel call β€” ~50x faster than on-the-fly dequant.
111
+ Memory footprint is identical to GPTQ INT4 (no extra memory needed).
112
+ """
113
+
114
+ def __init__(self, qweight, scales, qzeros, unpermute=None):
115
+ super().__init__()
116
+ in_features = qweight.shape[0] * PACK_FACTOR
117
+ out_features = qweight.shape[1]
118
+ n_groups = scales.shape[0]
119
+ self.in_features = in_features
120
+ self.out_features = out_features
121
+
122
+ # Dequantize GPTQ β†’ fp16, then repack into Marlin format
123
+ shifts = torch.arange(0, 32, BITS, device=qweight.device, dtype=torch.int32)
124
+ unpacked = (qweight.unsqueeze(0) >> shifts.view(-1, 1, 1)) & 0xF
125
+ unpacked = unpacked.permute(1, 0, 2).reshape(in_features, out_features)
126
+ unpacked = unpacked.T.reshape(out_features, n_groups, GROUP_SIZE)
127
+ s = scales.T.float().unsqueeze(-1)
128
+ w_fp16 = ((unpacked.float() - BIAS) * s).reshape(out_features, in_features).half()
129
+ del unpacked, s
130
+
131
+ if unpermute is not None:
132
+ n_heads, hidden_size = unpermute
133
+ head_dim = w_fp16.shape[0] // n_heads
134
+ w_fp16 = (w_fp16.view(n_heads, 2, head_dim // 2, hidden_size)
135
+ .transpose(1, 2)
136
+ .reshape(out_features, in_features))
137
+
138
+ # Create temporary nn.Linear for Marlin's pack()
139
+ linear = nn.Linear(in_features, out_features, bias=False,
140
+ dtype=torch.half, device=qweight.device)
141
+ linear.weight.data = w_fp16
142
+
143
+ # Create Marlin layer and pack (handles permutation + bit packing)
144
+ ml = _marlin.Layer(in_features, out_features, groupsize=GROUP_SIZE)
145
+ ml.pack(linear, scales.T)
146
+ del linear, w_fp16
147
+
148
+ # Store Marlin buffers
149
+ self.register_buffer('B', ml.B.to(qweight.device))
150
+ self.register_buffer('s', ml.s.to(qweight.device))
151
+ self.register_buffer('workspace',
152
+ torch.zeros(out_features // 128 * 16,
153
+ dtype=torch.int, device=qweight.device),
154
+ persistent=False)
155
+
156
+ def forward(self, x):
157
+ out_shape = x.shape[:-1] + (self.out_features,)
158
+ C = torch.empty(out_shape, dtype=x.dtype, device=x.device)
159
+ _marlin.mul(x.view(-1, x.shape[-1]), self.B, C.view(-1, self.out_features),
160
+ self.s, self.workspace)
161
+ return C
162
+
163
+
164
+ # ─── GPTQ INT4 Dequantization (fallback when Marlin unavailable) ────────────
165
+
166
+ class DequantLinear(nn.Module):
167
+ """Linear layer with INT4 GPTQ packed weights.
168
+
169
+ Supports two modes:
170
+ - On-the-fly dequantization (default): dequantizes each forward call
171
+ - Cached mode: stores pre-dequantized fp16 weight for fast matmul
172
+ """
173
+
174
+ _shifts = None # class-level cached shifts tensor
175
+
176
+ def __init__(self, qweight, scales, qzeros, unpermute=None):
177
+ super().__init__()
178
+ self.register_buffer('qweight', qweight)
179
+ self.register_buffer('scales', scales)
180
+ self.register_buffer('qzeros', qzeros)
181
+ self.in_features = qweight.shape[0] * PACK_FACTOR
182
+ self.out_features = qweight.shape[1]
183
+ self.unpermute = unpermute
184
+ self._cached_w = None # pre-dequantized fp16 weight [out, in]
185
+
186
+ def cache_weight(self, free_int4=True):
187
+ """Pre-dequantize and cache the fp16 weight.
188
+ If free_int4=True, frees INT4 buffers (saves memory, not reversible).
189
+ """
190
+ self._cached_w = self._dequantize()
191
+ if free_int4:
192
+ self.qweight = None
193
+ self.scales = None
194
+ self.qzeros = None
195
+
196
+ def uncache_weight(self):
197
+ """Free the cached weight (e.g., before re-loading INT4 weights)."""
198
+ self._cached_w = None
199
+
200
+ @property
201
+ def cached_bytes(self):
202
+ """Memory used by cached weight in bytes."""
203
+ if self._cached_w is not None:
204
+ return self._cached_w.nelement() * self._cached_w.element_size()
205
+ return 0
206
+
207
+ def _dequantize(self):
208
+ """Dequantize INT4 packed weights to fp16 [out, in]."""
209
+ qw = self.qweight
210
+ in_packed, out = qw.shape
211
+ n_groups = self.scales.shape[0]
212
+
213
+ # Cached shifts tensor (shared across all instances)
214
+ if DequantLinear._shifts is None or DequantLinear._shifts.device != qw.device:
215
+ DequantLinear._shifts = torch.arange(0, 32, BITS, device=qw.device, dtype=torch.int32)
216
+ shifts = DequantLinear._shifts
217
+
218
+ # Vectorized unpack: [8, in/8, out]
219
+ unpacked = (qw.unsqueeze(0) >> shifts.view(-1, 1, 1)) & 0xF
220
+ # Interleave to [in, out] then transpose+group to [out, groups, GROUP_SIZE]
221
+ unpacked = unpacked.permute(1, 0, 2).reshape(self.in_features, out)
222
+ unpacked = unpacked.T.reshape(out, n_groups, GROUP_SIZE)
223
+ # Scale: (val - 8) * scale
224
+ s = self.scales.T.float().unsqueeze(-1)
225
+ w = ((unpacked.float() - BIAS) * s).reshape(out, self.in_features).half()
226
+ del unpacked, s
227
+
228
+ if self.unpermute is not None:
229
+ n_heads, hidden_size = self.unpermute
230
+ head_dim = w.shape[0] // n_heads
231
+ w = (w.view(n_heads, 2, head_dim // 2, hidden_size)
232
+ .transpose(1, 2)
233
+ .reshape(out, self.in_features))
234
+ return w
235
+
236
+ def forward(self, x):
237
+ if self._cached_w is not None:
238
+ return F.linear(x, self._cached_w)
239
+ w = self._dequantize()
240
+ result = F.linear(x, w)
241
+ del w
242
+ return result
243
+
244
+
245
+ # ─── Building Blocks ─────────────────────────────────────────────────────────
246
+
247
+ class RMSNorm(nn.Module):
248
+ def __init__(self, dim, eps=1e-5):
249
+ super().__init__()
250
+ self.eps = eps
251
+ self.weight = nn.Parameter(torch.ones(dim))
252
+
253
+ def forward(self, x):
254
+ return (x.float() * (x.float().pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()).type_as(x) * self.weight
255
+
256
+
257
+ class RMSNormFP16(nn.Module):
258
+ """RMSNorm that stays in fp16 β€” avoids 2x float32 conversion overhead.
259
+ Safe for decoder hidden dim=3072 where values stay in fp16 range.
260
+ Uses fused CUDA kernel when available (single kernel vs 5-6 PyTorch kernels)."""
261
+ def __init__(self, dim, eps=1e-5):
262
+ super().__init__()
263
+ self.eps = eps
264
+ self.weight = nn.Parameter(torch.ones(dim))
265
+
266
+ def forward(self, x):
267
+ if HAS_FUSED:
268
+ return voxtral_kernels.fused_rmsnorm(x.contiguous(), self.weight, self.eps)
269
+ rms = x.pow(2).mean(-1, keepdim=True).add_(self.eps).rsqrt_()
270
+ return x * rms * self.weight
271
+
272
+
273
+ def apply_rotary_emb(q, k, cos, sin):
274
+ """Interleaved RoPE for Mistral-format Q/K: q,k [B, heads, T, dim], cos,sin [T, dim/2].
275
+ Pairs: (d0,d1), (d2,d3), etc."""
276
+ cos = cos.unsqueeze(0).unsqueeze(0)
277
+ sin = sin.unsqueeze(0).unsqueeze(0)
278
+ q_r, q_i = q.float().reshape(*q.shape[:-1], -1, 2).unbind(-1)
279
+ k_r, k_i = k.float().reshape(*k.shape[:-1], -1, 2).unbind(-1)
280
+ return (
281
+ torch.stack([q_r*cos - q_i*sin, q_r*sin + q_i*cos], -1).flatten(-2).type_as(q),
282
+ torch.stack([k_r*cos - k_i*sin, k_r*sin + k_i*cos], -1).flatten(-2).type_as(k),
283
+ )
284
+
285
+
286
+ def apply_rotary_emb_hf(q, k, cos, sin):
287
+ """Half-rotation RoPE for HF-format Q/K in fp16 (avoids float32 conversion).
288
+ q,k: [B, heads, T, dim]. cos,sin: [T, dim/2].
289
+ Uses fused CUDA kernel when available (~14 PyTorch kernels β†’ 2)."""
290
+ if HAS_FUSED:
291
+ return voxtral_kernels.fused_rope_hf(q, k, cos, sin)
292
+ cos = cos.unsqueeze(0).unsqueeze(0)
293
+ sin = sin.unsqueeze(0).unsqueeze(0)
294
+ half = q.shape[-1] // 2
295
+ q1, q2 = q[..., :half], q[..., half:]
296
+ k1, k2 = k[..., :half], k[..., half:]
297
+ return (
298
+ torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], -1),
299
+ torch.cat([k1 * cos - k2 * sin, k1 * sin + k2 * cos], -1),
300
+ )
301
+
302
+
303
+ def make_freqs(dim, maxlen, theta=1e6, device='cpu', dtype=torch.float16):
304
+ f = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
305
+ t = torch.arange(maxlen, device=device)
306
+ a = torch.outer(t, f)
307
+ return a.cos().to(dtype), a.sin().to(dtype)
308
+
309
+
310
+ # ─── KV Cache ────────────────────────────────────────────────────────────────
311
+
312
+ class KVCache:
313
+ """Pre-allocated KV cache that eliminates per-token torch.cat allocations.
314
+
315
+ Instead of creating new tensors and concatenating on every token (3900+
316
+ allocations for a typical transcription), we pre-allocate fixed buffers
317
+ and write into them with an advancing position index.
318
+
319
+ Memory cost: 2 * n_layers * n_kv_heads * max_seq * head_dim * 2 bytes
320
+ For 26 layers, 8 heads, 200 seq, 128 dim = ~21 MB (negligible).
321
+ """
322
+
323
+ def __init__(self, n_layers, max_seq_len, n_kv_heads, head_dim, device, dtype):
324
+ self.max_seq_len = max_seq_len
325
+ self.pos = 0
326
+ # [n_layers, 1, n_kv_heads, max_seq_len, head_dim]
327
+ self.k = torch.zeros(n_layers, 1, n_kv_heads, max_seq_len, head_dim,
328
+ device=device, dtype=dtype)
329
+ self.v = torch.zeros(n_layers, 1, n_kv_heads, max_seq_len, head_dim,
330
+ device=device, dtype=dtype)
331
+
332
+ def update(self, layer_idx, k_new, v_new):
333
+ """Write new K/V into cache and return full valid slice.
334
+
335
+ Args:
336
+ layer_idx: which decoder layer
337
+ k_new, v_new: [1, n_kv_heads, T_new, head_dim] (T_new=prompt_len or 1)
338
+
339
+ Returns:
340
+ k_full, v_full: [1, n_kv_heads, pos+T_new, head_dim]
341
+ """
342
+ t_new = k_new.shape[2]
343
+ end = self.pos + t_new
344
+ self.k[layer_idx, :, :, self.pos:end, :] = k_new
345
+ self.v[layer_idx, :, :, self.pos:end, :] = v_new
346
+ return self.k[layer_idx, :, :, :end, :], self.v[layer_idx, :, :, :end, :]
347
+
348
+ def advance(self, n):
349
+ """Advance position after all layers have processed n new tokens."""
350
+ self.pos += n
351
+
352
+ def reset(self):
353
+ self.pos = 0
354
+
355
+
356
+ # ─── Audio Encoder ───────────────────────────────────────────────────────────
357
+
358
+ class EncAttn(nn.Module):
359
+ def __init__(self, h=1280, nh=32, hd=64):
360
+ super().__init__()
361
+ self.nh, self.hd = nh, hd
362
+ self.ad = nh * hd
363
+ self.scale = hd ** -0.5
364
+ self.wq = nn.Linear(h, self.ad, bias=True)
365
+ self.wk = nn.Linear(h, self.ad, bias=False)
366
+ self.wv = nn.Linear(h, self.ad, bias=True)
367
+ self.wo = nn.Linear(self.ad, h, bias=True)
368
+
369
+ def forward(self, x, cos, sin, mask):
370
+ B, T, _ = x.shape
371
+ q = self.wq(x).view(B, T, self.nh, self.hd).transpose(1, 2)
372
+ k = self.wk(x).view(B, T, self.nh, self.hd).transpose(1, 2)
373
+ v = self.wv(x).view(B, T, self.nh, self.hd).transpose(1, 2)
374
+ q, k = apply_rotary_emb(q, k, cos, sin)
375
+ # Encoder uses causal attention with explicit mask
376
+ a = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, scale=self.scale)
377
+ return self.wo(a.transpose(1, 2).reshape(B, T, self.ad))
378
+
379
+
380
+ class EncLayer(nn.Module):
381
+ def __init__(self, h=1280, ff=5120):
382
+ super().__init__()
383
+ self.attn = EncAttn(h)
384
+ self.an = RMSNorm(h)
385
+ self.fn = RMSNorm(h)
386
+ self.w1 = nn.Linear(h, ff, bias=False)
387
+ self.w2 = nn.Linear(ff, h, bias=True)
388
+ self.w3 = nn.Linear(h, ff, bias=False)
389
+
390
+ def forward(self, x, cos, sin, mask):
391
+ x = x + self.attn(self.an(x), cos, sin, mask)
392
+ h = self.fn(x)
393
+ return x + self.w2(F.silu(self.w1(h)) * self.w3(h))
394
+
395
+
396
+ class Encoder(nn.Module):
397
+ def __init__(self, nl=32, h=1280, ff=5120, hd=64):
398
+ super().__init__()
399
+ self.conv1 = nn.Conv1d(N_MELS, h, 3, padding=1)
400
+ self.conv2 = nn.Conv1d(h, h, 3, stride=2, padding=1)
401
+ self.layers = nn.ModuleList([EncLayer(h, ff) for _ in range(nl)])
402
+ self.norm = RMSNorm(h)
403
+ self.hd = hd
404
+
405
+ def forward(self, mel):
406
+ x = F.gelu(self.conv1(mel))
407
+ x = F.gelu(self.conv2(x)).transpose(1, 2)
408
+ T = x.shape[1]
409
+ cos, sin = make_freqs(self.hd, T, device=x.device, dtype=x.dtype)
410
+ mask = torch.triu(torch.full((T, T), float('-inf'), device=x.device, dtype=x.dtype), 1).unsqueeze(0).unsqueeze(0)
411
+ for layer in self.layers:
412
+ x = layer(x, cos, sin, mask)
413
+ return self.norm(x)
414
+
415
+
416
+ # ─── Projector ───────────────────────────────────────────────────────────────
417
+
418
+ class Projector(nn.Module):
419
+ def __init__(self, inp=5120, out=3072):
420
+ super().__init__()
421
+ self.l1 = nn.Linear(inp, out, bias=False)
422
+ self.l2 = nn.Linear(out, out, bias=False)
423
+
424
+ def forward(self, x):
425
+ return self.l2(F.gelu(self.l1(x)))
426
+
427
+
428
+ # ─── LM Decoder Layer ───────────────────────────────────────────────────────
429
+
430
+ class DecAttn(nn.Module):
431
+ def __init__(self, h=3072, nh=32, nkv=8, hd=128):
432
+ super().__init__()
433
+ self.nh, self.nkv, self.hd = nh, nkv, hd
434
+ self.qd, self.kvd = nh*hd, nkv*hd
435
+ self.scale = hd ** -0.5
436
+ self.g = nh // nkv
437
+ self.q_proj = self.k_proj = self.v_proj = self.o_proj = None # set by loader
438
+
439
+ def forward(self, x, cos, sin, cache=None, layer_idx=None, is_causal=False):
440
+ B, T, _ = x.shape
441
+ q = self.q_proj(x).view(B, T, self.nh, self.hd).transpose(1, 2)
442
+ k = self.k_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
443
+ v = self.v_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
444
+ q, k = apply_rotary_emb_hf(q, k, cos, sin)
445
+
446
+ # Update KV cache if available
447
+ if cache is not None:
448
+ k, v = cache.update(layer_idx, k, v)
449
+
450
+ # GQA: expand KV heads to match query heads
451
+ if self.g > 1:
452
+ k = k.repeat_interleave(self.g, 1)
453
+ v = v.repeat_interleave(self.g, 1)
454
+
455
+ # Fused attention via SDPA
456
+ # - Prefill (T > 1): use is_causal=True for triangular mask
457
+ # - Decode (T == 1): no mask needed, single query attends to all past
458
+ out = F.scaled_dot_product_attention(
459
+ q, k, v, scale=self.scale, is_causal=is_causal)
460
+
461
+ return self.o_proj(out.transpose(1, 2).reshape(B, T, self.qd))
462
+
463
+
464
+ class DecLayer(nn.Module):
465
+ def __init__(self, h=3072, ad=32):
466
+ super().__init__()
467
+ self.an = RMSNormFP16(h)
468
+ self.attn = DecAttn()
469
+ self.fn = RMSNormFP16(h)
470
+ self.gate_proj = self.up_proj = self.down_proj = None # set by loader
471
+ self.ada0 = nn.Linear(h, ad, bias=False)
472
+ self.ada2 = nn.Linear(ad, h, bias=False)
473
+ self._ada_scale = None # pre-computed: 1 + ada2(gelu(ada0(t_cond)))
474
+
475
+ def precompute_ada(self, t_cond):
476
+ """Pre-compute the ada_rms_norm modulation scale from t_cond.
477
+ Since t_cond is constant (delay-based), this eliminates 2 matmuls +
478
+ gelu + add per layer per token (~0.45ms each, 11.7ms total for 26 layers)."""
479
+ with torch.no_grad():
480
+ self._ada_scale = (1.0 + self.ada2(F.gelu(self.ada0(t_cond)))).unsqueeze(0)
481
+
482
+ def forward(self, x, cos, sin, cache=None, layer_idx=None,
483
+ is_causal=False, t_cond=None):
484
+ h = self.attn(self.an(x), cos, sin, cache, layer_idx, is_causal)
485
+ x = x + h
486
+ h = self.fn(x)
487
+ if self._ada_scale is not None:
488
+ h = h * self._ada_scale
489
+ elif t_cond is not None:
490
+ h = h * (1.0 + self.ada2(F.gelu(self.ada0(t_cond))).unsqueeze(0))
491
+ gate_out = self.gate_proj(h)
492
+ up_out = self.up_proj(h)
493
+ if HAS_FUSED:
494
+ x = x + self.down_proj(voxtral_kernels.fused_silu_mul(gate_out, up_out))
495
+ else:
496
+ x = x + self.down_proj(F.silu(gate_out) * up_out)
497
+ return x
498
+
499
+
500
+ # ─── Full Model ──────────────────────────────────────────────────────────────
501
+
502
+ class VoxtralModel:
503
+ def __init__(self, model_path, device='cuda', dtype=torch.float16, compile=False):
504
+ self.device = device
505
+ self.dtype = dtype
506
+ self.model_path = model_path
507
+
508
+ with open(os.path.join(model_path, 'params.json')) as f:
509
+ self.p = json.load(f)
510
+
511
+ ea = self.p['multimodal']['whisper_model_args']['encoder_args']
512
+ da = self.p['multimodal']['whisper_model_args']['downsample_args']
513
+ self.ds_factor = da['downsample_factor']
514
+ self.enc_h = ea['dim']
515
+ self.n_layers = self.p['n_layers']
516
+ self.h = self.p['dim']
517
+ self.n_kv_heads = self.p.get('n_kv_heads', 8)
518
+ self.head_dim = self.p.get('head_dim', 128)
519
+ self.delay_ms = self.p['multimodal']['whisper_model_args'].get(
520
+ 'encoder_args', {}).get('audio_encoding_args', {}).get(
521
+ 'transcription_delay_ms', 480)
522
+
523
+ # Frame rate from config
524
+ self.frame_rate = ea.get('audio_encoding_args', {}).get('frame_rate', 12.5)
525
+ # delay_tokens = delay_ms / ms_per_frame
526
+ ms_per_frame = 1000.0 / self.frame_rate
527
+ self.delay_tokens = int(self.delay_ms / ms_per_frame)
528
+ self.n_left_pad = 32 # streaming_n_left_pad_tokens from audio config
529
+
530
+ self._build()
531
+ self._load()
532
+ if compile:
533
+ self._try_compile()
534
+
535
+ self.cos, self.sin = make_freqs(self.p['head_dim'], 8192, device=device, dtype=dtype)
536
+
537
+ # Time conditioning: sinusoidal embedding of delay token count
538
+ self.t_cond = self._compute_time_embedding(
539
+ float(self.delay_tokens), self.h
540
+ ).to(device=device, dtype=dtype)
541
+
542
+ # Pre-compute ada_rms_norm modulation for all decoder layers
543
+ # (t_cond is constant, so ada output is constant per layer)
544
+ for layer in self.layers:
545
+ layer.precompute_ada(self.t_cond)
546
+
547
+ def _build(self):
548
+ """Build model skeleton on meta device (zero memory allocation)."""
549
+ import gc
550
+ ea = self.p['multimodal']['whisper_model_args']['encoder_args']
551
+ with torch.device('meta'):
552
+ self.encoder = Encoder(ea['n_layers'], ea['dim'], ea['hidden_dim'], ea['head_dim'])
553
+ self.projector = Projector(ea['hidden_dim'], self.h)
554
+ self.layers = nn.ModuleList([
555
+ DecLayer(self.h, self.p.get('ada_rms_norm_t_cond_dim', 32))
556
+ for _ in range(self.n_layers)
557
+ ])
558
+ self.embed = nn.Embedding(self.p['vocab_size'], self.h)
559
+ self.norm = RMSNorm(self.h)
560
+ gc.collect()
561
+
562
+ def _try_compile(self):
563
+ """Optionally compile decoder layers with torch.compile.
564
+
565
+ Fuses pointwise ops (RMSNorm, SiLU, ada_rms_norm, residuals) between
566
+ Marlin matmuls, reducing kernel launch overhead at batch-1 decode.
567
+ Falls back gracefully if torch.compile isn't available on this platform.
568
+ """
569
+ try:
570
+ compiled = 0
571
+ for i in range(len(self.layers)):
572
+ self.layers[i] = torch.compile(self.layers[i], mode="default")
573
+ compiled += 1
574
+ print(f" torch.compile: {compiled} decoder layers compiled")
575
+ except Exception as e:
576
+ print(f" torch.compile not available ({e}), using eager mode")
577
+
578
+ def _pack_lm_head(self):
579
+ """Quantize the embedding weight to INT4 and pack for Marlin LM head.
580
+
581
+ The LM head (output projection) reads the full 768 MB fp16 embedding
582
+ matrix on every token β€” 21ms at Jetson bandwidth. By quantizing to INT4
583
+ and using Marlin, this drops to ~4ms (5.8x faster).
584
+
585
+ Must be called before decoder layers are loaded (maximum free GPU memory).
586
+ """
587
+ import gc
588
+ embed_w = self.embed.weight.data # [vocab_size, dim] fp16
589
+ out_f, in_f = embed_w.shape
590
+
591
+ if in_f % 128 != 0 or out_f % 256 != 0:
592
+ print(f" LM head: dims {out_f}x{in_f} not Marlin-compatible, keeping fp16")
593
+ self._lm_head_marlin = False
594
+ return
595
+
596
+ t0 = time.time()
597
+ # Compute per-group scales (group_size=128 along input dim)
598
+ n_groups = in_f // GROUP_SIZE
599
+ w_grouped = embed_w.float().reshape(out_f, n_groups, GROUP_SIZE)
600
+ scales = (w_grouped.abs().amax(dim=-1) / 7.0).clamp(min=1e-6).half() # [out_f, n_groups]
601
+ del w_grouped
602
+
603
+ # Pack with Marlin (re-uses embed_w in-place for the linear)
604
+ linear_tmp = nn.Linear(in_f, out_f, bias=False, dtype=torch.half, device=embed_w.device)
605
+ linear_tmp.weight.data = embed_w # share, no copy
606
+ ml = _marlin.Layer(in_f, out_f, groupsize=GROUP_SIZE)
607
+ ml.pack(linear_tmp, scales)
608
+ del linear_tmp, scales
609
+ gc.collect()
610
+ torch.cuda.empty_cache()
611
+
612
+ self._lm_B = ml.B.to(embed_w.device)
613
+ self._lm_s = ml.s.to(embed_w.device)
614
+ self._lm_ws = torch.zeros(out_f // 128 * 16, dtype=torch.int, device=embed_w.device)
615
+ self._lm_out_f = out_f
616
+ self._lm_head_marlin = True
617
+
618
+ mb = (self._lm_B.nelement() * self._lm_B.element_size() +
619
+ self._lm_s.nelement() * self._lm_s.element_size()) / 1024**2
620
+ print(f" Marlin LM head packed: {out_f}x{in_f} -> {mb:.0f} MB ({time.time()-t0:.1f}s)")
621
+
622
+ def _lm_head(self, h):
623
+ """Compute LM head logits. Uses Marlin INT4 if available, else fp16."""
624
+ if hasattr(self, '_lm_head_marlin') and self._lm_head_marlin:
625
+ h_flat = h.view(-1, h.shape[-1])
626
+ out = torch.empty(h_flat.shape[0], self._lm_out_f, dtype=h.dtype, device=h.device)
627
+ _marlin.mul(h_flat, self._lm_B, out, self._lm_s, self._lm_ws)
628
+ return out.view(*h.shape[:-1], self._lm_out_f)
629
+ return F.linear(h, self.embed.weight)
630
+
631
+ def _dql(self, f, prefix, dev, unpermute=None):
632
+ qw = f.get_tensor(f'{prefix}.qweight').to(dev)
633
+ sc = f.get_tensor(f'{prefix}.scales').to(dev)
634
+ qz = f.get_tensor(f'{prefix}.qzeros').to(dev)
635
+ in_f = qw.shape[0] * PACK_FACTOR
636
+ out_f = qw.shape[1]
637
+ if HAS_MARLIN and in_f % 128 == 0 and out_f % 256 == 0:
638
+ return MarlinLinear(qw, sc, qz, unpermute=unpermute)
639
+ return DequantLinear(qw, sc, qz, unpermute=unpermute)
640
+
641
+ def _set(self, module, name, tensor):
642
+ """Replace a meta parameter with a real CUDA tensor."""
643
+ module._parameters[name] = nn.Parameter(tensor, requires_grad=False)
644
+
645
+ @staticmethod
646
+ def _compute_time_embedding(t_value: float, dim: int, theta: float = 10000.0) -> torch.Tensor:
647
+ """Sinusoidal embedding of scalar t_value into dim-dimensional vector."""
648
+ half_dim = dim // 2
649
+ inv_freq = torch.exp(-math.log(theta) * torch.arange(half_dim).float() / half_dim)
650
+ emb = t_value * inv_freq
651
+ return torch.cat([emb.cos(), emb.sin()]) # [dim]
652
+
653
+ @staticmethod
654
+ def _evict_cache():
655
+ """Force kernel to reclaim page cache on Jetson unified memory."""
656
+ import ctypes, gc
657
+ for sz in [4, 3, 2, 1]:
658
+ try:
659
+ buf = ctypes.create_string_buffer(sz * 1024 * 1024 * 1024)
660
+ del buf
661
+ gc.collect()
662
+ return
663
+ except (MemoryError, OSError):
664
+ continue
665
+
666
+ def _load_section(self, path, load_fn):
667
+ """Open safetensors, run load_fn, close, evict cache."""
668
+ import gc
669
+ D = str(self.device)
670
+ with safe_open(path, framework='pt', device=D) as f:
671
+ load_fn(f)
672
+ gc.collect()
673
+ torch.cuda.empty_cache()
674
+ self._evict_cache()
675
+
676
+ def _load(self):
677
+ import gc
678
+ path = os.path.join(self.model_path, 'consolidated.safetensors')
679
+ print(f"Loading {path}...")
680
+ t0 = time.time()
681
+ D, T = self.device, self.dtype
682
+ ep = 'mm_streams_embeddings.embedding_module'
683
+ enc_prefix = f'{ep}.whisper_encoder'
684
+ ea = self.p['multimodal']['whisper_model_args']['encoder_args']
685
+
686
+ # Section 1: Embeddings + output norm
687
+ def load_embed(f):
688
+ self._set(self.embed, 'weight', f.get_tensor(f'{ep}.tok_embeddings.weight').to(T))
689
+ self._set(self.norm, 'weight', f.get_tensor('norm.weight').to(T))
690
+ print(f" Embeddings loaded")
691
+ self._load_section(path, load_embed)
692
+
693
+ # Section 2: Encoder convolutions
694
+ def load_enc_conv(f):
695
+ self._set(self.encoder.conv1, 'weight', f.get_tensor(f'{enc_prefix}.conv_layers.0.conv.weight').to(T))
696
+ self._set(self.encoder.conv1, 'bias', f.get_tensor(f'{enc_prefix}.conv_layers.0.conv.bias').to(T))
697
+ self._set(self.encoder.conv2, 'weight', f.get_tensor(f'{enc_prefix}.conv_layers.1.conv.weight').to(T))
698
+ self._set(self.encoder.conv2, 'bias', f.get_tensor(f'{enc_prefix}.conv_layers.1.conv.bias').to(T))
699
+ self._set(self.encoder.norm, 'weight', f.get_tensor(f'{enc_prefix}.transformer.norm.weight').to(T))
700
+ print(f" Encoder conv loaded")
701
+ self._load_section(path, load_enc_conv)
702
+
703
+ # Section 3: Encoder layers (in batches of 8 to limit mmap cache growth)
704
+ n_enc = ea['n_layers']
705
+ batch = 8
706
+ for b_start in range(0, n_enc, batch):
707
+ b_end = min(b_start + batch, n_enc)
708
+ def load_enc_batch(f, start=b_start, end=b_end):
709
+ for i in range(start, end):
710
+ lp = f'{enc_prefix}.transformer.layers.{i}'
711
+ el = self.encoder.layers[i]
712
+ self._set(el.attn.wq, 'weight', f.get_tensor(f'{lp}.attention.wq.weight').to(T))
713
+ self._set(el.attn.wq, 'bias', f.get_tensor(f'{lp}.attention.wq.bias').to(T))
714
+ self._set(el.attn.wk, 'weight', f.get_tensor(f'{lp}.attention.wk.weight').to(T))
715
+ self._set(el.attn.wv, 'weight', f.get_tensor(f'{lp}.attention.wv.weight').to(T))
716
+ self._set(el.attn.wv, 'bias', f.get_tensor(f'{lp}.attention.wv.bias').to(T))
717
+ self._set(el.attn.wo, 'weight', f.get_tensor(f'{lp}.attention.wo.weight').to(T))
718
+ self._set(el.attn.wo, 'bias', f.get_tensor(f'{lp}.attention.wo.bias').to(T))
719
+ self._set(el.an, 'weight', f.get_tensor(f'{lp}.attention_norm.weight').to(T))
720
+ self._set(el.fn, 'weight', f.get_tensor(f'{lp}.ffn_norm.weight').to(T))
721
+ self._set(el.w1, 'weight', f.get_tensor(f'{lp}.feed_forward.w1.weight').to(T))
722
+ self._set(el.w2, 'weight', f.get_tensor(f'{lp}.feed_forward.w2.weight').to(T))
723
+ self._set(el.w2, 'bias', f.get_tensor(f'{lp}.feed_forward.w2.bias').to(T))
724
+ self._set(el.w3, 'weight', f.get_tensor(f'{lp}.feed_forward.w3.weight').to(T))
725
+ print(f" Encoder layers {start}-{end-1} loaded")
726
+ self._load_section(path, load_enc_batch)
727
+
728
+ print(f" Encoder loaded ({n_enc} layers)")
729
+
730
+ # Section 4: Projector
731
+ def load_proj(f):
732
+ pp = f'{ep}.audio_language_projection'
733
+ self._set(self.projector.l1, 'weight', f.get_tensor(f'{pp}.0.weight').to(T))
734
+ self._set(self.projector.l2, 'weight', f.get_tensor(f'{pp}.2.weight').to(T))
735
+ print(f" Projector loaded")
736
+ self._load_section(path, load_proj)
737
+
738
+ # Section 4b: Pack Marlin INT4 LM head (before decoder layers, max free mem)
739
+ if HAS_MARLIN:
740
+ self._pack_lm_head()
741
+
742
+ # Section 5: LM decoder layers (in batches of 13)
743
+ dec_batch = 13
744
+ for b_start in range(0, self.n_layers, dec_batch):
745
+ b_end = min(b_start + dec_batch, self.n_layers)
746
+ def load_dec_batch(f, start=b_start, end=b_end):
747
+ for i in range(start, end):
748
+ lp = f'layers.{i}'
749
+ dl = self.layers[i]
750
+ self._set(dl.an, 'weight', f.get_tensor(f'{lp}.attention_norm.weight').to(T))
751
+ self._set(dl.fn, 'weight', f.get_tensor(f'{lp}.ffn_norm.weight').to(T))
752
+ self._set(dl.ada0, 'weight', f.get_tensor(f'{lp}.ada_rms_norm_t_cond.0.weight').to(T))
753
+ self._set(dl.ada2, 'weight', f.get_tensor(f'{lp}.ada_rms_norm_t_cond.2.weight').to(T))
754
+ dl.attn.q_proj = self._dql(f, f'{lp}.self_attn.q_proj', D)
755
+ dl.attn.k_proj = self._dql(f, f'{lp}.self_attn.k_proj', D)
756
+ dl.attn.v_proj = self._dql(f, f'{lp}.self_attn.v_proj', D)
757
+ dl.attn.o_proj = self._dql(f, f'{lp}.self_attn.o_proj', D)
758
+ dl.gate_proj = self._dql(f, f'{lp}.mlp.gate_proj', D)
759
+ dl.up_proj = self._dql(f, f'{lp}.mlp.up_proj', D)
760
+ dl.down_proj = self._dql(f, f'{lp}.mlp.down_proj', D)
761
+ print(f" LM layers {start}-{end-1} loaded")
762
+ self._load_section(path, load_dec_batch)
763
+
764
+ backend = "Marlin fused INT4" if HAS_MARLIN else "DequantLinear"
765
+ print(f" LM decoder loaded ({self.n_layers} layers, {backend})")
766
+ gc.collect()
767
+ torch.cuda.empty_cache()
768
+ mem = torch.cuda.memory_allocated() / 1024**3
769
+ print(f" Done in {time.time()-t0:.1f}s, GPU: {mem:.2f} GB")
770
+
771
+ def _load_tokenizer(self):
772
+ """Load tekken tokenizer for decoding."""
773
+ try:
774
+ from mistral_common.tokens.tokenizers.tekken import Tekkenizer
775
+ self.tokenizer = Tekkenizer.from_file(
776
+ os.path.join(self.model_path, 'tekken.json'))
777
+ print(" Tekken tokenizer loaded")
778
+ except ImportError:
779
+ try:
780
+ from mistral_common.tokens.tokenizers.tekken import Tekken
781
+ self.tokenizer = Tekken.from_file(
782
+ os.path.join(self.model_path, 'tekken.json'))
783
+ print(" Tekken tokenizer loaded (legacy)")
784
+ except ImportError:
785
+ self.tokenizer = None
786
+ print(" WARNING: mistral_common not available, using fallback decoder")
787
+
788
+ def _pre_dequant(self):
789
+ """Offload encoder to CPU and pre-dequantize decoder weights into GPU cache.
790
+
791
+ After encoding is done, the encoder (~1.86 GB) is no longer needed on GPU.
792
+ Offloading it frees memory for caching pre-dequantized decoder weights,
793
+ which eliminates the per-token dequantization overhead.
794
+ """
795
+ import gc
796
+
797
+ if hasattr(self, '_decoder_cached') and self._decoder_cached:
798
+ return # already cached
799
+
800
+ t0 = time.time()
801
+
802
+ # Move encoder + projector to CPU to free GPU memory
803
+ self.encoder.cpu()
804
+ self.projector.cpu()
805
+ gc.collect()
806
+ torch.cuda.empty_cache()
807
+ self._evict_cache()
808
+
809
+ free, _ = torch.cuda.mem_get_info(0)
810
+ print(f" After encoder offload: {free/1024**3:.2f} GB free")
811
+
812
+ # Budget: leave 500 MB for KV cache + intermediates
813
+ budget = free - 500 * 1024 * 1024
814
+ used_bytes = 0
815
+ cached_count = 0
816
+
817
+ for i, dl in enumerate(self.layers):
818
+ projs = [dl.attn.q_proj, dl.attn.k_proj, dl.attn.v_proj, dl.attn.o_proj,
819
+ dl.gate_proj, dl.up_proj, dl.down_proj]
820
+
821
+ # Estimate net memory cost (fp16 weight minus freed INT4 buffers)
822
+ layer_fp16 = sum(
823
+ p.in_features * p.out_features * 2
824
+ for p in projs if isinstance(p, DequantLinear) and p._cached_w is None
825
+ )
826
+ layer_int4 = sum(
827
+ p.qweight.nelement() * 4 + p.scales.nelement() * 2 + p.qzeros.nelement() * 4
828
+ for p in projs
829
+ if isinstance(p, DequantLinear) and p.qweight is not None
830
+ )
831
+ net = layer_fp16 - layer_int4 # net increase in memory
832
+
833
+ if used_bytes + net > budget:
834
+ break
835
+
836
+ for p in projs:
837
+ if isinstance(p, DequantLinear) and p._cached_w is None and p.qweight is not None:
838
+ p.cache_weight(free_int4=True)
839
+
840
+ used_bytes += net
841
+ cached_count += 1
842
+ # Periodic cleanup to keep peak memory low
843
+ if cached_count % 5 == 0:
844
+ gc.collect()
845
+ torch.cuda.empty_cache()
846
+
847
+ gc.collect()
848
+ torch.cuda.empty_cache()
849
+ free2, _ = torch.cuda.mem_get_info(0)
850
+ print(f" Pre-dequantized {cached_count}/{self.n_layers} layers in {time.time()-t0:.1f}s, "
851
+ f"{free2/1024**3:.2f} GB free")
852
+ self._decoder_cached = True
853
+
854
+ def _restore_encoder(self):
855
+ """Move encoder back to GPU for the next transcription.
856
+
857
+ Frees cached decoder weights first to make room, then reloads
858
+ INT4 weights for layers that had their buffers freed.
859
+ """
860
+ import gc
861
+
862
+ if not hasattr(self, '_decoder_cached') or not self._decoder_cached:
863
+ return
864
+
865
+ t0 = time.time()
866
+
867
+ # Free cached decoder weights
868
+ needs_reload = []
869
+ for i, dl in enumerate(self.layers):
870
+ for p in [dl.attn.q_proj, dl.attn.k_proj, dl.attn.v_proj, dl.attn.o_proj,
871
+ dl.gate_proj, dl.up_proj, dl.down_proj]:
872
+ if isinstance(p, DequantLinear):
873
+ if p._cached_w is not None and p.qweight is None:
874
+ needs_reload.append(i)
875
+ p.uncache_weight()
876
+
877
+ gc.collect()
878
+ torch.cuda.empty_cache()
879
+ self._evict_cache()
880
+
881
+ # Move encoder + projector back to GPU
882
+ self.encoder.to(self.device)
883
+ self.projector.to(self.device)
884
+
885
+ # Reload INT4 weights for layers that were freed
886
+ if needs_reload:
887
+ needs_reload = sorted(set(needs_reload))
888
+ path = os.path.join(self.model_path, 'consolidated.safetensors')
889
+ D = str(self.device)
890
+ with safe_open(path, framework='pt', device=D) as f:
891
+ for i in needs_reload:
892
+ lp = f'layers.{i}'
893
+ dl = self.layers[i]
894
+ dl.attn.q_proj = self._dql(f, f'{lp}.self_attn.q_proj', D)
895
+ dl.attn.k_proj = self._dql(f, f'{lp}.self_attn.k_proj', D)
896
+ dl.attn.v_proj = self._dql(f, f'{lp}.self_attn.v_proj', D)
897
+ dl.attn.o_proj = self._dql(f, f'{lp}.self_attn.o_proj', D)
898
+ dl.gate_proj = self._dql(f, f'{lp}.mlp.gate_proj', D)
899
+ dl.up_proj = self._dql(f, f'{lp}.mlp.up_proj', D)
900
+ dl.down_proj = self._dql(f, f'{lp}.mlp.down_proj', D)
901
+ gc.collect()
902
+ torch.cuda.empty_cache()
903
+ print(f" Reloaded {len(needs_reload)} decoder layers from disk")
904
+
905
+ gc.collect()
906
+ torch.cuda.empty_cache()
907
+ self._decoder_cached = False
908
+ print(f" Encoder restored in {time.time()-t0:.1f}s")
909
+
910
+ def decode_tokens(self, ids):
911
+ if self.tokenizer is not None:
912
+ try:
913
+ return self.tokenizer.decode(ids)
914
+ except:
915
+ pass
916
+ # Fallback: decode as UTF-8 byte tokens
917
+ # Token IDs 0-255 are byte tokens in tekken
918
+ result = bytearray()
919
+ for tid in ids:
920
+ if 0 <= tid <= 255:
921
+ result.append(tid)
922
+ elif 256 <= tid < 131072:
923
+ # BPE merge token β€” would need full vocab to decode
924
+ result.extend(f'[{tid}]'.encode())
925
+ try:
926
+ return result.decode('utf-8', errors='replace')
927
+ except:
928
+ return str(ids)
929
+
930
+ @staticmethod
931
+ def _pad_audio(audio: np.ndarray) -> np.ndarray:
932
+ """Pad audio for streaming alignment: left silence + alignment + right silence.
933
+
934
+ Left padding of N_LEFT_PAD_TOKENS tokens aligns with the SPAD tokens
935
+ in the prompt. Right padding accounts for delay + offline buffer.
936
+ """
937
+ n = len(audio)
938
+ left_pad = N_LEFT_PAD_TOKENS * RAW_AUDIO_PER_TOK
939
+ right_align = (RAW_AUDIO_PER_TOK - (n % RAW_AUDIO_PER_TOK)) % RAW_AUDIO_PER_TOK
940
+ right_pad = right_align + N_RIGHT_PAD_BASE * RAW_AUDIO_PER_TOK
941
+ return np.pad(audio.astype(np.float32), (left_pad, right_pad))
942
+
943
+ @torch.no_grad()
944
+ def transcribe(self, audio: np.ndarray, max_tokens: int = 512) -> str:
945
+ """Transcribe audio (float32, 16kHz mono) to text."""
946
+ t0 = time.time()
947
+
948
+ # Evict page cache before inference (Jetson unified memory)
949
+ self._evict_cache()
950
+ free, _ = torch.cuda.mem_get_info(0)
951
+ print(f" CUDA free before inference: {free/1024**3:.2f} GB")
952
+
953
+ # Restore encoder to GPU if it was offloaded (only needed without Marlin)
954
+ if not HAS_MARLIN:
955
+ self._restore_encoder()
956
+
957
+ # 0. Pad audio for streaming alignment
958
+ audio = self._pad_audio(audio)
959
+ print(f" padded audio: {len(audio)} samples ({len(audio)/SAMPLE_RATE:.1f}s)")
960
+
961
+ # 1. Mel spectrogram
962
+ mel = self._mel(audio)
963
+ # Pad mel for downsample alignment
964
+ enc_out_len = (mel.shape[2] - 1) // 2 + 1 # after conv stride-2
965
+ remainder = enc_out_len % self.ds_factor
966
+ if remainder != 0:
967
+ mel = F.pad(mel, (0, (self.ds_factor - remainder) * 2))
968
+ print(f" mel: {mel.shape}")
969
+
970
+ # 2. Encode
971
+ t1 = time.time()
972
+ enc = self.encoder(mel) # [1, T_enc, 1280]
973
+ print(f" enc: {enc.shape} ({time.time()-t1:.2f}s)")
974
+ del mel # free mel to save memory
975
+
976
+ # 3. Downsample 4x: concat groups of ds_factor frames
977
+ T_enc = enc.shape[1]
978
+ T_ds = T_enc // self.ds_factor
979
+ enc_ds = enc[:, :T_ds * self.ds_factor, :].reshape(
980
+ 1, T_ds, self.ds_factor * self.enc_h) # [1, T_ds, 5120]
981
+ del enc # free encoder output
982
+
983
+ # 4. Project (adapter)
984
+ adapter = self.projector(enc_ds) # [1, T_ds, 3072]
985
+ del enc_ds
986
+ print(f" adapter: {adapter.shape}")
987
+
988
+ # 5. Offload encoder, pre-dequantize decoder weights (only without Marlin)
989
+ if not HAS_MARLIN:
990
+ self._pre_dequant()
991
+
992
+ # 6. Build prompt: [BOS] + [SPAD] * (n_left_pad + delay_tokens)
993
+ prompt_len = 1 + self.n_left_pad + self.delay_tokens
994
+ prompt_ids = [TOKEN_BOS] + [TOKEN_STREAMING_PAD] * (self.n_left_pad + self.delay_tokens)
995
+
996
+ # Clamp prompt to adapter length
997
+ if prompt_len > T_ds:
998
+ prompt_len = T_ds
999
+ prompt_ids = prompt_ids[:T_ds]
1000
+
1001
+ print(f" prompt: {prompt_len} tokens, adapter: {T_ds} positions, gen budget: {T_ds - prompt_len}")
1002
+
1003
+ # 7. Allocate KV cache for the full sequence
1004
+ kv_cache = KVCache(self.n_layers, T_ds, self.n_kv_heads, self.head_dim,
1005
+ self.device, self.dtype)
1006
+
1007
+ # 8. Prefill: build embeddings = adapter + tok_embed for prompt positions
1008
+ prompt_tok = torch.tensor([prompt_ids], device=self.device)
1009
+ tok_emb = self.embed(prompt_tok) # [1, prompt_len, 3072]
1010
+ input_emb = adapter[:, :prompt_len, :] + tok_emb # ADDITION
1011
+
1012
+ cos = self.cos[:prompt_len]
1013
+ sin = self.sin[:prompt_len]
1014
+
1015
+ h = input_emb
1016
+ for i, layer in enumerate(self.layers):
1017
+ h = layer(h, cos, sin, cache=kv_cache, layer_idx=i,
1018
+ is_causal=True, t_cond=self.t_cond)
1019
+ kv_cache.advance(prompt_len)
1020
+
1021
+ h = self.norm(h)
1022
+
1023
+ # lm_head (Marlin INT4 if available, else fp16 tied embed)
1024
+ logits = self._lm_head(h[:, -1:, :])
1025
+ next_tok = logits.argmax(-1).item()
1026
+ # Diagnostic: top-5 tokens and logit stats
1027
+ topk = torch.topk(logits[0, 0], 10)
1028
+ print(f" prefill done ({time.time()-t1:.2f}s), first token: {next_tok}")
1029
+ print(f" logits: min={logits.min():.2f} max={logits.max():.2f} mean={logits.mean():.4f}")
1030
+ print(f" top-10: {list(zip(topk.indices.tolist(), [f'{v:.2f}' for v in topk.values.tolist()]))}")
1031
+ print(f" adapter stats: min={adapter.min():.4f} max={adapter.max():.4f} mean={adapter.mean():.6f}")
1032
+
1033
+ # 9. Generate: continue from prompt_len to T_ds
1034
+ generated = []
1035
+ pos = prompt_len
1036
+ t2 = time.time()
1037
+
1038
+ # Pre-allocate token tensor to avoid per-step allocation
1039
+ _tok_buf = torch.empty(1, 1, dtype=torch.long, device=self.device)
1040
+
1041
+ while pos < T_ds and next_tok != TOKEN_EOS and len(generated) < max_tokens:
1042
+ generated.append(next_tok)
1043
+
1044
+ # Input = adapter[pos] + tok_embed(next_tok)
1045
+ _tok_buf.fill_(next_tok)
1046
+ te = self.embed(_tok_buf)
1047
+ inp = adapter[:, pos:pos+1, :] + te
1048
+
1049
+ cos_s = self.cos[pos:pos+1]
1050
+ sin_s = self.sin[pos:pos+1]
1051
+
1052
+ h = inp
1053
+ for i, layer in enumerate(self.layers):
1054
+ h = layer(h, cos_s, sin_s, cache=kv_cache, layer_idx=i,
1055
+ is_causal=False, t_cond=self.t_cond)
1056
+ kv_cache.advance(1)
1057
+
1058
+ h = self.norm(h)
1059
+ logits = self._lm_head(h)
1060
+ next_tok = logits[:, -1, :].argmax(-1).item()
1061
+ pos += 1
1062
+
1063
+ # Progress every 25 tokens
1064
+ if len(generated) % 25 == 0:
1065
+ elapsed = time.time() - t2
1066
+ tps = len(generated) / max(elapsed, 0.001)
1067
+ n_text = sum(1 for t in generated if t >= 1000)
1068
+ print(f" step {pos}/{T_ds}: {len(generated)} tok ({n_text} text), {tps:.1f} tok/s")
1069
+
1070
+ if next_tok == TOKEN_EOS:
1071
+ generated.append(TOKEN_EOS)
1072
+
1073
+ dt = time.time() - t2
1074
+ n_gen = len(generated)
1075
+
1076
+ # Filter: text tokens are >= 1000 (special tokens are < 1000)
1077
+ text_toks = [t for t in generated if t >= 1000]
1078
+ special_toks = [t for t in generated if t < 1000]
1079
+ text = self.decode_tokens(text_toks)
1080
+
1081
+ print(f" gen: {n_gen} tokens ({len(special_toks)} special, {len(text_toks)} text) in {dt:.2f}s "
1082
+ f"({n_gen/max(dt,0.001):.1f} tok/s)")
1083
+ if special_toks:
1084
+ print(f" special tokens: {sorted(set(special_toks))[:10]}")
1085
+ if text_toks:
1086
+ print(f" first text IDs: {text_toks[:20]}")
1087
+ print(f" total: {time.time()-t0:.2f}s")
1088
+
1089
+ return text
1090
+
1091
+ def _mel(self, audio: np.ndarray) -> torch.Tensor:
1092
+ """Compute Whisper-style log-mel spectrogram.
1093
+
1094
+ Uses log10, max-relative normalization, and [0,1] scaling matching
1095
+ the standard Whisper preprocessing pipeline.
1096
+ """
1097
+ at = torch.from_numpy(audio).float()
1098
+ win = torch.hann_window(WINDOW_SIZE)
1099
+ stft = torch.stft(at, WINDOW_SIZE, HOP_LENGTH, WINDOW_SIZE, win,
1100
+ return_complex=True)
1101
+ magnitudes = stft[..., :-1].abs().pow(2)
1102
+
1103
+ # Mel filterbank
1104
+ fb = self._melfb(WINDOW_SIZE // 2 + 1, N_MELS, SAMPLE_RATE, 0.0, 8000.0)
1105
+ mel = fb @ magnitudes
1106
+
1107
+ # Voxtral log normalization: fixed global max, NOT per-sample
1108
+ log_mel = torch.log10(mel.clamp(min=1e-10))
1109
+ log_mel = torch.clamp(log_mel, min=GLOBAL_LOG_MEL_MAX - 8.0) # floor at -6.5
1110
+ log_mel = (log_mel + 4.0) / 4.0
1111
+
1112
+ return log_mel.unsqueeze(0).to(device=self.device, dtype=self.dtype)
1113
+
1114
+ @staticmethod
1115
+ def _melfb(n_fft_bins, n_mels, sr, fmin=0.0, fmax=8000.0):
1116
+ """Slaney mel filterbank matching transformers.audio_utils.mel_filter_bank."""
1117
+ # Slaney mel scale (piecewise linear/log, NOT HTK)
1118
+ f_sp = 200.0 / 3.0
1119
+ min_log_hz = 1000.0
1120
+ min_log_mel = min_log_hz / f_sp # 15.0
1121
+ logstep = np.log(6.4) / 27.0
1122
+
1123
+ def hz_to_mel(f):
1124
+ f = np.asarray(f, dtype=np.float64)
1125
+ mel = np.where(f < min_log_hz, f / f_sp,
1126
+ min_log_mel + np.log(f / min_log_hz) / logstep)
1127
+ return mel
1128
+
1129
+ def mel_to_hz(m):
1130
+ m = np.asarray(m, dtype=np.float64)
1131
+ hz = np.where(m < min_log_mel, m * f_sp,
1132
+ min_log_hz * np.exp(logstep * (m - min_log_mel)))
1133
+ return hz
1134
+
1135
+ mel_min = hz_to_mel(fmin)
1136
+ mel_max = hz_to_mel(fmax)
1137
+ mels = np.linspace(mel_min, mel_max, n_mels + 2)
1138
+ freqs = mel_to_hz(mels)
1139
+
1140
+ fft_freqs = np.linspace(0, sr / 2, n_fft_bins)
1141
+ fb = np.zeros((n_mels, n_fft_bins))
1142
+ for i in range(n_mels):
1143
+ low, center, high = freqs[i], freqs[i + 1], freqs[i + 2]
1144
+ for j in range(n_fft_bins):
1145
+ if low <= fft_freqs[j] <= center and center > low:
1146
+ fb[i, j] = (fft_freqs[j] - low) / (center - low)
1147
+ elif center < fft_freqs[j] <= high and high > center:
1148
+ fb[i, j] = (high - fft_freqs[j]) / (high - center)
1149
+
1150
+ # Slaney normalization: area = 1 per filter
1151
+ enorm = 2.0 / (freqs[2:n_mels+2] - freqs[:n_mels])
1152
+ fb *= enorm[:, np.newaxis]
1153
+
1154
+ return torch.tensor(fb, dtype=torch.float32)
1155
+
1156
+
1157
+ # ─── WebSocket Server ────────────────────────────────────────────────────────
1158
+
1159
+ async def handle_ws(ws, model):
1160
+ sid = str(uuid.uuid4())[:8]
1161
+ await ws.send(json.dumps({"type": "session.created", "session": {"id": sid}}))
1162
+
1163
+ buf = bytearray()
1164
+ async for msg in ws:
1165
+ try:
1166
+ ev = json.loads(msg)
1167
+ t = ev.get("type", "")
1168
+
1169
+ if t == "session.update":
1170
+ pass # model name ignored, we only have one
1171
+
1172
+ elif t == "input_audio_buffer.append":
1173
+ buf.extend(base64.b64decode(ev.get("audio", "")))
1174
+
1175
+ elif t == "input_audio_buffer.commit":
1176
+ if ev.get("final"):
1177
+ break
1178
+ if not buf:
1179
+ continue
1180
+
1181
+ pcm = np.frombuffer(bytes(buf), dtype=np.int16).astype(np.float32) / 32768.0
1182
+ buf = bytearray()
1183
+ dur = len(pcm) / SAMPLE_RATE
1184
+ print(f"[{sid}] {dur:.1f}s audio")
1185
+
1186
+ text = model.transcribe(pcm)
1187
+
1188
+ if text:
1189
+ await ws.send(json.dumps({"type": "transcription.delta", "delta": text}))
1190
+ await ws.send(json.dumps({
1191
+ "type": "transcription.done", "text": text,
1192
+ "usage": {"audio_duration_s": round(dur, 1)}
1193
+ }))
1194
+ print(f"[{sid}] -> {text[:100]}")
1195
+
1196
+ except Exception as e:
1197
+ import traceback
1198
+ traceback.print_exc()
1199
+ await ws.send(json.dumps({"type": "error", "error": str(e)}))
1200
+
1201
+
1202
+ async def serve(model, host='0.0.0.0', port=8000):
1203
+ import websockets
1204
+ print(f"\nWebSocket server at ws://{host}:{port}/v1/realtime")
1205
+
1206
+ async def handler(ws, path=None):
1207
+ await handle_ws(ws, model)
1208
+
1209
+ try:
1210
+ async with websockets.serve(handler, host, port):
1211
+ print("Ready.")
1212
+ await asyncio.Future()
1213
+ except TypeError:
1214
+ await websockets.serve(handler, host, port)
1215
+ print("Ready.")
1216
+ await asyncio.Future()
1217
+
1218
+
1219
+ # ─── Main ────────────────────────────────────────────────────────────────────
1220
+
1221
+ def main():
1222
+ ap = argparse.ArgumentParser()
1223
+ ap.add_argument('--model-path', default='/workspace/voxtral-quant/models/voxtral-rtn-4bit-packed-vllm')
1224
+ ap.add_argument('--port', type=int, default=8000)
1225
+ ap.add_argument('--test', help='WAV file to transcribe (skip server)')
1226
+ ap.add_argument('--device', default='cuda')
1227
+ ap.add_argument('--no-compile', action='store_true', help='Disable torch.compile')
1228
+ args = ap.parse_args()
1229
+
1230
+ print("=" * 60)
1231
+ print("Voxtral Mini 4B Realtime β€” Jetson Orin Nano 8GB")
1232
+ print("=" * 60)
1233
+
1234
+ if torch.cuda.is_available():
1235
+ props = torch.cuda.get_device_properties(0)
1236
+ print(f"GPU: {props.name}, {props.total_memory/1024**3:.1f} GB")
1237
+ free, total = torch.cuda.mem_get_info(0)
1238
+ print(f"CUDA free: {free/1024**3:.2f} GB / {total/1024**3:.2f} GB")
1239
+ # Force kernel to reclaim page cache (critical on Jetson unified memory)
1240
+ if free < 5 * 1024**3:
1241
+ import ctypes, gc
1242
+ print("Reclaiming page cache...")
1243
+ for sz in [4, 3, 2, 1]:
1244
+ try:
1245
+ buf = ctypes.create_string_buffer(sz * 1024 * 1024 * 1024)
1246
+ del buf
1247
+ gc.collect()
1248
+ break
1249
+ except (MemoryError, OSError):
1250
+ continue
1251
+ free2, _ = torch.cuda.mem_get_info(0)
1252
+ print(f"CUDA free after reclaim: {free2/1024**3:.2f} GB")
1253
+
1254
+ model = VoxtralModel(args.model_path, args.device, compile=not args.no_compile)
1255
+ model._load_tokenizer()
1256
+
1257
+ if args.test:
1258
+ import soundfile as sf
1259
+ audio, sr = sf.read(args.test, dtype='float32')
1260
+ if audio.ndim > 1:
1261
+ audio = audio.mean(1)
1262
+ if sr != SAMPLE_RATE:
1263
+ try:
1264
+ import soxr
1265
+ audio = soxr.resample(audio, sr, SAMPLE_RATE, quality='HQ')
1266
+ except ImportError:
1267
+ import librosa
1268
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
1269
+ print(f"\nTest: {args.test} ({len(audio)/SAMPLE_RATE:.1f}s)")
1270
+ text = model.transcribe(audio)
1271
+ print(f"\nResult: {text}")
1272
+ else:
1273
+ asyncio.run(serve(model, port=args.port))
1274
+
1275
+
1276
+ if __name__ == '__main__':
1277
+ main()
tekken.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8434af1d39eba99f0ef46cf1450bf1a63fa941a26933a1ef5dbbf4adf0d00e44
3
+ size 14910348
voxtral_client.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Voxtral Realtime Transcription β€” Gradio client for Jetson WebSocket API.
3
+
4
+ Captures microphone audio and streams it to the Voxtral server running on
5
+ a Jetson Orin Nano for live speech-to-text transcription.
6
+
7
+ Usage:
8
+ pip install gradio websockets numpy
9
+ python voxtral_client.py
10
+ """
11
+
12
+ import asyncio
13
+ import base64
14
+ import json
15
+ import time
16
+
17
+ import numpy as np
18
+ import gradio as gr
19
+
20
+ DEFAULT_WS_URL = "ws://localhost:8000/v1/realtime"
21
+ TARGET_SR = 16000
22
+ CHUNK_SAMPLES = TARGET_SR // 2 # 0.5s chunks for WebSocket
23
+
24
+
25
+ # ─── Audio Utilities ─────────────────────────────────────────────────────────
26
+
27
+ def resample_audio(audio, orig_sr, target_sr):
28
+ """Resample audio to target sample rate."""
29
+ if orig_sr == target_sr:
30
+ return audio
31
+ try:
32
+ from scipy.signal import resample_poly
33
+ from math import gcd
34
+ g = gcd(int(orig_sr), int(target_sr))
35
+ return resample_poly(audio, target_sr // g, orig_sr // g).astype(np.float32)
36
+ except ImportError:
37
+ # Linear interpolation fallback (no extra dependencies)
38
+ n_out = int(len(audio) * target_sr / orig_sr)
39
+ return np.interp(
40
+ np.linspace(0, len(audio) - 1, n_out),
41
+ np.arange(len(audio)),
42
+ audio,
43
+ ).astype(np.float32)
44
+
45
+
46
+ def audio_to_pcm16(audio_tuple):
47
+ """Convert Gradio audio (sr, ndarray) to 16kHz PCM16 mono. Returns (pcm16, duration)."""
48
+ if audio_tuple is None:
49
+ return None, 0.0
50
+ sr, data = audio_tuple
51
+ if data is None or len(data) == 0:
52
+ return None, 0.0
53
+
54
+ # Mono
55
+ if data.ndim > 1:
56
+ data = data.mean(axis=1)
57
+ # Normalize to float32 [-1, 1]
58
+ if np.issubdtype(data.dtype, np.integer):
59
+ data = data.astype(np.float32) / np.iinfo(data.dtype).max
60
+ else:
61
+ data = data.astype(np.float32)
62
+
63
+ if sr != TARGET_SR:
64
+ data = resample_audio(data, sr, TARGET_SR)
65
+
66
+ pcm16 = (data * 32768.0).clip(-32768, 32767).astype(np.int16)
67
+ return pcm16, len(pcm16) / TARGET_SR
68
+
69
+
70
+ # ─── WebSocket Client ────────────────────────────────────────────────────────
71
+
72
+ async def _transcribe_async(pcm16, ws_url):
73
+ """Send PCM16 audio to Voxtral WebSocket, return (text, usage, error)."""
74
+ import websockets
75
+
76
+ async with websockets.connect(ws_url, close_timeout=5, open_timeout=10) as ws:
77
+ msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=10))
78
+ if msg["type"] != "session.created":
79
+ return "", {}, f"Unexpected first message: {msg['type']}"
80
+
81
+ await ws.send(json.dumps({"type": "session.update"}))
82
+
83
+ # Send audio in 0.5s chunks
84
+ for i in range(0, len(pcm16), CHUNK_SAMPLES):
85
+ chunk = pcm16[i:i + CHUNK_SAMPLES]
86
+ await ws.send(json.dumps({
87
+ "type": "input_audio_buffer.append",
88
+ "audio": base64.b64encode(chunk.tobytes()).decode("ascii"),
89
+ }))
90
+
91
+ await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
92
+
93
+ text = ""
94
+ usage = {}
95
+ while True:
96
+ msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=120))
97
+ if msg["type"] == "transcription.delta":
98
+ text += msg.get("delta", "")
99
+ elif msg["type"] == "transcription.done":
100
+ text = msg.get("text", text)
101
+ usage = msg.get("usage", {})
102
+ break
103
+ elif msg["type"] == "error":
104
+ return "", {}, f"Server error: {msg.get('error', 'unknown')}"
105
+
106
+ # Signal done
107
+ await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
108
+
109
+ return text, usage, None
110
+
111
+
112
+ def transcribe_ws(pcm16, ws_url):
113
+ """Synchronous wrapper β€” safe to call from Gradio thread callbacks."""
114
+ loop = asyncio.new_event_loop()
115
+ try:
116
+ return loop.run_until_complete(_transcribe_async(pcm16, ws_url))
117
+ finally:
118
+ loop.close()
119
+
120
+
121
+ # ─── Gradio Callbacks ────────────────────────────────────────────────────────
122
+
123
+ def format_history(entries):
124
+ if not entries:
125
+ return ""
126
+ lines = []
127
+ for e in entries:
128
+ lines.append(f"[{e['time']}] ({e['audio_s']}s audio, {e['elapsed_s']}s total) {e['text']}")
129
+ return "\n".join(lines)
130
+
131
+
132
+ def process_recording(audio, ws_url, history_state):
133
+ """Transcribe recorded audio and update history."""
134
+ if audio is None:
135
+ return "", history_state, "No audio recorded", format_history(history_state)
136
+
137
+ pcm16, duration = audio_to_pcm16(audio)
138
+ if pcm16 is None or len(pcm16) < TARGET_SR // 10: # < 0.1s
139
+ return "", history_state, "Audio too short", format_history(history_state)
140
+
141
+ t0 = time.time()
142
+ try:
143
+ text, usage, error = transcribe_ws(pcm16, ws_url)
144
+ except Exception as e:
145
+ err = str(e)
146
+ if "Connect call failed" in err or "refused" in err:
147
+ err = f"Cannot reach server at {ws_url} β€” is it running?"
148
+ return "", history_state, f"Error: {err}", format_history(history_state)
149
+
150
+ if error:
151
+ return "", history_state, f"Error: {error}", format_history(history_state)
152
+
153
+ elapsed = time.time() - t0
154
+ text = text.strip()
155
+
156
+ entry = {
157
+ "time": time.strftime("%H:%M:%S"),
158
+ "text": text,
159
+ "audio_s": round(duration, 1),
160
+ "elapsed_s": round(elapsed, 1),
161
+ }
162
+ history_state = history_state + [entry]
163
+
164
+ status = f"Transcribed {duration:.1f}s audio in {elapsed:.1f}s"
165
+ return text, history_state, status, format_history(history_state)
166
+
167
+
168
+ def clear_history():
169
+ return [], "", ""
170
+
171
+
172
+ # ─── UI ──────────────────────────────────────────────────────────────────────
173
+
174
+ with gr.Blocks(title="Voxtral Realtime Transcription", theme=gr.themes.Soft()) as app:
175
+ history_state = gr.State([])
176
+
177
+ gr.Markdown("# Voxtral Realtime Transcription")
178
+ gr.Markdown(
179
+ "Record from your microphone or upload a WAV file. "
180
+ "Audio is sent to Voxtral on a Jetson Orin Nano for transcription."
181
+ )
182
+
183
+ ws_url = gr.Textbox(value=DEFAULT_WS_URL, label="Jetson WebSocket URL")
184
+
185
+ with gr.Row():
186
+ with gr.Column(scale=3):
187
+ audio_input = gr.Audio(
188
+ sources=["microphone", "upload"],
189
+ type="numpy",
190
+ label="Microphone (click to record, click again to stop)",
191
+ )
192
+ with gr.Column(scale=1, min_width=120):
193
+ transcribe_btn = gr.Button("Transcribe", variant="primary", size="lg")
194
+
195
+ status = gr.Textbox(label="Status", interactive=False)
196
+
197
+ current_text = gr.Textbox(label="Current Transcription", interactive=False, lines=3)
198
+
199
+ with gr.Row():
200
+ history_text = gr.Textbox(
201
+ label="Transcript History",
202
+ interactive=False,
203
+ lines=8,
204
+ scale=5,
205
+ )
206
+ clear_btn = gr.Button("Clear History", scale=1)
207
+
208
+ # Auto-transcribe when mic recording stops
209
+ audio_input.stop_recording(
210
+ fn=process_recording,
211
+ inputs=[audio_input, ws_url, history_state],
212
+ outputs=[current_text, history_state, status, history_text],
213
+ )
214
+
215
+ # Manual transcribe button (for uploads or re-transcription)
216
+ transcribe_btn.click(
217
+ fn=process_recording,
218
+ inputs=[audio_input, ws_url, history_state],
219
+ outputs=[current_text, history_state, status, history_text],
220
+ )
221
+
222
+ clear_btn.click(fn=clear_history, outputs=[history_state, history_text, current_text])
223
+
224
+
225
+ if __name__ == "__main__":
226
+ app.launch(server_name="0.0.0.0", server_port=7860)