Meta-AudioGen-API / audio_craft_model.py
strumber's picture
Create audio_craft_model.py
92deba3
import torch
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write
class AudioCraftGenerator(torch.nn.Module):
def __init__(self, model_name='facebook/audiogen-medium', duration=5):
super(AudioCraftGenerator, self).__init__()
self.model = AudioGen.get_pretrained(model_name)
self.model.set_generation_params(duration=duration)
def forward(self, descriptions):
wav = self.model.generate(descriptions)
return wav
def save_wav(self, wav, idx):
audio_write(f'{idx}', wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)