Kvikontent commited on
Commit
1e126d1
1 Parent(s): 9faf42d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -4,23 +4,21 @@ import torchaudio
4
  from einops import rearrange
5
  from stable_audio_tools import get_pretrained_model
6
  from stable_audio_tools.inference.generation import generate_diffusion_cond
7
- from huggingface_hub import cached_download, hf_hub_url
8
- from transformers import AutoModelForAudioClassification
9
  import os
10
 
11
- token = os.environ.get("TOKEN")
12
- model_name = "stabilityai/stable-audio-open-1.0"
13
- model_config_url = hf_hub_url(repo_id=model_name, revision="main", filename="model_config.json")
14
- model_config = cached_download(model_config_url, use_auth_token=token)
15
-
16
- model = AutoModelForAudioClassification.from_pretrained(
17
- model_name,
18
- cache_dir=None,
19
- use_auth_token=token
20
  )
21
  sample_rate = model_config["sample_rate"]
22
  sample_size = model_config["sample_size"]
23
 
 
 
 
 
 
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  model = model.to(device)
26
 
@@ -32,7 +30,7 @@ def generate_music(prompt, seconds_total, bpm, genre):
32
  # Set up text and timing conditioning
33
  conditioning = [{
34
  "prompt": f"{bpm} BPM {genre} {prompt}",
35
- "seconds_start": 0,
36
  "seconds_total": seconds_total
37
  }]
38
 
 
4
  from einops import rearrange
5
  from stable_audio_tools import get_pretrained_model
6
  from stable_audio_tools.inference.generation import generate_diffusion_cond
 
 
7
  import os
8
 
9
+ # Load model config from stable-audio-tools
10
+ model, model_config = get_pretrained_model(
11
+ "stabilityai/stable-audio-open-1.0", config_filename="model_config.json"
 
 
 
 
 
 
12
  )
13
  sample_rate = model_config["sample_rate"]
14
  sample_size = model_config["sample_size"]
15
 
16
+ # Load the model using the transformers library
17
+ token = os.environ.get("TOKEN")
18
+ model = AutoModelForAudioClassification.from_pretrained(
19
+ "stabilityai/stable-audio-open-1.0", use_auth_token=token, cache_dir=None
20
+ )
21
+
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  model = model.to(device)
24
 
 
30
  # Set up text and timing conditioning
31
  conditioning = [{
32
  "prompt": f"{bpm} BPM {genre} {prompt}",
33
+ "seconds_start": 0,
34
  "seconds_total": seconds_total
35
  }]
36