Meta-AudioGen-API / huggingface_integration.py
strumber's picture
Create huggingface_integration.py
fa7a770
raw
history blame
656 Bytes
from transformers import PreTrainedModel, PretrainedConfig
import torch
from audio_craft_model import AudioCraftGenerator
class AudioCraftForHuggingFace(PreTrainedModel):
def __init__(self, config: PretrainedConfig, model_name='facebook/audiogen-medium', duration=5):
super(AudioCraftForHuggingFace, self).__init__(config)
self.audio_craft_generator = AudioCraftGenerator(model_name, duration)
def forward(self, descriptions):
with torch.no_grad():
wav = self.audio_craft_generator(descriptions)
return wav
def save_wav(self, wav, idx):
return self.audio_craft_generator.save_wav(wav, idx)