strumber commited on
Commit
92deba3
1 Parent(s): ed34a50

Create audio_craft_model.py

Browse files
Files changed (1) hide show
  1. audio_craft_model.py +16 -0
audio_craft_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from audiocraft.models import AudioGen
3
+ from audiocraft.data.audio import audio_write
4
+
5
+ class AudioCraftGenerator(torch.nn.Module):
6
+ def __init__(self, model_name='facebook/audiogen-medium', duration=5):
7
+ super(AudioCraftGenerator, self).__init__()
8
+ self.model = AudioGen.get_pretrained(model_name)
9
+ self.model.set_generation_params(duration=duration)
10
+
11
+ def forward(self, descriptions):
12
+ wav = self.model.generate(descriptions)
13
+ return wav
14
+
15
+ def save_wav(self, wav, idx):
16
+ audio_write(f'{idx}', wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)