from transformers import PreTrainedModel, PretrainedConfig import torch from audio_craft_model import AudioCraftGenerator class AudioCraftForHuggingFace(PreTrainedModel): def __init__(self, config: PretrainedConfig, model_name='facebook/audiogen-medium', duration=5): super(AudioCraftForHuggingFace, self).__init__(config) self.audio_craft_generator = AudioCraftGenerator(model_name, duration) def forward(self, descriptions): with torch.no_grad(): wav = self.audio_craft_generator(descriptions) return wav def save_wav(self, wav, idx): return self.audio_craft_generator.save_wav(wav, idx)