strumber commited on
Commit
fa7a770
1 Parent(s): 92deba3

Create huggingface_integration.py

Browse files
Files changed (1) hide show
  1. huggingface_integration.py +16 -0
huggingface_integration.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ import torch
3
+ from audio_craft_model import AudioCraftGenerator
4
+
5
+ class AudioCraftForHuggingFace(PreTrainedModel):
6
+ def __init__(self, config: PretrainedConfig, model_name='facebook/audiogen-medium', duration=5):
7
+ super(AudioCraftForHuggingFace, self).__init__(config)
8
+ self.audio_craft_generator = AudioCraftGenerator(model_name, duration)
9
+
10
+ def forward(self, descriptions):
11
+ with torch.no_grad():
12
+ wav = self.audio_craft_generator(descriptions)
13
+ return wav
14
+
15
+ def save_wav(self, wav, idx):
16
+ return self.audio_craft_generator.save_wav(wav, idx)