import torch | |
from audiocraft.models import AudioGen | |
from audiocraft.data.audio import audio_write | |
class AudioCraftGenerator(torch.nn.Module): | |
def __init__(self, model_name='facebook/audiogen-medium', duration=5): | |
super(AudioCraftGenerator, self).__init__() | |
self.model = AudioGen.get_pretrained(model_name) | |
self.model.set_generation_params(duration=duration) | |
def forward(self, descriptions): | |
wav = self.model.generate(descriptions) | |
return wav | |
def save_wav(self, wav, idx): | |
audio_write(f'{idx}', wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True) |