zero2story / modules /music_maker.py
chansung's picture
.
3332aa4
from typing import Literal
from tempfile import NamedTemporaryFile
from pathlib import Path
import uuid
import shutil
import json
import asyncio
import toml
import torch
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
from pydub import AudioSegment
from .utils import set_all_seeds
from modules.llms import get_llm_factory
class MusicMaker:
# TODO: DocString...
"""Class for generating music from prompts."""
def __init__(self, model_size: Literal['small', 'medium', 'melody', 'large'] = 'large',
output_format: Literal['wav', 'mp3'] = 'mp3',
device: str = None) -> None:
"""Initialize the MusicMaker class.
Args:
model_size (Literal['small', 'medium', 'melody', 'large'], optional): Model size. Defaults to 'large'.
output_format (Literal['wav', 'mp3'], optional): Output format. Defaults to 'mp3'.
device (str, optional): Device to use for the model. Defaults to None.
"""
self.__model_size = model_size
self.__output_format = output_format
self.__device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if not device else device
print("Loading the MusicGen model into memory...")
self.__mg_model = MusicGen.get_pretrained(self.model_size, device=self.device)
self.__mg_model.set_generation_params(use_sampling=True,
top_k=250,
top_p=0.0,
temperature=1.0,
cfg_coef=3.0
)
output_dir = Path('.') / 'outputs'
if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=True)
elif output_dir.is_file():
assert False, f"A file with the same name as the desired directory ('{str(output_dir)}') already exists."
def text2music(self, prompt: str, length: int = 60, seed: int = None) -> str:
"""Generate a music from the prompt.
Args:
prompt (str): Prompt to generate the music from.
length (int, optional): Length of the music in seconds. Defaults to 60.
seed (int, optional): Seed to use for the generation. Defaults to None.
Returns:
str: Path to the generated music.
"""
def wavToMp3(src_file: str, dest_file: str) -> None:
sound = AudioSegment.from_wav(src_file)
sound.export(dest_file, format="mp3")
output_filename = Path('.') / 'outputs' / str(uuid.uuid4())
if not seed or seed == -1:
seed = torch.randint(0, 2**32 - 1, (1,)).item()
set_all_seeds(seed)
self.__mg_model.set_generation_params(duration=length)
output = self.__mg_model.generate(descriptions=[prompt], progress=True)[0]
with NamedTemporaryFile("wb", delete=True) as temp_file:
audio_write(temp_file.name, output.cpu(), self.__mg_model.sample_rate, strategy="loudness", loudness_compressor=True)
if self.output_format == 'mp3':
wavToMp3(f'{temp_file.name}.wav', str(output_filename.with_suffix('.mp3')))
else:
shutil.copy(f'{temp_file.name}.wav', str(output_filename.with_suffix('.wav')))
return str(output_filename.with_suffix('.mp3' if self.output_format == 'mp3' else '.wav'))
def generate_prompt(self, genre:str, place:str, mood:str,
title:str, chapter_title:str, chapter_plot:str,
llm_type:str="PaLM",
) -> str:
"""Generate a prompt for a background music based on given attributes.
Args:
genre (str): Genre of the story.
place (str): Place of the story.
mood (str): Mood of the story.
title (str): Title of the story.
chapter_title (str): Title of the chapter.
chapter_plot (str): Plot of the chapter.
Returns:
str: Generated prompt.
"""
factory = get_llm_factory(llm_type)
prompt_manager = factory.create_prompt_manager()
llm_service = factory.create_llm_service()
# Generate prompts with PaLM
t = prompt_manager.prompts['music_gen']['gen_prompt']
q = prompt_manager.prompts['music_gen']['query']
query_string = t.format(input=q.format(genre=genre,
place=place,
mood=mood,
title=title,
chapter_title=chapter_title,
chapter_plot=chapter_plot))
try:
response, response_txt = asyncio.run(asyncio.wait_for(
llm_service.gen_text(query_string, mode="text", use_filter=False),
timeout=10)
)
except asyncio.TimeoutError:
raise TimeoutError("The response time for PaLM API exceeded the limit.")
except:
raise Exception("PaLM API is not available.")
try:
res_json = json.loads(response_txt)
except:
print("=== PaLM Response ===")
print(response.filters)
print(response_txt)
print("=== PaLM Response ===")
raise ValueError("The response from PaLM API is not in the expected format.")
return res_json['primary_sentence']
@property
def model_size(self):
"""Model size
Returns:
Literal['small', 'medium', 'melody', 'large']: The model size (read-only)
"""
return self.__model_size
@property
def output_format(self):
"""Output format
Returns:
Literal['wav', 'mp3']: The output format (read-only)
"""
return self.__output_format
@property
def device(self):
"""Device
Returns:
str: The device (read-only)
"""
return self.__device