import torch from audiocraft.models import AudioGen from audiocraft.data.audio import audio_write class AudioCraftGenerator(torch.nn.Module): def __init__(self, model_name='facebook/audiogen-medium', duration=5): super(AudioCraftGenerator, self).__init__() self.model = AudioGen.get_pretrained(model_name) self.model.set_generation_params(duration=duration) def forward(self, descriptions): wav = self.model.generate(descriptions) return wav def save_wav(self, wav, idx): audio_write(f'{idx}', wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)