SmallO commited on
Commit
105d116
·
1 Parent(s): cba6ba1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -24
app.py CHANGED
@@ -1,34 +1,27 @@
1
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
2
- from IPython.display import Audio
3
- import scipy
4
  import torch
5
  import streamlit as st
6
 
7
-
8
  def mu_gen(prompt):
9
- processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
10
- model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
11
-
12
- device = torch.device("cpu")
13
- model.to(device)
14
 
15
- inputs = processor(
16
- text = [str(prompt)], # This line is correct
17
- padding=True,
18
- return_tensors="pt",
19
- )
20
-
21
- inputs = {key: value.to(device) for key, value in inputs.items()}
22
 
23
- # Generate audio on CPU
24
- audio_values = model.generate(**inputs, max_new_tokens=256)
25
- sampling_rate = model.config.audio_encoder.sampling_rate
 
 
26
 
27
- # Create an Audio object from the generated audio
28
- result = Audio(audio_values[0].numpy(), rate=sampling_rate)
29
 
30
- return result
 
 
31
 
 
32
 
33
  def main():
34
  st.title("Text to Music Generator")
@@ -39,12 +32,12 @@ def main():
39
  if st.button("Generate Music"):
40
  if prompt:
41
  # Call the mu_gen function to generate music
42
- generated_music = mu_gen(prompt)
43
 
44
  # Display the generated audio
45
- st.audio(generated_music, format="audio/wav")
46
  else:
47
  st.warning("Please enter a text prompt.")
48
 
49
  if __name__ == "__main__":
50
- main()
 
1
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
 
 
2
  import torch
3
  import streamlit as st
4
 
 
5
  def mu_gen(prompt):
6
+ processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
7
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
 
 
 
8
 
9
+ device = torch.device("cpu")
10
+ model.to(device)
 
 
 
 
 
11
 
12
+ inputs = processor(
13
+ text=[str(prompt)],
14
+ padding=True,
15
+ return_tensors="pt",
16
+ )
17
 
18
+ inputs = {key: value.to(device) for key, value in inputs.items()}
 
19
 
20
+ # Generate audio on CPU
21
+ audio_values = model.generate(**inputs, max_new_tokens=256)
22
+ sampling_rate = model.config.audio_encoder.sampling_rate
23
 
24
+ return audio_values, sampling_rate
25
 
26
  def main():
27
  st.title("Text to Music Generator")
 
32
  if st.button("Generate Music"):
33
  if prompt:
34
  # Call the mu_gen function to generate music
35
+ generated_music, sampling_rate = mu_gen(prompt)
36
 
37
  # Display the generated audio
38
+ st.audio(generated_music[0].numpy(), format="audio/wav", sample_rate=sampling_rate)
39
  else:
40
  st.warning("Please enter a text prompt.")
41
 
42
  if __name__ == "__main__":
43
+ main()