sungeuns commited on
Commit
7ccdf69
1 Parent(s): 6b3bb53

Model can be loaded from local directory (#69)

Browse files
audiocraft/models/loaders.py CHANGED
@@ -51,6 +51,10 @@ def _get_state_dict(
51
  if os.path.isfile(file_or_url_or_id):
52
  return torch.load(file_or_url_or_id, map_location=device)
53
 
 
 
 
 
54
  elif file_or_url_or_id.startswith('https://'):
55
  return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
56
 
 
51
  if os.path.isfile(file_or_url_or_id):
52
  return torch.load(file_or_url_or_id, map_location=device)
53
 
54
+ if os.path.isdir(file_or_url_or_id):
55
+ file = f"{file_or_url_or_id}/{filename}"
56
+ return torch.load(file, map_location=device)
57
+
58
  elif file_or_url_or_id.startswith('https://'):
59
  return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
60
 
audiocraft/models/musicgen.py CHANGED
@@ -89,10 +89,11 @@ class MusicGen:
89
  return MusicGen(name, compression_model, lm)
90
 
91
  if name not in HF_MODEL_CHECKPOINTS_MAP:
92
- raise ValueError(
93
- f"{name} is not a valid checkpoint name. "
94
- f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
95
- )
 
96
 
97
  cache_dir = os.environ.get('MUSICGEN_ROOT', None)
98
  compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
 
89
  return MusicGen(name, compression_model, lm)
90
 
91
  if name not in HF_MODEL_CHECKPOINTS_MAP:
92
+ if not os.path.isfile(name) and not os.path.isdir(name):
93
+ raise ValueError(
94
+ f"{name} is not a valid checkpoint name. "
95
+ f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
96
+ )
97
 
98
  cache_dir = os.environ.get('MUSICGEN_ROOT', None)
99
  compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)