Spaces:
Runtime error
Runtime error
Update OmniAvatar/models/model_manager.py
Browse files
OmniAvatar/models/model_manager.py
CHANGED
|
@@ -254,49 +254,8 @@ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
|
| 254 |
loaded_model_names += loaded_model_names_
|
| 255 |
loaded_models += loaded_models_
|
| 256 |
return loaded_model_names, loaded_models
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
class ModelDetectorFromHuggingfaceFolder:
|
| 261 |
-
def __init__(self, model_loader_configs=[]):
|
| 262 |
-
self.architecture_dict = {}
|
| 263 |
-
for metadata in model_loader_configs:
|
| 264 |
-
self.add_model_metadata(*metadata)
|
| 265 |
|
| 266 |
|
| 267 |
-
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
| 268 |
-
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
def match(self, file_path="", state_dict={}):
|
| 272 |
-
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
| 273 |
-
return False
|
| 274 |
-
file_list = os.listdir(file_path)
|
| 275 |
-
if "config.json" not in file_list:
|
| 276 |
-
return False
|
| 277 |
-
with open(os.path.join(file_path, "config.json"), "r") as f:
|
| 278 |
-
config = json.load(f)
|
| 279 |
-
if "architectures" not in config and "_class_name" not in config:
|
| 280 |
-
return False
|
| 281 |
-
return True
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
| 285 |
-
with open(os.path.join(file_path, "config.json"), "r") as f:
|
| 286 |
-
config = json.load(f)
|
| 287 |
-
loaded_model_names, loaded_models = [], []
|
| 288 |
-
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
| 289 |
-
for architecture in architectures:
|
| 290 |
-
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
| 291 |
-
if redirected_architecture is not None:
|
| 292 |
-
architecture = redirected_architecture
|
| 293 |
-
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
| 294 |
-
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
| 295 |
-
loaded_model_names += loaded_model_names_
|
| 296 |
-
loaded_models += loaded_models_
|
| 297 |
-
return loaded_model_names, loaded_models
|
| 298 |
-
|
| 299 |
-
|
| 300 |
|
| 301 |
class ModelDetectorFromPatchedSingleFile:
|
| 302 |
def __init__(self, model_loader_configs=[]):
|
|
@@ -357,7 +316,6 @@ class ModelManager:
|
|
| 357 |
self.model_detector = [
|
| 358 |
ModelDetectorFromSingleFile(model_loader_configs),
|
| 359 |
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
| 360 |
-
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
| 361 |
]
|
| 362 |
self.load_models(downloaded_files + file_path_list)
|
| 363 |
|
|
|
|
| 254 |
loaded_model_names += loaded_model_names_
|
| 255 |
loaded_models += loaded_models_
|
| 256 |
return loaded_model_names, loaded_models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
class ModelDetectorFromPatchedSingleFile:
|
| 261 |
def __init__(self, model_loader_configs=[]):
|
|
|
|
| 316 |
self.model_detector = [
|
| 317 |
ModelDetectorFromSingleFile(model_loader_configs),
|
| 318 |
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
|
|
|
| 319 |
]
|
| 320 |
self.load_models(downloaded_files + file_path_list)
|
| 321 |
|