File size: 656 Bytes
fa7a770
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import PreTrainedModel, PretrainedConfig
import torch
from audio_craft_model import AudioCraftGenerator

class AudioCraftForHuggingFace(PreTrainedModel):
    def __init__(self, config: PretrainedConfig, model_name='facebook/audiogen-medium', duration=5):
        super(AudioCraftForHuggingFace, self).__init__(config)
        self.audio_craft_generator = AudioCraftGenerator(model_name, duration)

    def forward(self, descriptions):
        with torch.no_grad():
            wav = self.audio_craft_generator(descriptions)
        return wav

    def save_wav(self, wav, idx):
        return self.audio_craft_generator.save_wav(wav, idx)