File size: 638 Bytes
92deba3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write
class AudioCraftGenerator(torch.nn.Module):
def __init__(self, model_name='facebook/audiogen-medium', duration=5):
super(AudioCraftGenerator, self).__init__()
self.model = AudioGen.get_pretrained(model_name)
self.model.set_generation_params(duration=duration)
def forward(self, descriptions):
wav = self.model.generate(descriptions)
return wav
def save_wav(self, wav, idx):
audio_write(f'{idx}', wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True) |