from transformers import PreTrainedModel, PretrainedConfig | |
import torch | |
from audio_craft_model import AudioCraftGenerator | |
class AudioCraftForHuggingFace(PreTrainedModel): | |
def __init__(self, config: PretrainedConfig, model_name='facebook/audiogen-medium', duration=5): | |
super(AudioCraftForHuggingFace, self).__init__(config) | |
self.audio_craft_generator = AudioCraftGenerator(model_name, duration) | |
def forward(self, descriptions): | |
with torch.no_grad(): | |
wav = self.audio_craft_generator(descriptions) | |
return wav | |
def save_wav(self, wav, idx): | |
return self.audio_craft_generator.save_wav(wav, idx) |