import collections.abc as collections from pathlib import Path import torch GLUESTICK_ROOT = Path(__file__).parent.parent def get_class(mod_name, base_path, BaseClass): """Get the class object which inherits from BaseClass and is defined in the module named mod_name, child of base_path. """ import inspect mod_path = "{}.{}".format(base_path, mod_name) mod = __import__(mod_path, fromlist=[""]) classes = inspect.getmembers(mod, inspect.isclass) # Filter classes defined in the module classes = [c for c in classes if c[1].__module__ == mod_path] # Filter classes inherited from BaseModel classes = [c for c in classes if issubclass(c[1], BaseClass)] assert len(classes) == 1, classes return classes[0][1] def get_model(name): from .models.base_model import BaseModel return get_class("models." + name, __name__, BaseModel) def numpy_image_to_torch(image): """Normalize the image tensor and reorder the dimensions.""" if image.ndim == 3: image = image.transpose((2, 0, 1)) # HxWxC to CxHxW elif image.ndim == 2: image = image[None] # add channel axis else: raise ValueError(f"Not an image: {image.shape}") return torch.from_numpy(image / 255.0).float() def map_tensor(input_, func): if isinstance(input_, (str, bytes)): return input_ elif isinstance(input_, collections.Mapping): return {k: map_tensor(sample, func) for k, sample in input_.items()} elif isinstance(input_, collections.Sequence): return [map_tensor(sample, func) for sample in input_] else: return func(input_) def batch_to_np(batch): return map_tensor(batch, lambda t: t.detach().cpu().numpy()[0])