File size: 2,467 Bytes
5066f7a
 
 
2e58327
5066f7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
---
base_model:
- descript/dac_16khz
library_name: transformers.js
---

ONNX weights for https://huggingface.co/descript/dac_16khz.

## Inference sample code
```py
import onnxruntime as ort

encoder_session = ort.InferenceSession("encoder_model.onnx")
decoder_session = ort.InferenceSession("decoder_model.onnx")

encoder_inputs = {encoder_session.get_inputs()[0].name: dummy_encoder_inputs.numpy()}
encoder_outputs = encoder_session.run(None, encoder_inputs)[0]

decoder_inputs = {decoder_session.get_inputs()[0].name: encoder_outputs}
decoder_outputs = decoder_session.run(None, decoder_inputs)[0]

# Print the results
print("Encoder Output Shape:", encoder_outputs.shape)
print("Decoder Output Shape:", decoder_outputs.shape)
```

## Conversion code
```py
import torch
import torch.nn as nn
from transformers import DacModel

class DacEncoder(nn.Module):
    def __init__(self, model):
        super(DacEncoder, self).__init__()
        self.model = model

    def forward(self, input_values):
        return self.model.encode(input_values).audio_codes

class DacDecoder(nn.Module):
    def __init__(self, model):
        super(DacDecoder, self).__init__()
        self.model = model

    def forward(self, audio_codes):
        quantized_representation = self.model.quantizer.from_codes(audio_codes)[0]
        return self.model.decoder(quantized_representation)

model = DacModel.from_pretrained("descript/dac_16khz")
encoder = DacEncoder(model)
decoder = DacDecoder(model)

# Export encoder
dummy_encoder_inputs = torch.randn((4, 1, 12340))
torch.onnx.export(
    encoder,
    dummy_encoder_inputs,
    "encoder_model.onnx",
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=['input_values'],
    output_names=['audio_codes'],
    dynamic_axes={
        'input_values': {0: 'batch_size', 1: 'num_channels', 2: 'sequence_length'},
        'audio_codes': {0: 'batch_size', 2: 'time_steps'},
    },
)

# Export decoder
dummy_decoder_inputs = torch.randint(model.config.codebook_size, (4, model.config.n_codebooks, 100))
torch.onnx.export(
    decoder,
    dummy_decoder_inputs,
    "decoder_model.onnx",
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=['audio_codes'],
    output_names=['audio_values'],
    dynamic_axes={
        'audio_codes': {0: 'batch_size', 2: 'time_steps'},
        'audio_values': {0: 'batch_size', 1: 'num_channels', 2: 'sequence_length'},
    },
)
```