File size: 917 Bytes
fd01725 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
# https://github.com/cvg/pixloc
# Released under the Apache License 2.0
import inspect
from .base import BaseModel
def get_class(mod_name, base_path, BaseClass):
"""Get the class object which inherits from BaseClass and is defined in
the module named mod_name, child of base_path.
"""
mod_path = "{}.{}".format(base_path, mod_name)
mod = __import__(mod_path, fromlist=[""])
classes = inspect.getmembers(mod, inspect.isclass)
# Filter classes defined in the module
classes = [c for c in classes if c[1].__module__ == mod_path]
# Filter classes inherited from BaseModel
classes = [c for c in classes if issubclass(c[1], BaseClass)]
assert len(classes) == 1, classes
return classes[0][1]
def get_model(name):
return get_class(name, __name__, BaseModel)
|