artificialguybr commited on
Commit
7ab6d17
1 Parent(s): 9f2cee2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -21
app.py CHANGED
@@ -1,32 +1,60 @@
1
  import gradio as gr
2
- import spaces
3
  import torchaudio
4
  from audiocraft.models import MusicGen
5
  from audiocraft.data.audio import audio_write
 
 
 
 
 
 
6
 
 
7
  model = MusicGen.get_pretrained('nateraw/musicgen-songstarter-v0.2')
8
- model.set_generation_params(duration=8) # generate 8 seconds.
9
 
10
- @spaces.GPU(duration=120) # Specify duration if the function is expected to take more than 60s
11
- def generate_music(description, audio_file):
12
- if audio_file is None:
13
- wav = model.generate([description]) # generates 1 sample based on the provided description
 
 
 
 
 
 
 
 
 
14
  else:
15
- melody, sr = torchaudio.load(audio_file)
16
- wav = model.generate_with_chroma([description], melody[None], sr) # generates using the melody from the given audio and the provided description
 
 
 
 
17
 
18
- audio_write('output', wav[0].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
19
- return 'output.wav'
 
 
 
20
 
21
- iface = gr.Interface(
22
- fn=generate_music,
23
- inputs=[
24
- gr.Text(label="Description"),
25
- gr.Audio(type="filepath", label="Audio File (optional)")
26
- ],
27
- outputs=gr.Audio(type="filepath"),
28
- title="MusicGen",
29
- description="Generate music using the MusicGen model. Provide a description and optionally an audio file for melody.",
30
- )
31
 
32
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torchaudio
3
  from audiocraft.models import MusicGen
4
  from audiocraft.data.audio import audio_write
5
+ import spaces
6
+ import logging
7
+ import os
8
+ import uuid
9
+ # Configura o logging
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
 
12
+ logging.info("Carregando o modelo pré-treinado.")
13
  model = MusicGen.get_pretrained('nateraw/musicgen-songstarter-v0.2')
14
+ model.set_generation_params(duration=8)
15
 
16
+ @spaces.GPU(duration=120)
17
+ def generate_music(description, melody_audio):
18
+ logging.info("Iniciando a geração de música.")
19
+ if description:
20
+ description = [description]
21
+ if melody_audio:
22
+ logging.info(f"Carregando a melodia de áudio de: {melody_audio}")
23
+ melody, sr = torchaudio.load(melody_audio)
24
+ logging.info("Gerando música com descrição e melodia.")
25
+ wav = model.generate_with_chroma(description, melody[None], sr)
26
+ else:
27
+ logging.info("Gerando música apenas com descrição.")
28
+ wav = model.generate(description)
29
  else:
30
+ logging.info("Gerando música de forma incondicional.")
31
+ wav = model.generate_unconditional(1)
32
+ filename = f'{str(uuid.uuid4())}.wav'
33
+ output_path = os.path.join('./', filename) # Salva o arquivo no diretório atual
34
+ logging.info(f"Salvando a música gerada em: {output_path}")
35
+ audio_write(output_path, wav[0].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
36
 
37
+ # Verifica a forma do tensor de áudio e se foi salvo corretamente
38
+ logging.info(f"A forma do tensor de áudio gerado: {wav[0].shape}")
39
+ logging.info("Música gerada e salva com sucesso.")
40
+ if not os.path.exists(output_path):
41
+ raise ValueError(f'Failed to save audio to {output_path}')
42
 
43
+ return output_path
44
+
45
+ # Define a interface Gradio
46
+ description = gr.Textbox(label="Description", placeholder="acoustic, guitar, melody, trap, d minor, 90 bpm")
47
+ melody_audio = gr.Audio(label="Melody Audio (optional)", type="filepath")
48
+ output_path = gr.Audio(label="Generated Music", type="filepath")
 
 
 
 
49
 
50
+ gr.Interface(
51
+ fn=generate_music,
52
+ inputs=[description, melody_audio],
53
+ outputs=output_path,
54
+ title="MusicGen Demo",
55
+ description="Generate music using the MusicGen model.",
56
+ examples=[
57
+ ["trap, synthesizer, songstarters, dark, G# minor, 140 bpm", "./assets/kalhonaho.mp3"],
58
+ ["upbeat, electronic, synth, dance, 120 bpm", None]
59
+ ]
60
+ ).launch()