VidTune-Gradio / engine /audio_generator.py
animikhaich's picture
Pre-Final Changes, Release Ready
ee4f393
import os
import warnings
import traceback
warnings.simplefilter("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import io
import torch
import numpy as np
from audiocraft.models import musicgen
from scipy.io.wavfile import write as wav_write
try:
from logger import logging
except:
import logging
class GenerateAudio:
def __init__(self, model="musicgen-stereo-small"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_name = self.get_model_name(model)
self.model = self.get_model(self.model_name, self.device)
self.generated_audio = None
self.sampling_rate = None
@staticmethod
def get_model(model, device):
try:
model = musicgen.MusicGen.get_pretrained(model, device=device)
logging.info(f"Loaded model: {model}")
return model
except Exception as e:
logging.error(
f"Failed to load model: {e}, Traceback: {traceback.format_exc()}"
)
raise ValueError(f"Failed to load model: {e}")
return
@staticmethod
def get_model_name(model_name):
if model_name.startswith("facebook/"):
return model_name
return f"facebook/{model_name}"
@staticmethod
def duration_sanity_check(duration):
if duration < 1:
logging.warning(
"Duration is less than 1 second. Setting duration to 1 second."
)
return 1
elif duration > 30:
logging.warning(
"Duration is greater than 30 seconds. Setting duration to 30 seconds."
)
return 30
return duration
@staticmethod
def prompts_sanity_check(prompts):
if isinstance(prompts, str):
prompts = [prompts]
elif not isinstance(prompts, list):
raise ValueError("Prompts should be a string or a list of strings.")
else:
for prompt in prompts:
if not isinstance(prompt, str):
raise ValueError("Prompts should be a string or a list of strings.")
if len(prompts) > 8: # Too many prompts will cause OOM error
raise ValueError("Maximum number of prompts allowed is 8.")
return prompts
def generate_audio(self, prompts, duration=10):
duration = self.duration_sanity_check(duration)
prompts = self.prompts_sanity_check(prompts)
try:
self.sampling_rate = self.model.sample_rate
if duration <= 30:
self.model.set_generation_params(duration=duration)
result = self.model.generate(prompts, progress=False)
elif duration > 30:
self.model.set_generation_params(duration=30)
result = self.model.generate(prompts, progress=False)
self.model.set_generation_params(duration=duration)
result = self.model.generate_with_chroma(
prompts,
result,
melody_sample_rate=self.sampling_rate,
progress=False,
)
self.result = result.cpu().numpy().T
self.result = self.result.transpose((2, 0, 1))
logging.info(
f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
)
return self.sampling_rate, self.result
except Exception as e:
logging.error(
f"Failed to generate audio: {e}, Traceback: {traceback.format_exc()}"
)
raise ValueError(f"Failed to generate audio: {e}")
def save_audio(self, audio_dir="generated_audio"):
if self.result is None:
raise ValueError("Audio is not generated yet.")
if self.sampling_rate is None:
raise ValueError("Sampling rate is not available.")
paths = []
os.makedirs(audio_dir, exist_ok=True)
for i, audio in enumerate(self.result):
path = os.path.join(audio_dir, f"audio_{i}.wav")
wav_write(path, self.sampling_rate, audio)
paths.append(path)
return paths
def get_audio_buffer(self):
if self.result is None:
raise ValueError("Audio is not generated yet.")
if self.sampling_rate is None:
raise ValueError("Sampling rate is not available.")
buffers = []
for audio in self.result:
buffer = io.BytesIO()
wav_write(buffer, self.sampling_rate, audio)
buffer.seek(0)
buffers.append(buffer)
return buffers
if __name__ == "__main__":
audio_gen = GenerateAudio()
sample_rate, result = audio_gen.generate_audio(
[
"A piano playing a jazz melody",
"A guitar playing a rock riff",
"A LoFi music for coding",
],
duration=10,
)
paths = audio_gen.save_audio()
print(f"Saved audio to: {paths}")
buffers = audio_gen.get_audio_buffer()
print(f"Audio buffers: {buffers}")