File size: 5,135 Bytes
d50bd1e
 
ee4f393
d50bd1e
 
 
8a2882e
d50bd1e
 
 
 
 
ee4f393
 
 
 
d50bd1e
 
 
 
 
 
 
8a2882e
 
d50bd1e
 
 
 
 
 
 
 
ee4f393
 
 
d50bd1e
 
 
 
 
 
 
 
ee4f393
8a2882e
 
 
ee4f393
 
 
8a2882e
 
ee4f393
 
 
8a2882e
 
 
 
 
 
 
 
 
 
 
 
 
ee4f393
8a2882e
 
 
 
 
 
d50bd1e
 
ee4f393
ae68709
 
 
 
 
 
 
ee4f393
 
 
 
 
 
8a2882e
 
5254bd1
d50bd1e
8a2882e
d50bd1e
8a2882e
d50bd1e
ee4f393
 
 
d50bd1e
 
8a2882e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d50bd1e
ee4f393
d50bd1e
8a2882e
0c4c7bf
 
ee4f393
 
 
 
 
0c4c7bf
8a2882e
 
 
ee4f393
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import warnings
import traceback

warnings.simplefilter("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import io
import torch
import numpy as np
from audiocraft.models import musicgen
from scipy.io.wavfile import write as wav_write

try:
    from logger import logging
except:
    import logging


class GenerateAudio:
    def __init__(self, model="musicgen-stereo-small"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_name = self.get_model_name(model)
        self.model = self.get_model(self.model_name, self.device)
        self.generated_audio = None
        self.sampling_rate = None

    @staticmethod
    def get_model(model, device):
        try:
            model = musicgen.MusicGen.get_pretrained(model, device=device)
            logging.info(f"Loaded model: {model}")
            return model
        except Exception as e:
            logging.error(
                f"Failed to load model: {e}, Traceback: {traceback.format_exc()}"
            )
            raise ValueError(f"Failed to load model: {e}")
            return

    @staticmethod
    def get_model_name(model_name):
        if model_name.startswith("facebook/"):
            return model_name
        return f"facebook/{model_name}"

    @staticmethod
    def duration_sanity_check(duration):
        if duration < 1:
            logging.warning(
                "Duration is less than 1 second. Setting duration to 1 second."
            )
            return 1
        elif duration > 30:
            logging.warning(
                "Duration is greater than 30 seconds. Setting duration to 30 seconds."
            )
            return 30
        return duration

    @staticmethod
    def prompts_sanity_check(prompts):
        if isinstance(prompts, str):
            prompts = [prompts]
        elif not isinstance(prompts, list):
            raise ValueError("Prompts should be a string or a list of strings.")
        else:
            for prompt in prompts:
                if not isinstance(prompt, str):
                    raise ValueError("Prompts should be a string or a list of strings.")
            if len(prompts) > 8:  # Too many prompts will cause OOM error
                raise ValueError("Maximum number of prompts allowed is 8.")
        return prompts

    def generate_audio(self, prompts, duration=10):
        duration = self.duration_sanity_check(duration)
        prompts = self.prompts_sanity_check(prompts)

        try:
            self.sampling_rate = self.model.sample_rate
            if duration <= 30:
                self.model.set_generation_params(duration=duration)
                result = self.model.generate(prompts, progress=False)
            elif duration > 30:
                self.model.set_generation_params(duration=30)
                result = self.model.generate(prompts, progress=False)
                self.model.set_generation_params(duration=duration)
                result = self.model.generate_with_chroma(
                    prompts,
                    result,
                    melody_sample_rate=self.sampling_rate,
                    progress=False,
                )
            self.result = result.cpu().numpy().T
            self.result = self.result.transpose((2, 0, 1))

            logging.info(
                f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
            )
            return self.sampling_rate, self.result
        except Exception as e:
            logging.error(
                f"Failed to generate audio: {e}, Traceback: {traceback.format_exc()}"
            )
            raise ValueError(f"Failed to generate audio: {e}")

    def save_audio(self, audio_dir="generated_audio"):
        if self.result is None:
            raise ValueError("Audio is not generated yet.")
        if self.sampling_rate is None:
            raise ValueError("Sampling rate is not available.")

        paths = []
        os.makedirs(audio_dir, exist_ok=True)
        for i, audio in enumerate(self.result):
            path = os.path.join(audio_dir, f"audio_{i}.wav")
            wav_write(path, self.sampling_rate, audio)
            paths.append(path)
        return paths

    def get_audio_buffer(self):
        if self.result is None:
            raise ValueError("Audio is not generated yet.")
        if self.sampling_rate is None:
            raise ValueError("Sampling rate is not available.")

        buffers = []
        for audio in self.result:
            buffer = io.BytesIO()
            wav_write(buffer, self.sampling_rate, audio)
            buffer.seek(0)
            buffers.append(buffer)
        return buffers


if __name__ == "__main__":
    audio_gen = GenerateAudio()
    sample_rate, result = audio_gen.generate_audio(
        [
            "A piano playing a jazz melody",
            "A guitar playing a rock riff",
            "A LoFi music for coding",
        ],
        duration=10,
    )
    paths = audio_gen.save_audio()
    print(f"Saved audio to: {paths}")
    buffers = audio_gen.get_audio_buffer()
    print(f"Audio buffers: {buffers}")