File size: 638 Bytes
92deba3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write

class AudioCraftGenerator(torch.nn.Module):
    def __init__(self, model_name='facebook/audiogen-medium', duration=5):
        super(AudioCraftGenerator, self).__init__()
        self.model = AudioGen.get_pretrained(model_name)
        self.model.set_generation_params(duration=duration)

    def forward(self, descriptions):
        wav = self.model.generate(descriptions)
        return wav

    def save_wav(self, wav, idx):
        audio_write(f'{idx}', wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)