File size: 1,620 Bytes
9e72440
 
17fb016
 
34ffbd5
9e72440
a926d98
34ffbd5
9e72440
 
34ffbd5
9e72440
34ffbd5
 
 
9e72440
34ffbd5
9e72440
a926d98
 
17fb016
9e72440
 
ca49a8a
9e72440
ca49a8a
9e72440
17fb016
9e72440
 
574dde7
9e72440
 
17fb016
9e72440
ca49a8a
17fb016
9e72440
 
34ffbd5
 
9e72440
 
34ffbd5
9e72440
 
17fb016
34ffbd5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# pip install gradio torch torchaudio snac

import torch
import torchaudio
import gradio as gr
from snac import SNAC

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TARGET_SR = 32000  # using the 32 kHz model per your example
MODEL = SNAC.from_pretrained("hubertsiuzdak/snac_32khz").eval().to(DEVICE)

def encode_then_decode(audio_in):
    if audio_in is None:
        return None

    sr, data = audio_in  # data: (T,) mono or (T, C) stereo

    # mono-ize if needed
    if data.ndim == 2:
        data = data.mean(axis=1)

    # torchify to [1, T]
    x = torch.from_numpy(data).float().unsqueeze(0)

    # resample to model's target SR
    if sr != TARGET_SR:
        x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=TARGET_SR)

    # expand to [B, 1, T] then encode->decode
    x = x.unsqueeze(0).to(DEVICE)  # [1, 1, T]
    with torch.inference_mode():
        codes = MODEL.encode(x)
        y = MODEL.decode(codes)  # [1, 1, T]

    y = y.squeeze().detach().cpu().numpy()
    return (TARGET_SR, y)

with gr.Blocks(title="SNAC Encode→Decode (Simple)") as demo:
    gr.Markdown("## 🎧 SNAC Encode → Decode (32 kHz)\nResample → `encode()` → `decode()` — that’s it.")
    with gr.Row():
        with gr.Column():
            audio_in = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Input audio")
            run = gr.Button("Encode + Decode")
        with gr.Column():
            audio_out = gr.Audio(type="numpy", label="Reconstructed (32 kHz)")
    run.click(encode_then_decode, inputs=audio_in, outputs=audio_out)

if __name__ == "__main__":
    demo.launch()