Initial upload: INT4 RTN quantized Voxtral Mini 4B for Jetson
Browse files- .gitattributes +1 -0
- consolidated.safetensors +3 -0
- kernels/fused_ops.cu +241 -0
- params.json +65 -0
- scripts/jetson_serve_sdpa.py +1277 -0
- tekken.json +3 -0
- voxtral_client.py +226 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tekken.json filter=lfs diff=lfs merge=lfs -text
|
consolidated.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fd25f9d675042c37b0a9b051a5333ef001c129d302461995bdc3e7b321c3b2b6
|
| 3 |
+
size 4382321392
|
kernels/fused_ops.cu
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Fused CUDA kernels for Voxtral decoder on Jetson Orin Nano (SM87).
|
| 2 |
+
//
|
| 3 |
+
// Three kernels that collapse ~500 PyTorch kernel launches per token into ~80:
|
| 4 |
+
// 1. fused_rmsnorm: 5-6 kernels β 1 per call (52 calls/token)
|
| 5 |
+
// 2. fused_rope_hf: ~14 kernels β 2 per call (26 calls/token)
|
| 6 |
+
// 3. fused_silu_mul: 2 kernels β 1 per call (26 calls/token)
|
| 7 |
+
//
|
| 8 |
+
// Build: JIT via torch.utils.cpp_extension.load() with -arch=sm_87
|
| 9 |
+
// All kernels: fp16 I/O, fp32 internal accumulation where needed.
|
| 10 |
+
|
| 11 |
+
#include <torch/extension.h>
|
| 12 |
+
#include <cuda_fp16.h>
|
| 13 |
+
#include <cuda_runtime.h>
|
| 14 |
+
|
| 15 |
+
#define BLOCK_SIZE 256
|
| 16 |
+
#define WARP_SIZE 32
|
| 17 |
+
|
| 18 |
+
// ============================================================================
|
| 19 |
+
// Warp-level reduction
|
| 20 |
+
// ============================================================================
|
| 21 |
+
|
| 22 |
+
__device__ __forceinline__ float warp_reduce_sum(float val) {
|
| 23 |
+
#pragma unroll
|
| 24 |
+
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1)
|
| 25 |
+
val += __shfl_xor_sync(0xffffffff, val, offset);
|
| 26 |
+
return val;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// ============================================================================
|
| 30 |
+
// Kernel 1: Fused RMSNorm
|
| 31 |
+
// ============================================================================
|
| 32 |
+
//
|
| 33 |
+
// Replaces: x.float() β pow(2) β mean β add(eps) β rsqrt β mul(x) β half β mul(weight)
|
| 34 |
+
// One block per row. 256 threads, each handles dim/256 elements.
|
| 35 |
+
// Warp shuffle + shared memory for cross-warp reduction.
|
| 36 |
+
|
| 37 |
+
__global__ void fused_rmsnorm_kernel(
|
| 38 |
+
const half* __restrict__ x,
|
| 39 |
+
const half* __restrict__ w,
|
| 40 |
+
half* __restrict__ out,
|
| 41 |
+
int dim,
|
| 42 |
+
float eps
|
| 43 |
+
) {
|
| 44 |
+
const int row = blockIdx.x;
|
| 45 |
+
const half* x_row = x + (int64_t)row * dim;
|
| 46 |
+
half* out_row = out + (int64_t)row * dim;
|
| 47 |
+
|
| 48 |
+
// Phase 1: partial sum of squares (fp32 accumulation)
|
| 49 |
+
float partial = 0.0f;
|
| 50 |
+
for (int i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
|
| 51 |
+
float v = __half2float(x_row[i]);
|
| 52 |
+
partial += v * v;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
// Phase 2: warp reduction
|
| 56 |
+
partial = warp_reduce_sum(partial);
|
| 57 |
+
|
| 58 |
+
// Phase 3: cross-warp reduction via shared memory
|
| 59 |
+
__shared__ float warp_sums[BLOCK_SIZE / WARP_SIZE];
|
| 60 |
+
const int lane = threadIdx.x % WARP_SIZE;
|
| 61 |
+
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 62 |
+
|
| 63 |
+
if (lane == 0)
|
| 64 |
+
warp_sums[warp_id] = partial;
|
| 65 |
+
__syncthreads();
|
| 66 |
+
|
| 67 |
+
float total = 0.0f;
|
| 68 |
+
if (warp_id == 0) {
|
| 69 |
+
total = (lane < (BLOCK_SIZE / WARP_SIZE)) ? warp_sums[lane] : 0.0f;
|
| 70 |
+
total = warp_reduce_sum(total);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// Phase 4: broadcast norm factor
|
| 74 |
+
__shared__ float s_norm;
|
| 75 |
+
if (threadIdx.x == 0)
|
| 76 |
+
s_norm = rsqrtf(total / (float)dim + eps);
|
| 77 |
+
__syncthreads();
|
| 78 |
+
|
| 79 |
+
// Phase 5: normalize and write
|
| 80 |
+
const float nf = s_norm;
|
| 81 |
+
for (int i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
|
| 82 |
+
float v = __half2float(x_row[i]);
|
| 83 |
+
float wt = __half2float(w[i]);
|
| 84 |
+
out_row[i] = __float2half(v * nf * wt);
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
torch::Tensor fused_rmsnorm(torch::Tensor x, torch::Tensor weight, float eps) {
|
| 89 |
+
TORCH_CHECK(x.is_cuda() && x.scalar_type() == torch::kHalf);
|
| 90 |
+
TORCH_CHECK(weight.is_cuda() && weight.scalar_type() == torch::kHalf);
|
| 91 |
+
|
| 92 |
+
auto x_c = x.contiguous();
|
| 93 |
+
auto out = torch::empty_like(x_c);
|
| 94 |
+
const int dim = x_c.size(-1);
|
| 95 |
+
const int rows = x_c.numel() / dim;
|
| 96 |
+
|
| 97 |
+
fused_rmsnorm_kernel<<<rows, BLOCK_SIZE>>>(
|
| 98 |
+
reinterpret_cast<const half*>(x_c.data_ptr<at::Half>()),
|
| 99 |
+
reinterpret_cast<const half*>(weight.data_ptr<at::Half>()),
|
| 100 |
+
reinterpret_cast<half*>(out.data_ptr<at::Half>()),
|
| 101 |
+
dim, eps
|
| 102 |
+
);
|
| 103 |
+
|
| 104 |
+
return out;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
// ============================================================================
|
| 108 |
+
// Kernel 2: Fused Half-Rotation RoPE (in-place on contiguous copy)
|
| 109 |
+
// ============================================================================
|
| 110 |
+
//
|
| 111 |
+
// Half-rotation format: q[..., :hd/2] = real, q[..., hd/2:] = imaginary
|
| 112 |
+
// Replaces: float() cast + 4 slices + 4 muls + 2 subs + 2 cats + half() cast
|
| 113 |
+
//
|
| 114 |
+
// Works for both decode (T=1) and prefill (T>1).
|
| 115 |
+
// Data layout after .contiguous(): q[h, t, d] = q_ptr[h*T*hd + t*hd + d]
|
| 116 |
+
|
| 117 |
+
__global__ void fused_rope_kernel(
|
| 118 |
+
half* __restrict__ data, // [n_heads, T, head_dim] contiguous
|
| 119 |
+
const half* __restrict__ cos_ptr, // [T, half_dim]
|
| 120 |
+
const half* __restrict__ sin_ptr, // [T, half_dim]
|
| 121 |
+
int n_heads,
|
| 122 |
+
int seq_len,
|
| 123 |
+
int head_dim
|
| 124 |
+
) {
|
| 125 |
+
const int half_dim = head_dim >> 1;
|
| 126 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 127 |
+
const int total = n_heads * seq_len * half_dim;
|
| 128 |
+
if (idx >= total) return;
|
| 129 |
+
|
| 130 |
+
const int i = idx % half_dim;
|
| 131 |
+
const int t = (idx / half_dim) % seq_len;
|
| 132 |
+
const int h = idx / (half_dim * seq_len);
|
| 133 |
+
|
| 134 |
+
const int base = h * seq_len * head_dim + t * head_dim;
|
| 135 |
+
const int cs = t * half_dim + i;
|
| 136 |
+
|
| 137 |
+
const float c = __half2float(cos_ptr[cs]);
|
| 138 |
+
const float s = __half2float(sin_ptr[cs]);
|
| 139 |
+
const float real = __half2float(data[base + i]);
|
| 140 |
+
const float imag = __half2float(data[base + half_dim + i]);
|
| 141 |
+
|
| 142 |
+
data[base + i] = __float2half(real * c - imag * s);
|
| 143 |
+
data[base + half_dim + i] = __float2half(real * s + imag * c);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
std::vector<torch::Tensor> fused_rope_hf(
|
| 147 |
+
torch::Tensor q, // [B, n_q_heads, T, head_dim]
|
| 148 |
+
torch::Tensor k, // [B, n_kv_heads, T, head_dim]
|
| 149 |
+
torch::Tensor cos, // [T, head_dim/2]
|
| 150 |
+
torch::Tensor sin // [T, head_dim/2]
|
| 151 |
+
) {
|
| 152 |
+
TORCH_CHECK(q.is_cuda() && q.scalar_type() == torch::kHalf);
|
| 153 |
+
|
| 154 |
+
// Make contiguous copies (needed because transpose makes q/k non-contiguous)
|
| 155 |
+
auto q_c = q.contiguous();
|
| 156 |
+
auto k_c = k.contiguous();
|
| 157 |
+
auto cos_c = cos.contiguous();
|
| 158 |
+
auto sin_c = sin.contiguous();
|
| 159 |
+
|
| 160 |
+
const int B = q_c.size(0);
|
| 161 |
+
const int n_q = q_c.size(1);
|
| 162 |
+
const int T = q_c.size(2);
|
| 163 |
+
const int hd = q_c.size(3);
|
| 164 |
+
const int n_kv = k_c.size(1);
|
| 165 |
+
const int half_dim = hd >> 1;
|
| 166 |
+
|
| 167 |
+
// Process each batch item (B=1 in practice, so no overhead)
|
| 168 |
+
for (int b = 0; b < B; b++) {
|
| 169 |
+
half* q_ptr = reinterpret_cast<half*>(q_c.data_ptr<at::Half>()) + b * n_q * T * hd;
|
| 170 |
+
half* k_ptr = reinterpret_cast<half*>(k_c.data_ptr<at::Half>()) + b * n_kv * T * hd;
|
| 171 |
+
|
| 172 |
+
const int total_q = n_q * T * half_dim;
|
| 173 |
+
const int total_kv = n_kv * T * half_dim;
|
| 174 |
+
|
| 175 |
+
fused_rope_kernel<<<(total_q + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 176 |
+
q_ptr,
|
| 177 |
+
reinterpret_cast<const half*>(cos_c.data_ptr<at::Half>()),
|
| 178 |
+
reinterpret_cast<const half*>(sin_c.data_ptr<at::Half>()),
|
| 179 |
+
n_q, T, hd
|
| 180 |
+
);
|
| 181 |
+
|
| 182 |
+
fused_rope_kernel<<<(total_kv + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 183 |
+
k_ptr,
|
| 184 |
+
reinterpret_cast<const half*>(cos_c.data_ptr<at::Half>()),
|
| 185 |
+
reinterpret_cast<const half*>(sin_c.data_ptr<at::Half>()),
|
| 186 |
+
n_kv, T, hd
|
| 187 |
+
);
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
return {q_c, k_c};
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// ============================================================================
|
| 194 |
+
// Kernel 3: Fused SiLU * Multiply
|
| 195 |
+
// ============================================================================
|
| 196 |
+
//
|
| 197 |
+
// Replaces: F.silu(gate) * up (2 separate kernels)
|
| 198 |
+
// silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
|
| 199 |
+
|
| 200 |
+
__global__ void fused_silu_mul_kernel(
|
| 201 |
+
const half* __restrict__ gate,
|
| 202 |
+
const half* __restrict__ up,
|
| 203 |
+
half* __restrict__ out,
|
| 204 |
+
int n
|
| 205 |
+
) {
|
| 206 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 207 |
+
if (idx < n) {
|
| 208 |
+
const float g = __half2float(gate[idx]);
|
| 209 |
+
const float u = __half2float(up[idx]);
|
| 210 |
+
out[idx] = __float2half((g / (1.0f + expf(-g))) * u);
|
| 211 |
+
}
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
torch::Tensor fused_silu_mul(torch::Tensor gate, torch::Tensor up) {
|
| 215 |
+
TORCH_CHECK(gate.is_cuda() && gate.scalar_type() == torch::kHalf);
|
| 216 |
+
TORCH_CHECK(up.is_cuda() && up.scalar_type() == torch::kHalf);
|
| 217 |
+
|
| 218 |
+
auto gate_c = gate.contiguous();
|
| 219 |
+
auto up_c = up.contiguous();
|
| 220 |
+
auto out = torch::empty_like(gate_c);
|
| 221 |
+
const int n = gate_c.numel();
|
| 222 |
+
|
| 223 |
+
fused_silu_mul_kernel<<<(n + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 224 |
+
reinterpret_cast<const half*>(gate_c.data_ptr<at::Half>()),
|
| 225 |
+
reinterpret_cast<const half*>(up_c.data_ptr<at::Half>()),
|
| 226 |
+
reinterpret_cast<half*>(out.data_ptr<at::Half>()),
|
| 227 |
+
n
|
| 228 |
+
);
|
| 229 |
+
|
| 230 |
+
return out;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
// ============================================================================
|
| 234 |
+
// Module
|
| 235 |
+
// ============================================================================
|
| 236 |
+
|
| 237 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 238 |
+
m.def("fused_rmsnorm", &fused_rmsnorm, "Fused RMSNorm (fp16 I/O, fp32 accumulation)");
|
| 239 |
+
m.def("fused_rope_hf", &fused_rope_hf, "Fused half-rotation RoPE (fp16, in-place on contiguous copy)");
|
| 240 |
+
m.def("fused_silu_mul", &fused_silu_mul, "Fused SiLU * mul (fp16)");
|
| 241 |
+
}
|
params.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dim": 3072,
|
| 3 |
+
"n_layers": 26,
|
| 4 |
+
"head_dim": 128,
|
| 5 |
+
"hidden_dim": 9216,
|
| 6 |
+
"n_heads": 32,
|
| 7 |
+
"n_kv_heads": 8,
|
| 8 |
+
"use_biases": false,
|
| 9 |
+
"causal": true,
|
| 10 |
+
"rope_theta": 1000000.0,
|
| 11 |
+
"norm_eps": 1e-05,
|
| 12 |
+
"vocab_size": 131072,
|
| 13 |
+
"model_parallel": 1,
|
| 14 |
+
"tied_embeddings": true,
|
| 15 |
+
"sliding_window": 8192,
|
| 16 |
+
"model_max_length": 131072,
|
| 17 |
+
"multimodal": {
|
| 18 |
+
"whisper_model_args": {
|
| 19 |
+
"encoder_args": {
|
| 20 |
+
"audio_encoding_args": {
|
| 21 |
+
"sampling_rate": 16000,
|
| 22 |
+
"frame_rate": 12.5,
|
| 23 |
+
"num_mel_bins": 128,
|
| 24 |
+
"hop_length": 160,
|
| 25 |
+
"window_size": 400,
|
| 26 |
+
"chunk_length_s": null,
|
| 27 |
+
"global_log_mel_max": 1.5,
|
| 28 |
+
"transcription_format": "streaming"
|
| 29 |
+
},
|
| 30 |
+
"dim": 1280,
|
| 31 |
+
"n_layers": 32,
|
| 32 |
+
"head_dim": 64,
|
| 33 |
+
"hidden_dim": 5120,
|
| 34 |
+
"n_heads": 32,
|
| 35 |
+
"vocab_size": 131072,
|
| 36 |
+
"n_kv_heads": 32,
|
| 37 |
+
"use_biases": true,
|
| 38 |
+
"use_cache": false,
|
| 39 |
+
"rope_theta": 1000000.0,
|
| 40 |
+
"causal": true,
|
| 41 |
+
"norm_eps": 1e-05,
|
| 42 |
+
"pos_embed": "rope",
|
| 43 |
+
"max_source_positions": null,
|
| 44 |
+
"ffn_type": "swiglu",
|
| 45 |
+
"norm_type": "rms_norm",
|
| 46 |
+
"sliding_window": 750,
|
| 47 |
+
"ragged_attention": "750"
|
| 48 |
+
},
|
| 49 |
+
"downsample_args": {
|
| 50 |
+
"downsample_factor": 4
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"ada_rms_norm_t_cond": true,
|
| 55 |
+
"ada_rms_norm_t_cond_dim": 32,
|
| 56 |
+
"quantization_config": {
|
| 57 |
+
"quant_method": "gptq",
|
| 58 |
+
"bits": 4,
|
| 59 |
+
"group_size": 128,
|
| 60 |
+
"desc_act": false,
|
| 61 |
+
"sym": true,
|
| 62 |
+
"checkpoint_format": "gptq",
|
| 63 |
+
"pack_dtype": "int32"
|
| 64 |
+
}
|
| 65 |
+
}
|
scripts/jetson_serve_sdpa.py
ADDED
|
@@ -0,0 +1,1277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Voxtral Mini 4B Realtime β Jetson Orin Nano 8GB inference server.
|
| 3 |
+
|
| 4 |
+
Loads INT4-packed GPTQ weights from Mistral native format and serves
|
| 5 |
+
transcription via WebSocket at ws://localhost:8000/v1/realtime.
|
| 6 |
+
|
| 7 |
+
Key architecture detail: at each decoder position, the input embedding is
|
| 8 |
+
adapter_out[pos] + tok_embed(token_id)
|
| 9 |
+
where token_id is BOS/STREAMING_PAD during the prompt, then previously
|
| 10 |
+
generated text tokens during generation. Generation runs for exactly
|
| 11 |
+
T_adapter total positions.
|
| 12 |
+
|
| 13 |
+
Memory budget: ~6 GB total (model 4.08 GB + runtime 1.5 GB + KV cache 0.2 GB).
|
| 14 |
+
|
| 15 |
+
Optimizations over baseline:
|
| 16 |
+
- Marlin fused INT4 dequant+matmul (24x faster generation)
|
| 17 |
+
- F.scaled_dot_product_attention (fused attention kernel)
|
| 18 |
+
- Pre-allocated KV cache (eliminates torch.cat per token per layer)
|
| 19 |
+
- Optional torch.compile (fuses pointwise ops between matmuls)
|
| 20 |
+
|
| 21 |
+
Usage (inside PyTorch Jetson container):
|
| 22 |
+
pip install safetensors websockets soundfile numpy librosa
|
| 23 |
+
python3 jetson_serve.py --test /workspace/voxtral-quant/test_audio_0.wav
|
| 24 |
+
python3 jetson_serve.py # starts WebSocket server on port 8000
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import asyncio
|
| 29 |
+
import base64
|
| 30 |
+
import json
|
| 31 |
+
import math
|
| 32 |
+
import os
|
| 33 |
+
import sys
|
| 34 |
+
import time
|
| 35 |
+
import uuid
|
| 36 |
+
import warnings
|
| 37 |
+
from typing import List, Optional, Tuple
|
| 38 |
+
|
| 39 |
+
# Suppress harmless warnings
|
| 40 |
+
warnings.filterwarnings('ignore', message='.*divide by zero.*')
|
| 41 |
+
warnings.filterwarnings('ignore', message='.*FutureWarning.*tokenizer.*')
|
| 42 |
+
|
| 43 |
+
# Set CUDA allocator config before importing torch (critical for Jetson)
|
| 44 |
+
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
|
| 45 |
+
|
| 46 |
+
import numpy as np
|
| 47 |
+
import torch
|
| 48 |
+
import torch.nn as nn
|
| 49 |
+
import torch.nn.functional as F
|
| 50 |
+
from safetensors import safe_open
|
| 51 |
+
|
| 52 |
+
# Try to import Marlin fused INT4 kernel (50x faster than on-the-fly dequantization)
|
| 53 |
+
try:
|
| 54 |
+
import marlin as _marlin
|
| 55 |
+
HAS_MARLIN = True
|
| 56 |
+
except ImportError:
|
| 57 |
+
HAS_MARLIN = False
|
| 58 |
+
|
| 59 |
+
# Try to JIT-compile fused CUDA kernels (collapses ~500 kernel launches/token to ~80)
|
| 60 |
+
HAS_FUSED = False
|
| 61 |
+
try:
|
| 62 |
+
from torch.utils.cpp_extension import load as _load_ext
|
| 63 |
+
_cu_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
| 64 |
+
'..', 'kernels', 'fused_ops.cu')
|
| 65 |
+
if not os.path.exists(_cu_path):
|
| 66 |
+
# Fallback for container mount path
|
| 67 |
+
_cu_path = '/workspace/voxtral-quant/kernels/fused_ops.cu'
|
| 68 |
+
if os.path.exists(_cu_path):
|
| 69 |
+
voxtral_kernels = _load_ext(
|
| 70 |
+
name='voxtral_kernels',
|
| 71 |
+
sources=[_cu_path],
|
| 72 |
+
extra_cuda_cflags=['-O3', '--use_fast_math', '-arch=sm_87'],
|
| 73 |
+
verbose=True
|
| 74 |
+
)
|
| 75 |
+
HAS_FUSED = True
|
| 76 |
+
print(f"Fused CUDA kernels loaded from {_cu_path}")
|
| 77 |
+
else:
|
| 78 |
+
print(f"Fused kernels .cu not found, using PyTorch fallback")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Fused kernel JIT failed ({e}), using PyTorch fallback")
|
| 81 |
+
|
| 82 |
+
# βββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 83 |
+
|
| 84 |
+
BITS = 4
|
| 85 |
+
GROUP_SIZE = 128
|
| 86 |
+
PACK_FACTOR = 32 // BITS # 8 int4 values per int32
|
| 87 |
+
BIAS = 1 << (BITS - 1) # 8 (uint4b8 encoding)
|
| 88 |
+
|
| 89 |
+
TOKEN_BOS = 1
|
| 90 |
+
TOKEN_EOS = 2
|
| 91 |
+
TOKEN_STREAMING_PAD = 32
|
| 92 |
+
|
| 93 |
+
SAMPLE_RATE = 16000
|
| 94 |
+
HOP_LENGTH = 160
|
| 95 |
+
N_MELS = 128
|
| 96 |
+
WINDOW_SIZE = 400
|
| 97 |
+
GLOBAL_LOG_MEL_MAX = 1.5 # Fixed constant from params.json (NOT per-sample)
|
| 98 |
+
RAW_AUDIO_PER_TOK = 1280 # Raw audio samples per adapter token
|
| 99 |
+
N_LEFT_PAD_TOKENS = 32 # Left silence padding (aligns with SPAD prompt)
|
| 100 |
+
N_RIGHT_PAD_BASE = 17 # Right silence padding (delay+1 + OFFLINE_BUFFER)
|
| 101 |
+
DOWNSAMPLE_FACTOR = 4
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# βββ Marlin Fused INT4 Linear ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
+
|
| 106 |
+
class MarlinLinear(nn.Module):
|
| 107 |
+
"""Linear layer using Marlin fused INT4 dequant+matmul CUDA kernel.
|
| 108 |
+
|
| 109 |
+
Repacks GPTQ INT4 weights into Marlin's optimized format at construction time.
|
| 110 |
+
Forward pass is a single fused kernel call β ~50x faster than on-the-fly dequant.
|
| 111 |
+
Memory footprint is identical to GPTQ INT4 (no extra memory needed).
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, qweight, scales, qzeros, unpermute=None):
|
| 115 |
+
super().__init__()
|
| 116 |
+
in_features = qweight.shape[0] * PACK_FACTOR
|
| 117 |
+
out_features = qweight.shape[1]
|
| 118 |
+
n_groups = scales.shape[0]
|
| 119 |
+
self.in_features = in_features
|
| 120 |
+
self.out_features = out_features
|
| 121 |
+
|
| 122 |
+
# Dequantize GPTQ β fp16, then repack into Marlin format
|
| 123 |
+
shifts = torch.arange(0, 32, BITS, device=qweight.device, dtype=torch.int32)
|
| 124 |
+
unpacked = (qweight.unsqueeze(0) >> shifts.view(-1, 1, 1)) & 0xF
|
| 125 |
+
unpacked = unpacked.permute(1, 0, 2).reshape(in_features, out_features)
|
| 126 |
+
unpacked = unpacked.T.reshape(out_features, n_groups, GROUP_SIZE)
|
| 127 |
+
s = scales.T.float().unsqueeze(-1)
|
| 128 |
+
w_fp16 = ((unpacked.float() - BIAS) * s).reshape(out_features, in_features).half()
|
| 129 |
+
del unpacked, s
|
| 130 |
+
|
| 131 |
+
if unpermute is not None:
|
| 132 |
+
n_heads, hidden_size = unpermute
|
| 133 |
+
head_dim = w_fp16.shape[0] // n_heads
|
| 134 |
+
w_fp16 = (w_fp16.view(n_heads, 2, head_dim // 2, hidden_size)
|
| 135 |
+
.transpose(1, 2)
|
| 136 |
+
.reshape(out_features, in_features))
|
| 137 |
+
|
| 138 |
+
# Create temporary nn.Linear for Marlin's pack()
|
| 139 |
+
linear = nn.Linear(in_features, out_features, bias=False,
|
| 140 |
+
dtype=torch.half, device=qweight.device)
|
| 141 |
+
linear.weight.data = w_fp16
|
| 142 |
+
|
| 143 |
+
# Create Marlin layer and pack (handles permutation + bit packing)
|
| 144 |
+
ml = _marlin.Layer(in_features, out_features, groupsize=GROUP_SIZE)
|
| 145 |
+
ml.pack(linear, scales.T)
|
| 146 |
+
del linear, w_fp16
|
| 147 |
+
|
| 148 |
+
# Store Marlin buffers
|
| 149 |
+
self.register_buffer('B', ml.B.to(qweight.device))
|
| 150 |
+
self.register_buffer('s', ml.s.to(qweight.device))
|
| 151 |
+
self.register_buffer('workspace',
|
| 152 |
+
torch.zeros(out_features // 128 * 16,
|
| 153 |
+
dtype=torch.int, device=qweight.device),
|
| 154 |
+
persistent=False)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
out_shape = x.shape[:-1] + (self.out_features,)
|
| 158 |
+
C = torch.empty(out_shape, dtype=x.dtype, device=x.device)
|
| 159 |
+
_marlin.mul(x.view(-1, x.shape[-1]), self.B, C.view(-1, self.out_features),
|
| 160 |
+
self.s, self.workspace)
|
| 161 |
+
return C
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# βββ GPTQ INT4 Dequantization (fallback when Marlin unavailable) ββββββββββββ
|
| 165 |
+
|
| 166 |
+
class DequantLinear(nn.Module):
|
| 167 |
+
"""Linear layer with INT4 GPTQ packed weights.
|
| 168 |
+
|
| 169 |
+
Supports two modes:
|
| 170 |
+
- On-the-fly dequantization (default): dequantizes each forward call
|
| 171 |
+
- Cached mode: stores pre-dequantized fp16 weight for fast matmul
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
_shifts = None # class-level cached shifts tensor
|
| 175 |
+
|
| 176 |
+
def __init__(self, qweight, scales, qzeros, unpermute=None):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.register_buffer('qweight', qweight)
|
| 179 |
+
self.register_buffer('scales', scales)
|
| 180 |
+
self.register_buffer('qzeros', qzeros)
|
| 181 |
+
self.in_features = qweight.shape[0] * PACK_FACTOR
|
| 182 |
+
self.out_features = qweight.shape[1]
|
| 183 |
+
self.unpermute = unpermute
|
| 184 |
+
self._cached_w = None # pre-dequantized fp16 weight [out, in]
|
| 185 |
+
|
| 186 |
+
def cache_weight(self, free_int4=True):
|
| 187 |
+
"""Pre-dequantize and cache the fp16 weight.
|
| 188 |
+
If free_int4=True, frees INT4 buffers (saves memory, not reversible).
|
| 189 |
+
"""
|
| 190 |
+
self._cached_w = self._dequantize()
|
| 191 |
+
if free_int4:
|
| 192 |
+
self.qweight = None
|
| 193 |
+
self.scales = None
|
| 194 |
+
self.qzeros = None
|
| 195 |
+
|
| 196 |
+
def uncache_weight(self):
|
| 197 |
+
"""Free the cached weight (e.g., before re-loading INT4 weights)."""
|
| 198 |
+
self._cached_w = None
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def cached_bytes(self):
|
| 202 |
+
"""Memory used by cached weight in bytes."""
|
| 203 |
+
if self._cached_w is not None:
|
| 204 |
+
return self._cached_w.nelement() * self._cached_w.element_size()
|
| 205 |
+
return 0
|
| 206 |
+
|
| 207 |
+
def _dequantize(self):
|
| 208 |
+
"""Dequantize INT4 packed weights to fp16 [out, in]."""
|
| 209 |
+
qw = self.qweight
|
| 210 |
+
in_packed, out = qw.shape
|
| 211 |
+
n_groups = self.scales.shape[0]
|
| 212 |
+
|
| 213 |
+
# Cached shifts tensor (shared across all instances)
|
| 214 |
+
if DequantLinear._shifts is None or DequantLinear._shifts.device != qw.device:
|
| 215 |
+
DequantLinear._shifts = torch.arange(0, 32, BITS, device=qw.device, dtype=torch.int32)
|
| 216 |
+
shifts = DequantLinear._shifts
|
| 217 |
+
|
| 218 |
+
# Vectorized unpack: [8, in/8, out]
|
| 219 |
+
unpacked = (qw.unsqueeze(0) >> shifts.view(-1, 1, 1)) & 0xF
|
| 220 |
+
# Interleave to [in, out] then transpose+group to [out, groups, GROUP_SIZE]
|
| 221 |
+
unpacked = unpacked.permute(1, 0, 2).reshape(self.in_features, out)
|
| 222 |
+
unpacked = unpacked.T.reshape(out, n_groups, GROUP_SIZE)
|
| 223 |
+
# Scale: (val - 8) * scale
|
| 224 |
+
s = self.scales.T.float().unsqueeze(-1)
|
| 225 |
+
w = ((unpacked.float() - BIAS) * s).reshape(out, self.in_features).half()
|
| 226 |
+
del unpacked, s
|
| 227 |
+
|
| 228 |
+
if self.unpermute is not None:
|
| 229 |
+
n_heads, hidden_size = self.unpermute
|
| 230 |
+
head_dim = w.shape[0] // n_heads
|
| 231 |
+
w = (w.view(n_heads, 2, head_dim // 2, hidden_size)
|
| 232 |
+
.transpose(1, 2)
|
| 233 |
+
.reshape(out, self.in_features))
|
| 234 |
+
return w
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
if self._cached_w is not None:
|
| 238 |
+
return F.linear(x, self._cached_w)
|
| 239 |
+
w = self._dequantize()
|
| 240 |
+
result = F.linear(x, w)
|
| 241 |
+
del w
|
| 242 |
+
return result
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# βββ Building Blocks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 246 |
+
|
| 247 |
+
class RMSNorm(nn.Module):
|
| 248 |
+
def __init__(self, dim, eps=1e-5):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.eps = eps
|
| 251 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 252 |
+
|
| 253 |
+
def forward(self, x):
|
| 254 |
+
return (x.float() * (x.float().pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()).type_as(x) * self.weight
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class RMSNormFP16(nn.Module):
|
| 258 |
+
"""RMSNorm that stays in fp16 β avoids 2x float32 conversion overhead.
|
| 259 |
+
Safe for decoder hidden dim=3072 where values stay in fp16 range.
|
| 260 |
+
Uses fused CUDA kernel when available (single kernel vs 5-6 PyTorch kernels)."""
|
| 261 |
+
def __init__(self, dim, eps=1e-5):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.eps = eps
|
| 264 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
if HAS_FUSED:
|
| 268 |
+
return voxtral_kernels.fused_rmsnorm(x.contiguous(), self.weight, self.eps)
|
| 269 |
+
rms = x.pow(2).mean(-1, keepdim=True).add_(self.eps).rsqrt_()
|
| 270 |
+
return x * rms * self.weight
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def apply_rotary_emb(q, k, cos, sin):
|
| 274 |
+
"""Interleaved RoPE for Mistral-format Q/K: q,k [B, heads, T, dim], cos,sin [T, dim/2].
|
| 275 |
+
Pairs: (d0,d1), (d2,d3), etc."""
|
| 276 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 277 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 278 |
+
q_r, q_i = q.float().reshape(*q.shape[:-1], -1, 2).unbind(-1)
|
| 279 |
+
k_r, k_i = k.float().reshape(*k.shape[:-1], -1, 2).unbind(-1)
|
| 280 |
+
return (
|
| 281 |
+
torch.stack([q_r*cos - q_i*sin, q_r*sin + q_i*cos], -1).flatten(-2).type_as(q),
|
| 282 |
+
torch.stack([k_r*cos - k_i*sin, k_r*sin + k_i*cos], -1).flatten(-2).type_as(k),
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def apply_rotary_emb_hf(q, k, cos, sin):
|
| 287 |
+
"""Half-rotation RoPE for HF-format Q/K in fp16 (avoids float32 conversion).
|
| 288 |
+
q,k: [B, heads, T, dim]. cos,sin: [T, dim/2].
|
| 289 |
+
Uses fused CUDA kernel when available (~14 PyTorch kernels β 2)."""
|
| 290 |
+
if HAS_FUSED:
|
| 291 |
+
return voxtral_kernels.fused_rope_hf(q, k, cos, sin)
|
| 292 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 293 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 294 |
+
half = q.shape[-1] // 2
|
| 295 |
+
q1, q2 = q[..., :half], q[..., half:]
|
| 296 |
+
k1, k2 = k[..., :half], k[..., half:]
|
| 297 |
+
return (
|
| 298 |
+
torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], -1),
|
| 299 |
+
torch.cat([k1 * cos - k2 * sin, k1 * sin + k2 * cos], -1),
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def make_freqs(dim, maxlen, theta=1e6, device='cpu', dtype=torch.float16):
|
| 304 |
+
f = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
| 305 |
+
t = torch.arange(maxlen, device=device)
|
| 306 |
+
a = torch.outer(t, f)
|
| 307 |
+
return a.cos().to(dtype), a.sin().to(dtype)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# βββ KV Cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 311 |
+
|
| 312 |
+
class KVCache:
|
| 313 |
+
"""Pre-allocated KV cache that eliminates per-token torch.cat allocations.
|
| 314 |
+
|
| 315 |
+
Instead of creating new tensors and concatenating on every token (3900+
|
| 316 |
+
allocations for a typical transcription), we pre-allocate fixed buffers
|
| 317 |
+
and write into them with an advancing position index.
|
| 318 |
+
|
| 319 |
+
Memory cost: 2 * n_layers * n_kv_heads * max_seq * head_dim * 2 bytes
|
| 320 |
+
For 26 layers, 8 heads, 200 seq, 128 dim = ~21 MB (negligible).
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
def __init__(self, n_layers, max_seq_len, n_kv_heads, head_dim, device, dtype):
|
| 324 |
+
self.max_seq_len = max_seq_len
|
| 325 |
+
self.pos = 0
|
| 326 |
+
# [n_layers, 1, n_kv_heads, max_seq_len, head_dim]
|
| 327 |
+
self.k = torch.zeros(n_layers, 1, n_kv_heads, max_seq_len, head_dim,
|
| 328 |
+
device=device, dtype=dtype)
|
| 329 |
+
self.v = torch.zeros(n_layers, 1, n_kv_heads, max_seq_len, head_dim,
|
| 330 |
+
device=device, dtype=dtype)
|
| 331 |
+
|
| 332 |
+
def update(self, layer_idx, k_new, v_new):
|
| 333 |
+
"""Write new K/V into cache and return full valid slice.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
layer_idx: which decoder layer
|
| 337 |
+
k_new, v_new: [1, n_kv_heads, T_new, head_dim] (T_new=prompt_len or 1)
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
k_full, v_full: [1, n_kv_heads, pos+T_new, head_dim]
|
| 341 |
+
"""
|
| 342 |
+
t_new = k_new.shape[2]
|
| 343 |
+
end = self.pos + t_new
|
| 344 |
+
self.k[layer_idx, :, :, self.pos:end, :] = k_new
|
| 345 |
+
self.v[layer_idx, :, :, self.pos:end, :] = v_new
|
| 346 |
+
return self.k[layer_idx, :, :, :end, :], self.v[layer_idx, :, :, :end, :]
|
| 347 |
+
|
| 348 |
+
def advance(self, n):
|
| 349 |
+
"""Advance position after all layers have processed n new tokens."""
|
| 350 |
+
self.pos += n
|
| 351 |
+
|
| 352 |
+
def reset(self):
|
| 353 |
+
self.pos = 0
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
# βββ Audio Encoder βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 357 |
+
|
| 358 |
+
class EncAttn(nn.Module):
|
| 359 |
+
def __init__(self, h=1280, nh=32, hd=64):
|
| 360 |
+
super().__init__()
|
| 361 |
+
self.nh, self.hd = nh, hd
|
| 362 |
+
self.ad = nh * hd
|
| 363 |
+
self.scale = hd ** -0.5
|
| 364 |
+
self.wq = nn.Linear(h, self.ad, bias=True)
|
| 365 |
+
self.wk = nn.Linear(h, self.ad, bias=False)
|
| 366 |
+
self.wv = nn.Linear(h, self.ad, bias=True)
|
| 367 |
+
self.wo = nn.Linear(self.ad, h, bias=True)
|
| 368 |
+
|
| 369 |
+
def forward(self, x, cos, sin, mask):
|
| 370 |
+
B, T, _ = x.shape
|
| 371 |
+
q = self.wq(x).view(B, T, self.nh, self.hd).transpose(1, 2)
|
| 372 |
+
k = self.wk(x).view(B, T, self.nh, self.hd).transpose(1, 2)
|
| 373 |
+
v = self.wv(x).view(B, T, self.nh, self.hd).transpose(1, 2)
|
| 374 |
+
q, k = apply_rotary_emb(q, k, cos, sin)
|
| 375 |
+
# Encoder uses causal attention with explicit mask
|
| 376 |
+
a = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, scale=self.scale)
|
| 377 |
+
return self.wo(a.transpose(1, 2).reshape(B, T, self.ad))
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class EncLayer(nn.Module):
|
| 381 |
+
def __init__(self, h=1280, ff=5120):
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.attn = EncAttn(h)
|
| 384 |
+
self.an = RMSNorm(h)
|
| 385 |
+
self.fn = RMSNorm(h)
|
| 386 |
+
self.w1 = nn.Linear(h, ff, bias=False)
|
| 387 |
+
self.w2 = nn.Linear(ff, h, bias=True)
|
| 388 |
+
self.w3 = nn.Linear(h, ff, bias=False)
|
| 389 |
+
|
| 390 |
+
def forward(self, x, cos, sin, mask):
|
| 391 |
+
x = x + self.attn(self.an(x), cos, sin, mask)
|
| 392 |
+
h = self.fn(x)
|
| 393 |
+
return x + self.w2(F.silu(self.w1(h)) * self.w3(h))
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class Encoder(nn.Module):
|
| 397 |
+
def __init__(self, nl=32, h=1280, ff=5120, hd=64):
|
| 398 |
+
super().__init__()
|
| 399 |
+
self.conv1 = nn.Conv1d(N_MELS, h, 3, padding=1)
|
| 400 |
+
self.conv2 = nn.Conv1d(h, h, 3, stride=2, padding=1)
|
| 401 |
+
self.layers = nn.ModuleList([EncLayer(h, ff) for _ in range(nl)])
|
| 402 |
+
self.norm = RMSNorm(h)
|
| 403 |
+
self.hd = hd
|
| 404 |
+
|
| 405 |
+
def forward(self, mel):
|
| 406 |
+
x = F.gelu(self.conv1(mel))
|
| 407 |
+
x = F.gelu(self.conv2(x)).transpose(1, 2)
|
| 408 |
+
T = x.shape[1]
|
| 409 |
+
cos, sin = make_freqs(self.hd, T, device=x.device, dtype=x.dtype)
|
| 410 |
+
mask = torch.triu(torch.full((T, T), float('-inf'), device=x.device, dtype=x.dtype), 1).unsqueeze(0).unsqueeze(0)
|
| 411 |
+
for layer in self.layers:
|
| 412 |
+
x = layer(x, cos, sin, mask)
|
| 413 |
+
return self.norm(x)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# βββ Projector βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 417 |
+
|
| 418 |
+
class Projector(nn.Module):
|
| 419 |
+
def __init__(self, inp=5120, out=3072):
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.l1 = nn.Linear(inp, out, bias=False)
|
| 422 |
+
self.l2 = nn.Linear(out, out, bias=False)
|
| 423 |
+
|
| 424 |
+
def forward(self, x):
|
| 425 |
+
return self.l2(F.gelu(self.l1(x)))
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# βββ LM Decoder Layer βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 429 |
+
|
| 430 |
+
class DecAttn(nn.Module):
|
| 431 |
+
def __init__(self, h=3072, nh=32, nkv=8, hd=128):
|
| 432 |
+
super().__init__()
|
| 433 |
+
self.nh, self.nkv, self.hd = nh, nkv, hd
|
| 434 |
+
self.qd, self.kvd = nh*hd, nkv*hd
|
| 435 |
+
self.scale = hd ** -0.5
|
| 436 |
+
self.g = nh // nkv
|
| 437 |
+
self.q_proj = self.k_proj = self.v_proj = self.o_proj = None # set by loader
|
| 438 |
+
|
| 439 |
+
def forward(self, x, cos, sin, cache=None, layer_idx=None, is_causal=False):
|
| 440 |
+
B, T, _ = x.shape
|
| 441 |
+
q = self.q_proj(x).view(B, T, self.nh, self.hd).transpose(1, 2)
|
| 442 |
+
k = self.k_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
|
| 443 |
+
v = self.v_proj(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
|
| 444 |
+
q, k = apply_rotary_emb_hf(q, k, cos, sin)
|
| 445 |
+
|
| 446 |
+
# Update KV cache if available
|
| 447 |
+
if cache is not None:
|
| 448 |
+
k, v = cache.update(layer_idx, k, v)
|
| 449 |
+
|
| 450 |
+
# GQA: expand KV heads to match query heads
|
| 451 |
+
if self.g > 1:
|
| 452 |
+
k = k.repeat_interleave(self.g, 1)
|
| 453 |
+
v = v.repeat_interleave(self.g, 1)
|
| 454 |
+
|
| 455 |
+
# Fused attention via SDPA
|
| 456 |
+
# - Prefill (T > 1): use is_causal=True for triangular mask
|
| 457 |
+
# - Decode (T == 1): no mask needed, single query attends to all past
|
| 458 |
+
out = F.scaled_dot_product_attention(
|
| 459 |
+
q, k, v, scale=self.scale, is_causal=is_causal)
|
| 460 |
+
|
| 461 |
+
return self.o_proj(out.transpose(1, 2).reshape(B, T, self.qd))
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class DecLayer(nn.Module):
|
| 465 |
+
def __init__(self, h=3072, ad=32):
|
| 466 |
+
super().__init__()
|
| 467 |
+
self.an = RMSNormFP16(h)
|
| 468 |
+
self.attn = DecAttn()
|
| 469 |
+
self.fn = RMSNormFP16(h)
|
| 470 |
+
self.gate_proj = self.up_proj = self.down_proj = None # set by loader
|
| 471 |
+
self.ada0 = nn.Linear(h, ad, bias=False)
|
| 472 |
+
self.ada2 = nn.Linear(ad, h, bias=False)
|
| 473 |
+
self._ada_scale = None # pre-computed: 1 + ada2(gelu(ada0(t_cond)))
|
| 474 |
+
|
| 475 |
+
def precompute_ada(self, t_cond):
|
| 476 |
+
"""Pre-compute the ada_rms_norm modulation scale from t_cond.
|
| 477 |
+
Since t_cond is constant (delay-based), this eliminates 2 matmuls +
|
| 478 |
+
gelu + add per layer per token (~0.45ms each, 11.7ms total for 26 layers)."""
|
| 479 |
+
with torch.no_grad():
|
| 480 |
+
self._ada_scale = (1.0 + self.ada2(F.gelu(self.ada0(t_cond)))).unsqueeze(0)
|
| 481 |
+
|
| 482 |
+
def forward(self, x, cos, sin, cache=None, layer_idx=None,
|
| 483 |
+
is_causal=False, t_cond=None):
|
| 484 |
+
h = self.attn(self.an(x), cos, sin, cache, layer_idx, is_causal)
|
| 485 |
+
x = x + h
|
| 486 |
+
h = self.fn(x)
|
| 487 |
+
if self._ada_scale is not None:
|
| 488 |
+
h = h * self._ada_scale
|
| 489 |
+
elif t_cond is not None:
|
| 490 |
+
h = h * (1.0 + self.ada2(F.gelu(self.ada0(t_cond))).unsqueeze(0))
|
| 491 |
+
gate_out = self.gate_proj(h)
|
| 492 |
+
up_out = self.up_proj(h)
|
| 493 |
+
if HAS_FUSED:
|
| 494 |
+
x = x + self.down_proj(voxtral_kernels.fused_silu_mul(gate_out, up_out))
|
| 495 |
+
else:
|
| 496 |
+
x = x + self.down_proj(F.silu(gate_out) * up_out)
|
| 497 |
+
return x
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# βββ Full Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 501 |
+
|
| 502 |
+
class VoxtralModel:
|
| 503 |
+
def __init__(self, model_path, device='cuda', dtype=torch.float16, compile=False):
|
| 504 |
+
self.device = device
|
| 505 |
+
self.dtype = dtype
|
| 506 |
+
self.model_path = model_path
|
| 507 |
+
|
| 508 |
+
with open(os.path.join(model_path, 'params.json')) as f:
|
| 509 |
+
self.p = json.load(f)
|
| 510 |
+
|
| 511 |
+
ea = self.p['multimodal']['whisper_model_args']['encoder_args']
|
| 512 |
+
da = self.p['multimodal']['whisper_model_args']['downsample_args']
|
| 513 |
+
self.ds_factor = da['downsample_factor']
|
| 514 |
+
self.enc_h = ea['dim']
|
| 515 |
+
self.n_layers = self.p['n_layers']
|
| 516 |
+
self.h = self.p['dim']
|
| 517 |
+
self.n_kv_heads = self.p.get('n_kv_heads', 8)
|
| 518 |
+
self.head_dim = self.p.get('head_dim', 128)
|
| 519 |
+
self.delay_ms = self.p['multimodal']['whisper_model_args'].get(
|
| 520 |
+
'encoder_args', {}).get('audio_encoding_args', {}).get(
|
| 521 |
+
'transcription_delay_ms', 480)
|
| 522 |
+
|
| 523 |
+
# Frame rate from config
|
| 524 |
+
self.frame_rate = ea.get('audio_encoding_args', {}).get('frame_rate', 12.5)
|
| 525 |
+
# delay_tokens = delay_ms / ms_per_frame
|
| 526 |
+
ms_per_frame = 1000.0 / self.frame_rate
|
| 527 |
+
self.delay_tokens = int(self.delay_ms / ms_per_frame)
|
| 528 |
+
self.n_left_pad = 32 # streaming_n_left_pad_tokens from audio config
|
| 529 |
+
|
| 530 |
+
self._build()
|
| 531 |
+
self._load()
|
| 532 |
+
if compile:
|
| 533 |
+
self._try_compile()
|
| 534 |
+
|
| 535 |
+
self.cos, self.sin = make_freqs(self.p['head_dim'], 8192, device=device, dtype=dtype)
|
| 536 |
+
|
| 537 |
+
# Time conditioning: sinusoidal embedding of delay token count
|
| 538 |
+
self.t_cond = self._compute_time_embedding(
|
| 539 |
+
float(self.delay_tokens), self.h
|
| 540 |
+
).to(device=device, dtype=dtype)
|
| 541 |
+
|
| 542 |
+
# Pre-compute ada_rms_norm modulation for all decoder layers
|
| 543 |
+
# (t_cond is constant, so ada output is constant per layer)
|
| 544 |
+
for layer in self.layers:
|
| 545 |
+
layer.precompute_ada(self.t_cond)
|
| 546 |
+
|
| 547 |
+
def _build(self):
|
| 548 |
+
"""Build model skeleton on meta device (zero memory allocation)."""
|
| 549 |
+
import gc
|
| 550 |
+
ea = self.p['multimodal']['whisper_model_args']['encoder_args']
|
| 551 |
+
with torch.device('meta'):
|
| 552 |
+
self.encoder = Encoder(ea['n_layers'], ea['dim'], ea['hidden_dim'], ea['head_dim'])
|
| 553 |
+
self.projector = Projector(ea['hidden_dim'], self.h)
|
| 554 |
+
self.layers = nn.ModuleList([
|
| 555 |
+
DecLayer(self.h, self.p.get('ada_rms_norm_t_cond_dim', 32))
|
| 556 |
+
for _ in range(self.n_layers)
|
| 557 |
+
])
|
| 558 |
+
self.embed = nn.Embedding(self.p['vocab_size'], self.h)
|
| 559 |
+
self.norm = RMSNorm(self.h)
|
| 560 |
+
gc.collect()
|
| 561 |
+
|
| 562 |
+
def _try_compile(self):
|
| 563 |
+
"""Optionally compile decoder layers with torch.compile.
|
| 564 |
+
|
| 565 |
+
Fuses pointwise ops (RMSNorm, SiLU, ada_rms_norm, residuals) between
|
| 566 |
+
Marlin matmuls, reducing kernel launch overhead at batch-1 decode.
|
| 567 |
+
Falls back gracefully if torch.compile isn't available on this platform.
|
| 568 |
+
"""
|
| 569 |
+
try:
|
| 570 |
+
compiled = 0
|
| 571 |
+
for i in range(len(self.layers)):
|
| 572 |
+
self.layers[i] = torch.compile(self.layers[i], mode="default")
|
| 573 |
+
compiled += 1
|
| 574 |
+
print(f" torch.compile: {compiled} decoder layers compiled")
|
| 575 |
+
except Exception as e:
|
| 576 |
+
print(f" torch.compile not available ({e}), using eager mode")
|
| 577 |
+
|
| 578 |
+
def _pack_lm_head(self):
|
| 579 |
+
"""Quantize the embedding weight to INT4 and pack for Marlin LM head.
|
| 580 |
+
|
| 581 |
+
The LM head (output projection) reads the full 768 MB fp16 embedding
|
| 582 |
+
matrix on every token β 21ms at Jetson bandwidth. By quantizing to INT4
|
| 583 |
+
and using Marlin, this drops to ~4ms (5.8x faster).
|
| 584 |
+
|
| 585 |
+
Must be called before decoder layers are loaded (maximum free GPU memory).
|
| 586 |
+
"""
|
| 587 |
+
import gc
|
| 588 |
+
embed_w = self.embed.weight.data # [vocab_size, dim] fp16
|
| 589 |
+
out_f, in_f = embed_w.shape
|
| 590 |
+
|
| 591 |
+
if in_f % 128 != 0 or out_f % 256 != 0:
|
| 592 |
+
print(f" LM head: dims {out_f}x{in_f} not Marlin-compatible, keeping fp16")
|
| 593 |
+
self._lm_head_marlin = False
|
| 594 |
+
return
|
| 595 |
+
|
| 596 |
+
t0 = time.time()
|
| 597 |
+
# Compute per-group scales (group_size=128 along input dim)
|
| 598 |
+
n_groups = in_f // GROUP_SIZE
|
| 599 |
+
w_grouped = embed_w.float().reshape(out_f, n_groups, GROUP_SIZE)
|
| 600 |
+
scales = (w_grouped.abs().amax(dim=-1) / 7.0).clamp(min=1e-6).half() # [out_f, n_groups]
|
| 601 |
+
del w_grouped
|
| 602 |
+
|
| 603 |
+
# Pack with Marlin (re-uses embed_w in-place for the linear)
|
| 604 |
+
linear_tmp = nn.Linear(in_f, out_f, bias=False, dtype=torch.half, device=embed_w.device)
|
| 605 |
+
linear_tmp.weight.data = embed_w # share, no copy
|
| 606 |
+
ml = _marlin.Layer(in_f, out_f, groupsize=GROUP_SIZE)
|
| 607 |
+
ml.pack(linear_tmp, scales)
|
| 608 |
+
del linear_tmp, scales
|
| 609 |
+
gc.collect()
|
| 610 |
+
torch.cuda.empty_cache()
|
| 611 |
+
|
| 612 |
+
self._lm_B = ml.B.to(embed_w.device)
|
| 613 |
+
self._lm_s = ml.s.to(embed_w.device)
|
| 614 |
+
self._lm_ws = torch.zeros(out_f // 128 * 16, dtype=torch.int, device=embed_w.device)
|
| 615 |
+
self._lm_out_f = out_f
|
| 616 |
+
self._lm_head_marlin = True
|
| 617 |
+
|
| 618 |
+
mb = (self._lm_B.nelement() * self._lm_B.element_size() +
|
| 619 |
+
self._lm_s.nelement() * self._lm_s.element_size()) / 1024**2
|
| 620 |
+
print(f" Marlin LM head packed: {out_f}x{in_f} -> {mb:.0f} MB ({time.time()-t0:.1f}s)")
|
| 621 |
+
|
| 622 |
+
def _lm_head(self, h):
|
| 623 |
+
"""Compute LM head logits. Uses Marlin INT4 if available, else fp16."""
|
| 624 |
+
if hasattr(self, '_lm_head_marlin') and self._lm_head_marlin:
|
| 625 |
+
h_flat = h.view(-1, h.shape[-1])
|
| 626 |
+
out = torch.empty(h_flat.shape[0], self._lm_out_f, dtype=h.dtype, device=h.device)
|
| 627 |
+
_marlin.mul(h_flat, self._lm_B, out, self._lm_s, self._lm_ws)
|
| 628 |
+
return out.view(*h.shape[:-1], self._lm_out_f)
|
| 629 |
+
return F.linear(h, self.embed.weight)
|
| 630 |
+
|
| 631 |
+
def _dql(self, f, prefix, dev, unpermute=None):
|
| 632 |
+
qw = f.get_tensor(f'{prefix}.qweight').to(dev)
|
| 633 |
+
sc = f.get_tensor(f'{prefix}.scales').to(dev)
|
| 634 |
+
qz = f.get_tensor(f'{prefix}.qzeros').to(dev)
|
| 635 |
+
in_f = qw.shape[0] * PACK_FACTOR
|
| 636 |
+
out_f = qw.shape[1]
|
| 637 |
+
if HAS_MARLIN and in_f % 128 == 0 and out_f % 256 == 0:
|
| 638 |
+
return MarlinLinear(qw, sc, qz, unpermute=unpermute)
|
| 639 |
+
return DequantLinear(qw, sc, qz, unpermute=unpermute)
|
| 640 |
+
|
| 641 |
+
def _set(self, module, name, tensor):
|
| 642 |
+
"""Replace a meta parameter with a real CUDA tensor."""
|
| 643 |
+
module._parameters[name] = nn.Parameter(tensor, requires_grad=False)
|
| 644 |
+
|
| 645 |
+
@staticmethod
|
| 646 |
+
def _compute_time_embedding(t_value: float, dim: int, theta: float = 10000.0) -> torch.Tensor:
|
| 647 |
+
"""Sinusoidal embedding of scalar t_value into dim-dimensional vector."""
|
| 648 |
+
half_dim = dim // 2
|
| 649 |
+
inv_freq = torch.exp(-math.log(theta) * torch.arange(half_dim).float() / half_dim)
|
| 650 |
+
emb = t_value * inv_freq
|
| 651 |
+
return torch.cat([emb.cos(), emb.sin()]) # [dim]
|
| 652 |
+
|
| 653 |
+
@staticmethod
|
| 654 |
+
def _evict_cache():
|
| 655 |
+
"""Force kernel to reclaim page cache on Jetson unified memory."""
|
| 656 |
+
import ctypes, gc
|
| 657 |
+
for sz in [4, 3, 2, 1]:
|
| 658 |
+
try:
|
| 659 |
+
buf = ctypes.create_string_buffer(sz * 1024 * 1024 * 1024)
|
| 660 |
+
del buf
|
| 661 |
+
gc.collect()
|
| 662 |
+
return
|
| 663 |
+
except (MemoryError, OSError):
|
| 664 |
+
continue
|
| 665 |
+
|
| 666 |
+
def _load_section(self, path, load_fn):
|
| 667 |
+
"""Open safetensors, run load_fn, close, evict cache."""
|
| 668 |
+
import gc
|
| 669 |
+
D = str(self.device)
|
| 670 |
+
with safe_open(path, framework='pt', device=D) as f:
|
| 671 |
+
load_fn(f)
|
| 672 |
+
gc.collect()
|
| 673 |
+
torch.cuda.empty_cache()
|
| 674 |
+
self._evict_cache()
|
| 675 |
+
|
| 676 |
+
def _load(self):
|
| 677 |
+
import gc
|
| 678 |
+
path = os.path.join(self.model_path, 'consolidated.safetensors')
|
| 679 |
+
print(f"Loading {path}...")
|
| 680 |
+
t0 = time.time()
|
| 681 |
+
D, T = self.device, self.dtype
|
| 682 |
+
ep = 'mm_streams_embeddings.embedding_module'
|
| 683 |
+
enc_prefix = f'{ep}.whisper_encoder'
|
| 684 |
+
ea = self.p['multimodal']['whisper_model_args']['encoder_args']
|
| 685 |
+
|
| 686 |
+
# Section 1: Embeddings + output norm
|
| 687 |
+
def load_embed(f):
|
| 688 |
+
self._set(self.embed, 'weight', f.get_tensor(f'{ep}.tok_embeddings.weight').to(T))
|
| 689 |
+
self._set(self.norm, 'weight', f.get_tensor('norm.weight').to(T))
|
| 690 |
+
print(f" Embeddings loaded")
|
| 691 |
+
self._load_section(path, load_embed)
|
| 692 |
+
|
| 693 |
+
# Section 2: Encoder convolutions
|
| 694 |
+
def load_enc_conv(f):
|
| 695 |
+
self._set(self.encoder.conv1, 'weight', f.get_tensor(f'{enc_prefix}.conv_layers.0.conv.weight').to(T))
|
| 696 |
+
self._set(self.encoder.conv1, 'bias', f.get_tensor(f'{enc_prefix}.conv_layers.0.conv.bias').to(T))
|
| 697 |
+
self._set(self.encoder.conv2, 'weight', f.get_tensor(f'{enc_prefix}.conv_layers.1.conv.weight').to(T))
|
| 698 |
+
self._set(self.encoder.conv2, 'bias', f.get_tensor(f'{enc_prefix}.conv_layers.1.conv.bias').to(T))
|
| 699 |
+
self._set(self.encoder.norm, 'weight', f.get_tensor(f'{enc_prefix}.transformer.norm.weight').to(T))
|
| 700 |
+
print(f" Encoder conv loaded")
|
| 701 |
+
self._load_section(path, load_enc_conv)
|
| 702 |
+
|
| 703 |
+
# Section 3: Encoder layers (in batches of 8 to limit mmap cache growth)
|
| 704 |
+
n_enc = ea['n_layers']
|
| 705 |
+
batch = 8
|
| 706 |
+
for b_start in range(0, n_enc, batch):
|
| 707 |
+
b_end = min(b_start + batch, n_enc)
|
| 708 |
+
def load_enc_batch(f, start=b_start, end=b_end):
|
| 709 |
+
for i in range(start, end):
|
| 710 |
+
lp = f'{enc_prefix}.transformer.layers.{i}'
|
| 711 |
+
el = self.encoder.layers[i]
|
| 712 |
+
self._set(el.attn.wq, 'weight', f.get_tensor(f'{lp}.attention.wq.weight').to(T))
|
| 713 |
+
self._set(el.attn.wq, 'bias', f.get_tensor(f'{lp}.attention.wq.bias').to(T))
|
| 714 |
+
self._set(el.attn.wk, 'weight', f.get_tensor(f'{lp}.attention.wk.weight').to(T))
|
| 715 |
+
self._set(el.attn.wv, 'weight', f.get_tensor(f'{lp}.attention.wv.weight').to(T))
|
| 716 |
+
self._set(el.attn.wv, 'bias', f.get_tensor(f'{lp}.attention.wv.bias').to(T))
|
| 717 |
+
self._set(el.attn.wo, 'weight', f.get_tensor(f'{lp}.attention.wo.weight').to(T))
|
| 718 |
+
self._set(el.attn.wo, 'bias', f.get_tensor(f'{lp}.attention.wo.bias').to(T))
|
| 719 |
+
self._set(el.an, 'weight', f.get_tensor(f'{lp}.attention_norm.weight').to(T))
|
| 720 |
+
self._set(el.fn, 'weight', f.get_tensor(f'{lp}.ffn_norm.weight').to(T))
|
| 721 |
+
self._set(el.w1, 'weight', f.get_tensor(f'{lp}.feed_forward.w1.weight').to(T))
|
| 722 |
+
self._set(el.w2, 'weight', f.get_tensor(f'{lp}.feed_forward.w2.weight').to(T))
|
| 723 |
+
self._set(el.w2, 'bias', f.get_tensor(f'{lp}.feed_forward.w2.bias').to(T))
|
| 724 |
+
self._set(el.w3, 'weight', f.get_tensor(f'{lp}.feed_forward.w3.weight').to(T))
|
| 725 |
+
print(f" Encoder layers {start}-{end-1} loaded")
|
| 726 |
+
self._load_section(path, load_enc_batch)
|
| 727 |
+
|
| 728 |
+
print(f" Encoder loaded ({n_enc} layers)")
|
| 729 |
+
|
| 730 |
+
# Section 4: Projector
|
| 731 |
+
def load_proj(f):
|
| 732 |
+
pp = f'{ep}.audio_language_projection'
|
| 733 |
+
self._set(self.projector.l1, 'weight', f.get_tensor(f'{pp}.0.weight').to(T))
|
| 734 |
+
self._set(self.projector.l2, 'weight', f.get_tensor(f'{pp}.2.weight').to(T))
|
| 735 |
+
print(f" Projector loaded")
|
| 736 |
+
self._load_section(path, load_proj)
|
| 737 |
+
|
| 738 |
+
# Section 4b: Pack Marlin INT4 LM head (before decoder layers, max free mem)
|
| 739 |
+
if HAS_MARLIN:
|
| 740 |
+
self._pack_lm_head()
|
| 741 |
+
|
| 742 |
+
# Section 5: LM decoder layers (in batches of 13)
|
| 743 |
+
dec_batch = 13
|
| 744 |
+
for b_start in range(0, self.n_layers, dec_batch):
|
| 745 |
+
b_end = min(b_start + dec_batch, self.n_layers)
|
| 746 |
+
def load_dec_batch(f, start=b_start, end=b_end):
|
| 747 |
+
for i in range(start, end):
|
| 748 |
+
lp = f'layers.{i}'
|
| 749 |
+
dl = self.layers[i]
|
| 750 |
+
self._set(dl.an, 'weight', f.get_tensor(f'{lp}.attention_norm.weight').to(T))
|
| 751 |
+
self._set(dl.fn, 'weight', f.get_tensor(f'{lp}.ffn_norm.weight').to(T))
|
| 752 |
+
self._set(dl.ada0, 'weight', f.get_tensor(f'{lp}.ada_rms_norm_t_cond.0.weight').to(T))
|
| 753 |
+
self._set(dl.ada2, 'weight', f.get_tensor(f'{lp}.ada_rms_norm_t_cond.2.weight').to(T))
|
| 754 |
+
dl.attn.q_proj = self._dql(f, f'{lp}.self_attn.q_proj', D)
|
| 755 |
+
dl.attn.k_proj = self._dql(f, f'{lp}.self_attn.k_proj', D)
|
| 756 |
+
dl.attn.v_proj = self._dql(f, f'{lp}.self_attn.v_proj', D)
|
| 757 |
+
dl.attn.o_proj = self._dql(f, f'{lp}.self_attn.o_proj', D)
|
| 758 |
+
dl.gate_proj = self._dql(f, f'{lp}.mlp.gate_proj', D)
|
| 759 |
+
dl.up_proj = self._dql(f, f'{lp}.mlp.up_proj', D)
|
| 760 |
+
dl.down_proj = self._dql(f, f'{lp}.mlp.down_proj', D)
|
| 761 |
+
print(f" LM layers {start}-{end-1} loaded")
|
| 762 |
+
self._load_section(path, load_dec_batch)
|
| 763 |
+
|
| 764 |
+
backend = "Marlin fused INT4" if HAS_MARLIN else "DequantLinear"
|
| 765 |
+
print(f" LM decoder loaded ({self.n_layers} layers, {backend})")
|
| 766 |
+
gc.collect()
|
| 767 |
+
torch.cuda.empty_cache()
|
| 768 |
+
mem = torch.cuda.memory_allocated() / 1024**3
|
| 769 |
+
print(f" Done in {time.time()-t0:.1f}s, GPU: {mem:.2f} GB")
|
| 770 |
+
|
| 771 |
+
def _load_tokenizer(self):
|
| 772 |
+
"""Load tekken tokenizer for decoding."""
|
| 773 |
+
try:
|
| 774 |
+
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
| 775 |
+
self.tokenizer = Tekkenizer.from_file(
|
| 776 |
+
os.path.join(self.model_path, 'tekken.json'))
|
| 777 |
+
print(" Tekken tokenizer loaded")
|
| 778 |
+
except ImportError:
|
| 779 |
+
try:
|
| 780 |
+
from mistral_common.tokens.tokenizers.tekken import Tekken
|
| 781 |
+
self.tokenizer = Tekken.from_file(
|
| 782 |
+
os.path.join(self.model_path, 'tekken.json'))
|
| 783 |
+
print(" Tekken tokenizer loaded (legacy)")
|
| 784 |
+
except ImportError:
|
| 785 |
+
self.tokenizer = None
|
| 786 |
+
print(" WARNING: mistral_common not available, using fallback decoder")
|
| 787 |
+
|
| 788 |
+
def _pre_dequant(self):
|
| 789 |
+
"""Offload encoder to CPU and pre-dequantize decoder weights into GPU cache.
|
| 790 |
+
|
| 791 |
+
After encoding is done, the encoder (~1.86 GB) is no longer needed on GPU.
|
| 792 |
+
Offloading it frees memory for caching pre-dequantized decoder weights,
|
| 793 |
+
which eliminates the per-token dequantization overhead.
|
| 794 |
+
"""
|
| 795 |
+
import gc
|
| 796 |
+
|
| 797 |
+
if hasattr(self, '_decoder_cached') and self._decoder_cached:
|
| 798 |
+
return # already cached
|
| 799 |
+
|
| 800 |
+
t0 = time.time()
|
| 801 |
+
|
| 802 |
+
# Move encoder + projector to CPU to free GPU memory
|
| 803 |
+
self.encoder.cpu()
|
| 804 |
+
self.projector.cpu()
|
| 805 |
+
gc.collect()
|
| 806 |
+
torch.cuda.empty_cache()
|
| 807 |
+
self._evict_cache()
|
| 808 |
+
|
| 809 |
+
free, _ = torch.cuda.mem_get_info(0)
|
| 810 |
+
print(f" After encoder offload: {free/1024**3:.2f} GB free")
|
| 811 |
+
|
| 812 |
+
# Budget: leave 500 MB for KV cache + intermediates
|
| 813 |
+
budget = free - 500 * 1024 * 1024
|
| 814 |
+
used_bytes = 0
|
| 815 |
+
cached_count = 0
|
| 816 |
+
|
| 817 |
+
for i, dl in enumerate(self.layers):
|
| 818 |
+
projs = [dl.attn.q_proj, dl.attn.k_proj, dl.attn.v_proj, dl.attn.o_proj,
|
| 819 |
+
dl.gate_proj, dl.up_proj, dl.down_proj]
|
| 820 |
+
|
| 821 |
+
# Estimate net memory cost (fp16 weight minus freed INT4 buffers)
|
| 822 |
+
layer_fp16 = sum(
|
| 823 |
+
p.in_features * p.out_features * 2
|
| 824 |
+
for p in projs if isinstance(p, DequantLinear) and p._cached_w is None
|
| 825 |
+
)
|
| 826 |
+
layer_int4 = sum(
|
| 827 |
+
p.qweight.nelement() * 4 + p.scales.nelement() * 2 + p.qzeros.nelement() * 4
|
| 828 |
+
for p in projs
|
| 829 |
+
if isinstance(p, DequantLinear) and p.qweight is not None
|
| 830 |
+
)
|
| 831 |
+
net = layer_fp16 - layer_int4 # net increase in memory
|
| 832 |
+
|
| 833 |
+
if used_bytes + net > budget:
|
| 834 |
+
break
|
| 835 |
+
|
| 836 |
+
for p in projs:
|
| 837 |
+
if isinstance(p, DequantLinear) and p._cached_w is None and p.qweight is not None:
|
| 838 |
+
p.cache_weight(free_int4=True)
|
| 839 |
+
|
| 840 |
+
used_bytes += net
|
| 841 |
+
cached_count += 1
|
| 842 |
+
# Periodic cleanup to keep peak memory low
|
| 843 |
+
if cached_count % 5 == 0:
|
| 844 |
+
gc.collect()
|
| 845 |
+
torch.cuda.empty_cache()
|
| 846 |
+
|
| 847 |
+
gc.collect()
|
| 848 |
+
torch.cuda.empty_cache()
|
| 849 |
+
free2, _ = torch.cuda.mem_get_info(0)
|
| 850 |
+
print(f" Pre-dequantized {cached_count}/{self.n_layers} layers in {time.time()-t0:.1f}s, "
|
| 851 |
+
f"{free2/1024**3:.2f} GB free")
|
| 852 |
+
self._decoder_cached = True
|
| 853 |
+
|
| 854 |
+
def _restore_encoder(self):
|
| 855 |
+
"""Move encoder back to GPU for the next transcription.
|
| 856 |
+
|
| 857 |
+
Frees cached decoder weights first to make room, then reloads
|
| 858 |
+
INT4 weights for layers that had their buffers freed.
|
| 859 |
+
"""
|
| 860 |
+
import gc
|
| 861 |
+
|
| 862 |
+
if not hasattr(self, '_decoder_cached') or not self._decoder_cached:
|
| 863 |
+
return
|
| 864 |
+
|
| 865 |
+
t0 = time.time()
|
| 866 |
+
|
| 867 |
+
# Free cached decoder weights
|
| 868 |
+
needs_reload = []
|
| 869 |
+
for i, dl in enumerate(self.layers):
|
| 870 |
+
for p in [dl.attn.q_proj, dl.attn.k_proj, dl.attn.v_proj, dl.attn.o_proj,
|
| 871 |
+
dl.gate_proj, dl.up_proj, dl.down_proj]:
|
| 872 |
+
if isinstance(p, DequantLinear):
|
| 873 |
+
if p._cached_w is not None and p.qweight is None:
|
| 874 |
+
needs_reload.append(i)
|
| 875 |
+
p.uncache_weight()
|
| 876 |
+
|
| 877 |
+
gc.collect()
|
| 878 |
+
torch.cuda.empty_cache()
|
| 879 |
+
self._evict_cache()
|
| 880 |
+
|
| 881 |
+
# Move encoder + projector back to GPU
|
| 882 |
+
self.encoder.to(self.device)
|
| 883 |
+
self.projector.to(self.device)
|
| 884 |
+
|
| 885 |
+
# Reload INT4 weights for layers that were freed
|
| 886 |
+
if needs_reload:
|
| 887 |
+
needs_reload = sorted(set(needs_reload))
|
| 888 |
+
path = os.path.join(self.model_path, 'consolidated.safetensors')
|
| 889 |
+
D = str(self.device)
|
| 890 |
+
with safe_open(path, framework='pt', device=D) as f:
|
| 891 |
+
for i in needs_reload:
|
| 892 |
+
lp = f'layers.{i}'
|
| 893 |
+
dl = self.layers[i]
|
| 894 |
+
dl.attn.q_proj = self._dql(f, f'{lp}.self_attn.q_proj', D)
|
| 895 |
+
dl.attn.k_proj = self._dql(f, f'{lp}.self_attn.k_proj', D)
|
| 896 |
+
dl.attn.v_proj = self._dql(f, f'{lp}.self_attn.v_proj', D)
|
| 897 |
+
dl.attn.o_proj = self._dql(f, f'{lp}.self_attn.o_proj', D)
|
| 898 |
+
dl.gate_proj = self._dql(f, f'{lp}.mlp.gate_proj', D)
|
| 899 |
+
dl.up_proj = self._dql(f, f'{lp}.mlp.up_proj', D)
|
| 900 |
+
dl.down_proj = self._dql(f, f'{lp}.mlp.down_proj', D)
|
| 901 |
+
gc.collect()
|
| 902 |
+
torch.cuda.empty_cache()
|
| 903 |
+
print(f" Reloaded {len(needs_reload)} decoder layers from disk")
|
| 904 |
+
|
| 905 |
+
gc.collect()
|
| 906 |
+
torch.cuda.empty_cache()
|
| 907 |
+
self._decoder_cached = False
|
| 908 |
+
print(f" Encoder restored in {time.time()-t0:.1f}s")
|
| 909 |
+
|
| 910 |
+
def decode_tokens(self, ids):
|
| 911 |
+
if self.tokenizer is not None:
|
| 912 |
+
try:
|
| 913 |
+
return self.tokenizer.decode(ids)
|
| 914 |
+
except:
|
| 915 |
+
pass
|
| 916 |
+
# Fallback: decode as UTF-8 byte tokens
|
| 917 |
+
# Token IDs 0-255 are byte tokens in tekken
|
| 918 |
+
result = bytearray()
|
| 919 |
+
for tid in ids:
|
| 920 |
+
if 0 <= tid <= 255:
|
| 921 |
+
result.append(tid)
|
| 922 |
+
elif 256 <= tid < 131072:
|
| 923 |
+
# BPE merge token β would need full vocab to decode
|
| 924 |
+
result.extend(f'[{tid}]'.encode())
|
| 925 |
+
try:
|
| 926 |
+
return result.decode('utf-8', errors='replace')
|
| 927 |
+
except:
|
| 928 |
+
return str(ids)
|
| 929 |
+
|
| 930 |
+
@staticmethod
|
| 931 |
+
def _pad_audio(audio: np.ndarray) -> np.ndarray:
|
| 932 |
+
"""Pad audio for streaming alignment: left silence + alignment + right silence.
|
| 933 |
+
|
| 934 |
+
Left padding of N_LEFT_PAD_TOKENS tokens aligns with the SPAD tokens
|
| 935 |
+
in the prompt. Right padding accounts for delay + offline buffer.
|
| 936 |
+
"""
|
| 937 |
+
n = len(audio)
|
| 938 |
+
left_pad = N_LEFT_PAD_TOKENS * RAW_AUDIO_PER_TOK
|
| 939 |
+
right_align = (RAW_AUDIO_PER_TOK - (n % RAW_AUDIO_PER_TOK)) % RAW_AUDIO_PER_TOK
|
| 940 |
+
right_pad = right_align + N_RIGHT_PAD_BASE * RAW_AUDIO_PER_TOK
|
| 941 |
+
return np.pad(audio.astype(np.float32), (left_pad, right_pad))
|
| 942 |
+
|
| 943 |
+
@torch.no_grad()
|
| 944 |
+
def transcribe(self, audio: np.ndarray, max_tokens: int = 512) -> str:
|
| 945 |
+
"""Transcribe audio (float32, 16kHz mono) to text."""
|
| 946 |
+
t0 = time.time()
|
| 947 |
+
|
| 948 |
+
# Evict page cache before inference (Jetson unified memory)
|
| 949 |
+
self._evict_cache()
|
| 950 |
+
free, _ = torch.cuda.mem_get_info(0)
|
| 951 |
+
print(f" CUDA free before inference: {free/1024**3:.2f} GB")
|
| 952 |
+
|
| 953 |
+
# Restore encoder to GPU if it was offloaded (only needed without Marlin)
|
| 954 |
+
if not HAS_MARLIN:
|
| 955 |
+
self._restore_encoder()
|
| 956 |
+
|
| 957 |
+
# 0. Pad audio for streaming alignment
|
| 958 |
+
audio = self._pad_audio(audio)
|
| 959 |
+
print(f" padded audio: {len(audio)} samples ({len(audio)/SAMPLE_RATE:.1f}s)")
|
| 960 |
+
|
| 961 |
+
# 1. Mel spectrogram
|
| 962 |
+
mel = self._mel(audio)
|
| 963 |
+
# Pad mel for downsample alignment
|
| 964 |
+
enc_out_len = (mel.shape[2] - 1) // 2 + 1 # after conv stride-2
|
| 965 |
+
remainder = enc_out_len % self.ds_factor
|
| 966 |
+
if remainder != 0:
|
| 967 |
+
mel = F.pad(mel, (0, (self.ds_factor - remainder) * 2))
|
| 968 |
+
print(f" mel: {mel.shape}")
|
| 969 |
+
|
| 970 |
+
# 2. Encode
|
| 971 |
+
t1 = time.time()
|
| 972 |
+
enc = self.encoder(mel) # [1, T_enc, 1280]
|
| 973 |
+
print(f" enc: {enc.shape} ({time.time()-t1:.2f}s)")
|
| 974 |
+
del mel # free mel to save memory
|
| 975 |
+
|
| 976 |
+
# 3. Downsample 4x: concat groups of ds_factor frames
|
| 977 |
+
T_enc = enc.shape[1]
|
| 978 |
+
T_ds = T_enc // self.ds_factor
|
| 979 |
+
enc_ds = enc[:, :T_ds * self.ds_factor, :].reshape(
|
| 980 |
+
1, T_ds, self.ds_factor * self.enc_h) # [1, T_ds, 5120]
|
| 981 |
+
del enc # free encoder output
|
| 982 |
+
|
| 983 |
+
# 4. Project (adapter)
|
| 984 |
+
adapter = self.projector(enc_ds) # [1, T_ds, 3072]
|
| 985 |
+
del enc_ds
|
| 986 |
+
print(f" adapter: {adapter.shape}")
|
| 987 |
+
|
| 988 |
+
# 5. Offload encoder, pre-dequantize decoder weights (only without Marlin)
|
| 989 |
+
if not HAS_MARLIN:
|
| 990 |
+
self._pre_dequant()
|
| 991 |
+
|
| 992 |
+
# 6. Build prompt: [BOS] + [SPAD] * (n_left_pad + delay_tokens)
|
| 993 |
+
prompt_len = 1 + self.n_left_pad + self.delay_tokens
|
| 994 |
+
prompt_ids = [TOKEN_BOS] + [TOKEN_STREAMING_PAD] * (self.n_left_pad + self.delay_tokens)
|
| 995 |
+
|
| 996 |
+
# Clamp prompt to adapter length
|
| 997 |
+
if prompt_len > T_ds:
|
| 998 |
+
prompt_len = T_ds
|
| 999 |
+
prompt_ids = prompt_ids[:T_ds]
|
| 1000 |
+
|
| 1001 |
+
print(f" prompt: {prompt_len} tokens, adapter: {T_ds} positions, gen budget: {T_ds - prompt_len}")
|
| 1002 |
+
|
| 1003 |
+
# 7. Allocate KV cache for the full sequence
|
| 1004 |
+
kv_cache = KVCache(self.n_layers, T_ds, self.n_kv_heads, self.head_dim,
|
| 1005 |
+
self.device, self.dtype)
|
| 1006 |
+
|
| 1007 |
+
# 8. Prefill: build embeddings = adapter + tok_embed for prompt positions
|
| 1008 |
+
prompt_tok = torch.tensor([prompt_ids], device=self.device)
|
| 1009 |
+
tok_emb = self.embed(prompt_tok) # [1, prompt_len, 3072]
|
| 1010 |
+
input_emb = adapter[:, :prompt_len, :] + tok_emb # ADDITION
|
| 1011 |
+
|
| 1012 |
+
cos = self.cos[:prompt_len]
|
| 1013 |
+
sin = self.sin[:prompt_len]
|
| 1014 |
+
|
| 1015 |
+
h = input_emb
|
| 1016 |
+
for i, layer in enumerate(self.layers):
|
| 1017 |
+
h = layer(h, cos, sin, cache=kv_cache, layer_idx=i,
|
| 1018 |
+
is_causal=True, t_cond=self.t_cond)
|
| 1019 |
+
kv_cache.advance(prompt_len)
|
| 1020 |
+
|
| 1021 |
+
h = self.norm(h)
|
| 1022 |
+
|
| 1023 |
+
# lm_head (Marlin INT4 if available, else fp16 tied embed)
|
| 1024 |
+
logits = self._lm_head(h[:, -1:, :])
|
| 1025 |
+
next_tok = logits.argmax(-1).item()
|
| 1026 |
+
# Diagnostic: top-5 tokens and logit stats
|
| 1027 |
+
topk = torch.topk(logits[0, 0], 10)
|
| 1028 |
+
print(f" prefill done ({time.time()-t1:.2f}s), first token: {next_tok}")
|
| 1029 |
+
print(f" logits: min={logits.min():.2f} max={logits.max():.2f} mean={logits.mean():.4f}")
|
| 1030 |
+
print(f" top-10: {list(zip(topk.indices.tolist(), [f'{v:.2f}' for v in topk.values.tolist()]))}")
|
| 1031 |
+
print(f" adapter stats: min={adapter.min():.4f} max={adapter.max():.4f} mean={adapter.mean():.6f}")
|
| 1032 |
+
|
| 1033 |
+
# 9. Generate: continue from prompt_len to T_ds
|
| 1034 |
+
generated = []
|
| 1035 |
+
pos = prompt_len
|
| 1036 |
+
t2 = time.time()
|
| 1037 |
+
|
| 1038 |
+
# Pre-allocate token tensor to avoid per-step allocation
|
| 1039 |
+
_tok_buf = torch.empty(1, 1, dtype=torch.long, device=self.device)
|
| 1040 |
+
|
| 1041 |
+
while pos < T_ds and next_tok != TOKEN_EOS and len(generated) < max_tokens:
|
| 1042 |
+
generated.append(next_tok)
|
| 1043 |
+
|
| 1044 |
+
# Input = adapter[pos] + tok_embed(next_tok)
|
| 1045 |
+
_tok_buf.fill_(next_tok)
|
| 1046 |
+
te = self.embed(_tok_buf)
|
| 1047 |
+
inp = adapter[:, pos:pos+1, :] + te
|
| 1048 |
+
|
| 1049 |
+
cos_s = self.cos[pos:pos+1]
|
| 1050 |
+
sin_s = self.sin[pos:pos+1]
|
| 1051 |
+
|
| 1052 |
+
h = inp
|
| 1053 |
+
for i, layer in enumerate(self.layers):
|
| 1054 |
+
h = layer(h, cos_s, sin_s, cache=kv_cache, layer_idx=i,
|
| 1055 |
+
is_causal=False, t_cond=self.t_cond)
|
| 1056 |
+
kv_cache.advance(1)
|
| 1057 |
+
|
| 1058 |
+
h = self.norm(h)
|
| 1059 |
+
logits = self._lm_head(h)
|
| 1060 |
+
next_tok = logits[:, -1, :].argmax(-1).item()
|
| 1061 |
+
pos += 1
|
| 1062 |
+
|
| 1063 |
+
# Progress every 25 tokens
|
| 1064 |
+
if len(generated) % 25 == 0:
|
| 1065 |
+
elapsed = time.time() - t2
|
| 1066 |
+
tps = len(generated) / max(elapsed, 0.001)
|
| 1067 |
+
n_text = sum(1 for t in generated if t >= 1000)
|
| 1068 |
+
print(f" step {pos}/{T_ds}: {len(generated)} tok ({n_text} text), {tps:.1f} tok/s")
|
| 1069 |
+
|
| 1070 |
+
if next_tok == TOKEN_EOS:
|
| 1071 |
+
generated.append(TOKEN_EOS)
|
| 1072 |
+
|
| 1073 |
+
dt = time.time() - t2
|
| 1074 |
+
n_gen = len(generated)
|
| 1075 |
+
|
| 1076 |
+
# Filter: text tokens are >= 1000 (special tokens are < 1000)
|
| 1077 |
+
text_toks = [t for t in generated if t >= 1000]
|
| 1078 |
+
special_toks = [t for t in generated if t < 1000]
|
| 1079 |
+
text = self.decode_tokens(text_toks)
|
| 1080 |
+
|
| 1081 |
+
print(f" gen: {n_gen} tokens ({len(special_toks)} special, {len(text_toks)} text) in {dt:.2f}s "
|
| 1082 |
+
f"({n_gen/max(dt,0.001):.1f} tok/s)")
|
| 1083 |
+
if special_toks:
|
| 1084 |
+
print(f" special tokens: {sorted(set(special_toks))[:10]}")
|
| 1085 |
+
if text_toks:
|
| 1086 |
+
print(f" first text IDs: {text_toks[:20]}")
|
| 1087 |
+
print(f" total: {time.time()-t0:.2f}s")
|
| 1088 |
+
|
| 1089 |
+
return text
|
| 1090 |
+
|
| 1091 |
+
def _mel(self, audio: np.ndarray) -> torch.Tensor:
|
| 1092 |
+
"""Compute Whisper-style log-mel spectrogram.
|
| 1093 |
+
|
| 1094 |
+
Uses log10, max-relative normalization, and [0,1] scaling matching
|
| 1095 |
+
the standard Whisper preprocessing pipeline.
|
| 1096 |
+
"""
|
| 1097 |
+
at = torch.from_numpy(audio).float()
|
| 1098 |
+
win = torch.hann_window(WINDOW_SIZE)
|
| 1099 |
+
stft = torch.stft(at, WINDOW_SIZE, HOP_LENGTH, WINDOW_SIZE, win,
|
| 1100 |
+
return_complex=True)
|
| 1101 |
+
magnitudes = stft[..., :-1].abs().pow(2)
|
| 1102 |
+
|
| 1103 |
+
# Mel filterbank
|
| 1104 |
+
fb = self._melfb(WINDOW_SIZE // 2 + 1, N_MELS, SAMPLE_RATE, 0.0, 8000.0)
|
| 1105 |
+
mel = fb @ magnitudes
|
| 1106 |
+
|
| 1107 |
+
# Voxtral log normalization: fixed global max, NOT per-sample
|
| 1108 |
+
log_mel = torch.log10(mel.clamp(min=1e-10))
|
| 1109 |
+
log_mel = torch.clamp(log_mel, min=GLOBAL_LOG_MEL_MAX - 8.0) # floor at -6.5
|
| 1110 |
+
log_mel = (log_mel + 4.0) / 4.0
|
| 1111 |
+
|
| 1112 |
+
return log_mel.unsqueeze(0).to(device=self.device, dtype=self.dtype)
|
| 1113 |
+
|
| 1114 |
+
@staticmethod
|
| 1115 |
+
def _melfb(n_fft_bins, n_mels, sr, fmin=0.0, fmax=8000.0):
|
| 1116 |
+
"""Slaney mel filterbank matching transformers.audio_utils.mel_filter_bank."""
|
| 1117 |
+
# Slaney mel scale (piecewise linear/log, NOT HTK)
|
| 1118 |
+
f_sp = 200.0 / 3.0
|
| 1119 |
+
min_log_hz = 1000.0
|
| 1120 |
+
min_log_mel = min_log_hz / f_sp # 15.0
|
| 1121 |
+
logstep = np.log(6.4) / 27.0
|
| 1122 |
+
|
| 1123 |
+
def hz_to_mel(f):
|
| 1124 |
+
f = np.asarray(f, dtype=np.float64)
|
| 1125 |
+
mel = np.where(f < min_log_hz, f / f_sp,
|
| 1126 |
+
min_log_mel + np.log(f / min_log_hz) / logstep)
|
| 1127 |
+
return mel
|
| 1128 |
+
|
| 1129 |
+
def mel_to_hz(m):
|
| 1130 |
+
m = np.asarray(m, dtype=np.float64)
|
| 1131 |
+
hz = np.where(m < min_log_mel, m * f_sp,
|
| 1132 |
+
min_log_hz * np.exp(logstep * (m - min_log_mel)))
|
| 1133 |
+
return hz
|
| 1134 |
+
|
| 1135 |
+
mel_min = hz_to_mel(fmin)
|
| 1136 |
+
mel_max = hz_to_mel(fmax)
|
| 1137 |
+
mels = np.linspace(mel_min, mel_max, n_mels + 2)
|
| 1138 |
+
freqs = mel_to_hz(mels)
|
| 1139 |
+
|
| 1140 |
+
fft_freqs = np.linspace(0, sr / 2, n_fft_bins)
|
| 1141 |
+
fb = np.zeros((n_mels, n_fft_bins))
|
| 1142 |
+
for i in range(n_mels):
|
| 1143 |
+
low, center, high = freqs[i], freqs[i + 1], freqs[i + 2]
|
| 1144 |
+
for j in range(n_fft_bins):
|
| 1145 |
+
if low <= fft_freqs[j] <= center and center > low:
|
| 1146 |
+
fb[i, j] = (fft_freqs[j] - low) / (center - low)
|
| 1147 |
+
elif center < fft_freqs[j] <= high and high > center:
|
| 1148 |
+
fb[i, j] = (high - fft_freqs[j]) / (high - center)
|
| 1149 |
+
|
| 1150 |
+
# Slaney normalization: area = 1 per filter
|
| 1151 |
+
enorm = 2.0 / (freqs[2:n_mels+2] - freqs[:n_mels])
|
| 1152 |
+
fb *= enorm[:, np.newaxis]
|
| 1153 |
+
|
| 1154 |
+
return torch.tensor(fb, dtype=torch.float32)
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
# βββ WebSocket Server ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1158 |
+
|
| 1159 |
+
async def handle_ws(ws, model):
|
| 1160 |
+
sid = str(uuid.uuid4())[:8]
|
| 1161 |
+
await ws.send(json.dumps({"type": "session.created", "session": {"id": sid}}))
|
| 1162 |
+
|
| 1163 |
+
buf = bytearray()
|
| 1164 |
+
async for msg in ws:
|
| 1165 |
+
try:
|
| 1166 |
+
ev = json.loads(msg)
|
| 1167 |
+
t = ev.get("type", "")
|
| 1168 |
+
|
| 1169 |
+
if t == "session.update":
|
| 1170 |
+
pass # model name ignored, we only have one
|
| 1171 |
+
|
| 1172 |
+
elif t == "input_audio_buffer.append":
|
| 1173 |
+
buf.extend(base64.b64decode(ev.get("audio", "")))
|
| 1174 |
+
|
| 1175 |
+
elif t == "input_audio_buffer.commit":
|
| 1176 |
+
if ev.get("final"):
|
| 1177 |
+
break
|
| 1178 |
+
if not buf:
|
| 1179 |
+
continue
|
| 1180 |
+
|
| 1181 |
+
pcm = np.frombuffer(bytes(buf), dtype=np.int16).astype(np.float32) / 32768.0
|
| 1182 |
+
buf = bytearray()
|
| 1183 |
+
dur = len(pcm) / SAMPLE_RATE
|
| 1184 |
+
print(f"[{sid}] {dur:.1f}s audio")
|
| 1185 |
+
|
| 1186 |
+
text = model.transcribe(pcm)
|
| 1187 |
+
|
| 1188 |
+
if text:
|
| 1189 |
+
await ws.send(json.dumps({"type": "transcription.delta", "delta": text}))
|
| 1190 |
+
await ws.send(json.dumps({
|
| 1191 |
+
"type": "transcription.done", "text": text,
|
| 1192 |
+
"usage": {"audio_duration_s": round(dur, 1)}
|
| 1193 |
+
}))
|
| 1194 |
+
print(f"[{sid}] -> {text[:100]}")
|
| 1195 |
+
|
| 1196 |
+
except Exception as e:
|
| 1197 |
+
import traceback
|
| 1198 |
+
traceback.print_exc()
|
| 1199 |
+
await ws.send(json.dumps({"type": "error", "error": str(e)}))
|
| 1200 |
+
|
| 1201 |
+
|
| 1202 |
+
async def serve(model, host='0.0.0.0', port=8000):
|
| 1203 |
+
import websockets
|
| 1204 |
+
print(f"\nWebSocket server at ws://{host}:{port}/v1/realtime")
|
| 1205 |
+
|
| 1206 |
+
async def handler(ws, path=None):
|
| 1207 |
+
await handle_ws(ws, model)
|
| 1208 |
+
|
| 1209 |
+
try:
|
| 1210 |
+
async with websockets.serve(handler, host, port):
|
| 1211 |
+
print("Ready.")
|
| 1212 |
+
await asyncio.Future()
|
| 1213 |
+
except TypeError:
|
| 1214 |
+
await websockets.serve(handler, host, port)
|
| 1215 |
+
print("Ready.")
|
| 1216 |
+
await asyncio.Future()
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
# βββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1220 |
+
|
| 1221 |
+
def main():
|
| 1222 |
+
ap = argparse.ArgumentParser()
|
| 1223 |
+
ap.add_argument('--model-path', default='/workspace/voxtral-quant/models/voxtral-rtn-4bit-packed-vllm')
|
| 1224 |
+
ap.add_argument('--port', type=int, default=8000)
|
| 1225 |
+
ap.add_argument('--test', help='WAV file to transcribe (skip server)')
|
| 1226 |
+
ap.add_argument('--device', default='cuda')
|
| 1227 |
+
ap.add_argument('--no-compile', action='store_true', help='Disable torch.compile')
|
| 1228 |
+
args = ap.parse_args()
|
| 1229 |
+
|
| 1230 |
+
print("=" * 60)
|
| 1231 |
+
print("Voxtral Mini 4B Realtime β Jetson Orin Nano 8GB")
|
| 1232 |
+
print("=" * 60)
|
| 1233 |
+
|
| 1234 |
+
if torch.cuda.is_available():
|
| 1235 |
+
props = torch.cuda.get_device_properties(0)
|
| 1236 |
+
print(f"GPU: {props.name}, {props.total_memory/1024**3:.1f} GB")
|
| 1237 |
+
free, total = torch.cuda.mem_get_info(0)
|
| 1238 |
+
print(f"CUDA free: {free/1024**3:.2f} GB / {total/1024**3:.2f} GB")
|
| 1239 |
+
# Force kernel to reclaim page cache (critical on Jetson unified memory)
|
| 1240 |
+
if free < 5 * 1024**3:
|
| 1241 |
+
import ctypes, gc
|
| 1242 |
+
print("Reclaiming page cache...")
|
| 1243 |
+
for sz in [4, 3, 2, 1]:
|
| 1244 |
+
try:
|
| 1245 |
+
buf = ctypes.create_string_buffer(sz * 1024 * 1024 * 1024)
|
| 1246 |
+
del buf
|
| 1247 |
+
gc.collect()
|
| 1248 |
+
break
|
| 1249 |
+
except (MemoryError, OSError):
|
| 1250 |
+
continue
|
| 1251 |
+
free2, _ = torch.cuda.mem_get_info(0)
|
| 1252 |
+
print(f"CUDA free after reclaim: {free2/1024**3:.2f} GB")
|
| 1253 |
+
|
| 1254 |
+
model = VoxtralModel(args.model_path, args.device, compile=not args.no_compile)
|
| 1255 |
+
model._load_tokenizer()
|
| 1256 |
+
|
| 1257 |
+
if args.test:
|
| 1258 |
+
import soundfile as sf
|
| 1259 |
+
audio, sr = sf.read(args.test, dtype='float32')
|
| 1260 |
+
if audio.ndim > 1:
|
| 1261 |
+
audio = audio.mean(1)
|
| 1262 |
+
if sr != SAMPLE_RATE:
|
| 1263 |
+
try:
|
| 1264 |
+
import soxr
|
| 1265 |
+
audio = soxr.resample(audio, sr, SAMPLE_RATE, quality='HQ')
|
| 1266 |
+
except ImportError:
|
| 1267 |
+
import librosa
|
| 1268 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
|
| 1269 |
+
print(f"\nTest: {args.test} ({len(audio)/SAMPLE_RATE:.1f}s)")
|
| 1270 |
+
text = model.transcribe(audio)
|
| 1271 |
+
print(f"\nResult: {text}")
|
| 1272 |
+
else:
|
| 1273 |
+
asyncio.run(serve(model, port=args.port))
|
| 1274 |
+
|
| 1275 |
+
|
| 1276 |
+
if __name__ == '__main__':
|
| 1277 |
+
main()
|
tekken.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8434af1d39eba99f0ef46cf1450bf1a63fa941a26933a1ef5dbbf4adf0d00e44
|
| 3 |
+
size 14910348
|
voxtral_client.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Voxtral Realtime Transcription β Gradio client for Jetson WebSocket API.
|
| 3 |
+
|
| 4 |
+
Captures microphone audio and streams it to the Voxtral server running on
|
| 5 |
+
a Jetson Orin Nano for live speech-to-text transcription.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
pip install gradio websockets numpy
|
| 9 |
+
python voxtral_client.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import asyncio
|
| 13 |
+
import base64
|
| 14 |
+
import json
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import gradio as gr
|
| 19 |
+
|
| 20 |
+
DEFAULT_WS_URL = "ws://localhost:8000/v1/realtime"
|
| 21 |
+
TARGET_SR = 16000
|
| 22 |
+
CHUNK_SAMPLES = TARGET_SR // 2 # 0.5s chunks for WebSocket
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# βββ Audio Utilities βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
|
| 27 |
+
def resample_audio(audio, orig_sr, target_sr):
|
| 28 |
+
"""Resample audio to target sample rate."""
|
| 29 |
+
if orig_sr == target_sr:
|
| 30 |
+
return audio
|
| 31 |
+
try:
|
| 32 |
+
from scipy.signal import resample_poly
|
| 33 |
+
from math import gcd
|
| 34 |
+
g = gcd(int(orig_sr), int(target_sr))
|
| 35 |
+
return resample_poly(audio, target_sr // g, orig_sr // g).astype(np.float32)
|
| 36 |
+
except ImportError:
|
| 37 |
+
# Linear interpolation fallback (no extra dependencies)
|
| 38 |
+
n_out = int(len(audio) * target_sr / orig_sr)
|
| 39 |
+
return np.interp(
|
| 40 |
+
np.linspace(0, len(audio) - 1, n_out),
|
| 41 |
+
np.arange(len(audio)),
|
| 42 |
+
audio,
|
| 43 |
+
).astype(np.float32)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def audio_to_pcm16(audio_tuple):
|
| 47 |
+
"""Convert Gradio audio (sr, ndarray) to 16kHz PCM16 mono. Returns (pcm16, duration)."""
|
| 48 |
+
if audio_tuple is None:
|
| 49 |
+
return None, 0.0
|
| 50 |
+
sr, data = audio_tuple
|
| 51 |
+
if data is None or len(data) == 0:
|
| 52 |
+
return None, 0.0
|
| 53 |
+
|
| 54 |
+
# Mono
|
| 55 |
+
if data.ndim > 1:
|
| 56 |
+
data = data.mean(axis=1)
|
| 57 |
+
# Normalize to float32 [-1, 1]
|
| 58 |
+
if np.issubdtype(data.dtype, np.integer):
|
| 59 |
+
data = data.astype(np.float32) / np.iinfo(data.dtype).max
|
| 60 |
+
else:
|
| 61 |
+
data = data.astype(np.float32)
|
| 62 |
+
|
| 63 |
+
if sr != TARGET_SR:
|
| 64 |
+
data = resample_audio(data, sr, TARGET_SR)
|
| 65 |
+
|
| 66 |
+
pcm16 = (data * 32768.0).clip(-32768, 32767).astype(np.int16)
|
| 67 |
+
return pcm16, len(pcm16) / TARGET_SR
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# βββ WebSocket Client ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
|
| 72 |
+
async def _transcribe_async(pcm16, ws_url):
|
| 73 |
+
"""Send PCM16 audio to Voxtral WebSocket, return (text, usage, error)."""
|
| 74 |
+
import websockets
|
| 75 |
+
|
| 76 |
+
async with websockets.connect(ws_url, close_timeout=5, open_timeout=10) as ws:
|
| 77 |
+
msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=10))
|
| 78 |
+
if msg["type"] != "session.created":
|
| 79 |
+
return "", {}, f"Unexpected first message: {msg['type']}"
|
| 80 |
+
|
| 81 |
+
await ws.send(json.dumps({"type": "session.update"}))
|
| 82 |
+
|
| 83 |
+
# Send audio in 0.5s chunks
|
| 84 |
+
for i in range(0, len(pcm16), CHUNK_SAMPLES):
|
| 85 |
+
chunk = pcm16[i:i + CHUNK_SAMPLES]
|
| 86 |
+
await ws.send(json.dumps({
|
| 87 |
+
"type": "input_audio_buffer.append",
|
| 88 |
+
"audio": base64.b64encode(chunk.tobytes()).decode("ascii"),
|
| 89 |
+
}))
|
| 90 |
+
|
| 91 |
+
await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
|
| 92 |
+
|
| 93 |
+
text = ""
|
| 94 |
+
usage = {}
|
| 95 |
+
while True:
|
| 96 |
+
msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=120))
|
| 97 |
+
if msg["type"] == "transcription.delta":
|
| 98 |
+
text += msg.get("delta", "")
|
| 99 |
+
elif msg["type"] == "transcription.done":
|
| 100 |
+
text = msg.get("text", text)
|
| 101 |
+
usage = msg.get("usage", {})
|
| 102 |
+
break
|
| 103 |
+
elif msg["type"] == "error":
|
| 104 |
+
return "", {}, f"Server error: {msg.get('error', 'unknown')}"
|
| 105 |
+
|
| 106 |
+
# Signal done
|
| 107 |
+
await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
|
| 108 |
+
|
| 109 |
+
return text, usage, None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def transcribe_ws(pcm16, ws_url):
|
| 113 |
+
"""Synchronous wrapper β safe to call from Gradio thread callbacks."""
|
| 114 |
+
loop = asyncio.new_event_loop()
|
| 115 |
+
try:
|
| 116 |
+
return loop.run_until_complete(_transcribe_async(pcm16, ws_url))
|
| 117 |
+
finally:
|
| 118 |
+
loop.close()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# βββ Gradio Callbacks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
+
|
| 123 |
+
def format_history(entries):
|
| 124 |
+
if not entries:
|
| 125 |
+
return ""
|
| 126 |
+
lines = []
|
| 127 |
+
for e in entries:
|
| 128 |
+
lines.append(f"[{e['time']}] ({e['audio_s']}s audio, {e['elapsed_s']}s total) {e['text']}")
|
| 129 |
+
return "\n".join(lines)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def process_recording(audio, ws_url, history_state):
|
| 133 |
+
"""Transcribe recorded audio and update history."""
|
| 134 |
+
if audio is None:
|
| 135 |
+
return "", history_state, "No audio recorded", format_history(history_state)
|
| 136 |
+
|
| 137 |
+
pcm16, duration = audio_to_pcm16(audio)
|
| 138 |
+
if pcm16 is None or len(pcm16) < TARGET_SR // 10: # < 0.1s
|
| 139 |
+
return "", history_state, "Audio too short", format_history(history_state)
|
| 140 |
+
|
| 141 |
+
t0 = time.time()
|
| 142 |
+
try:
|
| 143 |
+
text, usage, error = transcribe_ws(pcm16, ws_url)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
err = str(e)
|
| 146 |
+
if "Connect call failed" in err or "refused" in err:
|
| 147 |
+
err = f"Cannot reach server at {ws_url} β is it running?"
|
| 148 |
+
return "", history_state, f"Error: {err}", format_history(history_state)
|
| 149 |
+
|
| 150 |
+
if error:
|
| 151 |
+
return "", history_state, f"Error: {error}", format_history(history_state)
|
| 152 |
+
|
| 153 |
+
elapsed = time.time() - t0
|
| 154 |
+
text = text.strip()
|
| 155 |
+
|
| 156 |
+
entry = {
|
| 157 |
+
"time": time.strftime("%H:%M:%S"),
|
| 158 |
+
"text": text,
|
| 159 |
+
"audio_s": round(duration, 1),
|
| 160 |
+
"elapsed_s": round(elapsed, 1),
|
| 161 |
+
}
|
| 162 |
+
history_state = history_state + [entry]
|
| 163 |
+
|
| 164 |
+
status = f"Transcribed {duration:.1f}s audio in {elapsed:.1f}s"
|
| 165 |
+
return text, history_state, status, format_history(history_state)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def clear_history():
|
| 169 |
+
return [], "", ""
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# βββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
+
|
| 174 |
+
with gr.Blocks(title="Voxtral Realtime Transcription", theme=gr.themes.Soft()) as app:
|
| 175 |
+
history_state = gr.State([])
|
| 176 |
+
|
| 177 |
+
gr.Markdown("# Voxtral Realtime Transcription")
|
| 178 |
+
gr.Markdown(
|
| 179 |
+
"Record from your microphone or upload a WAV file. "
|
| 180 |
+
"Audio is sent to Voxtral on a Jetson Orin Nano for transcription."
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
ws_url = gr.Textbox(value=DEFAULT_WS_URL, label="Jetson WebSocket URL")
|
| 184 |
+
|
| 185 |
+
with gr.Row():
|
| 186 |
+
with gr.Column(scale=3):
|
| 187 |
+
audio_input = gr.Audio(
|
| 188 |
+
sources=["microphone", "upload"],
|
| 189 |
+
type="numpy",
|
| 190 |
+
label="Microphone (click to record, click again to stop)",
|
| 191 |
+
)
|
| 192 |
+
with gr.Column(scale=1, min_width=120):
|
| 193 |
+
transcribe_btn = gr.Button("Transcribe", variant="primary", size="lg")
|
| 194 |
+
|
| 195 |
+
status = gr.Textbox(label="Status", interactive=False)
|
| 196 |
+
|
| 197 |
+
current_text = gr.Textbox(label="Current Transcription", interactive=False, lines=3)
|
| 198 |
+
|
| 199 |
+
with gr.Row():
|
| 200 |
+
history_text = gr.Textbox(
|
| 201 |
+
label="Transcript History",
|
| 202 |
+
interactive=False,
|
| 203 |
+
lines=8,
|
| 204 |
+
scale=5,
|
| 205 |
+
)
|
| 206 |
+
clear_btn = gr.Button("Clear History", scale=1)
|
| 207 |
+
|
| 208 |
+
# Auto-transcribe when mic recording stops
|
| 209 |
+
audio_input.stop_recording(
|
| 210 |
+
fn=process_recording,
|
| 211 |
+
inputs=[audio_input, ws_url, history_state],
|
| 212 |
+
outputs=[current_text, history_state, status, history_text],
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Manual transcribe button (for uploads or re-transcription)
|
| 216 |
+
transcribe_btn.click(
|
| 217 |
+
fn=process_recording,
|
| 218 |
+
inputs=[audio_input, ws_url, history_state],
|
| 219 |
+
outputs=[current_text, history_state, status, history_text],
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
clear_btn.click(fn=clear_history, outputs=[history_state, history_text, current_text])
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
if __name__ == "__main__":
|
| 226 |
+
app.launch(server_name="0.0.0.0", server_port=7860)
|