Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
class Registry: | |
mapping = { | |
"builder_name_mapping": {}, | |
"task_name_mapping": {}, | |
"processor_name_mapping": {}, | |
"model_name_mapping": {}, | |
"lr_scheduler_name_mapping": {}, | |
"runner_name_mapping": {}, | |
"state": {}, | |
"paths": {}, | |
} | |
def register_builder(cls, name): | |
r"""Register a dataset builder to registry with key 'name' | |
Args: | |
name: Key with which the builder will be registered. | |
Usage: | |
from bubogpt.common.registry import registry | |
from bubogpt.datasets.base_dataset_builder import BaseDatasetBuilder | |
""" | |
def wrap(builder_cls): | |
# TODO: merge them or split builders by modality | |
from bubogpt.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder | |
from bubogpt.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder | |
from bubogpt.datasets.builders.multimodal_base_dataset_builder import MultimodalBaseDatasetBuilder | |
assert issubclass( | |
builder_cls, (ImageBaseDatasetBuilder, AudioBaseDatasetBuilder, MultimodalBaseDatasetBuilder) | |
), "All builders must inherit BaseDatasetBuilder class, found {}".format( | |
builder_cls | |
) | |
if name in cls.mapping["builder_name_mapping"]: | |
raise KeyError( | |
"Name '{}' already registered for {}.".format( | |
name, cls.mapping["builder_name_mapping"][name] | |
) | |
) | |
cls.mapping["builder_name_mapping"][name] = builder_cls | |
return builder_cls | |
return wrap | |
def register_task(cls, name): | |
r"""Register a task to registry with key 'name' | |
Args: | |
name: Key with which the task will be registered. | |
Usage: | |
from bubogpt.common.registry import registry | |
""" | |
def wrap(task_cls): | |
from bubogpt.tasks.base_task import BaseTask | |
assert issubclass( | |
task_cls, BaseTask | |
), "All tasks must inherit BaseTask class" | |
if name in cls.mapping["task_name_mapping"]: | |
raise KeyError( | |
"Name '{}' already registered for {}.".format( | |
name, cls.mapping["task_name_mapping"][name] | |
) | |
) | |
cls.mapping["task_name_mapping"][name] = task_cls | |
return task_cls | |
return wrap | |
def register_model(cls, name): | |
r"""Register a task to registry with key 'name' | |
Args: | |
name: Key with which the task will be registered. | |
Usage: | |
from bubogpt.common.registry import registry | |
""" | |
def wrap(model_cls): | |
from bubogpt.models import BaseModel | |
assert issubclass( | |
model_cls, BaseModel | |
), "All models must inherit BaseModel class" | |
if name in cls.mapping["model_name_mapping"]: | |
raise KeyError( | |
"Name '{}' already registered for {}.".format( | |
name, cls.mapping["model_name_mapping"][name] | |
) | |
) | |
cls.mapping["model_name_mapping"][name] = model_cls | |
return model_cls | |
return wrap | |
def register_processor(cls, name): | |
r"""Register a processor to registry with key 'name' | |
Args: | |
name: Key with which the task will be registered. | |
Usage: | |
from bubogpt.common.registry import registry | |
""" | |
def wrap(processor_cls): | |
from bubogpt.processors import BaseProcessor | |
assert issubclass( | |
processor_cls, BaseProcessor | |
), "All processors must inherit BaseProcessor class" | |
if name in cls.mapping["processor_name_mapping"]: | |
raise KeyError( | |
"Name '{}' already registered for {}.".format( | |
name, cls.mapping["processor_name_mapping"][name] | |
) | |
) | |
cls.mapping["processor_name_mapping"][name] = processor_cls | |
return processor_cls | |
return wrap | |
def register_lr_scheduler(cls, name): | |
r"""Register a model to registry with key 'name' | |
Args: | |
name: Key with which the task will be registered. | |
Usage: | |
from bubogpt.common.registry import registry | |
""" | |
def wrap(lr_sched_cls): | |
if name in cls.mapping["lr_scheduler_name_mapping"]: | |
raise KeyError( | |
"Name '{}' already registered for {}.".format( | |
name, cls.mapping["lr_scheduler_name_mapping"][name] | |
) | |
) | |
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls | |
return lr_sched_cls | |
return wrap | |
def register_runner(cls, name): | |
r"""Register a model to registry with key 'name' | |
Args: | |
name: Key with which the task will be registered. | |
Usage: | |
from bubogpt.common.registry import registry | |
""" | |
def wrap(runner_cls): | |
if name in cls.mapping["runner_name_mapping"]: | |
raise KeyError( | |
"Name '{}' already registered for {}.".format( | |
name, cls.mapping["runner_name_mapping"][name] | |
) | |
) | |
cls.mapping["runner_name_mapping"][name] = runner_cls | |
return runner_cls | |
return wrap | |
def register_path(cls, name, path): | |
r"""Register a path to registry with key 'name' | |
Args: | |
name: Key with which the path will be registered. | |
Usage: | |
from bubogpt.common.registry import registry | |
""" | |
assert isinstance(path, str), "All path must be str." | |
if name in cls.mapping["paths"]: | |
raise KeyError("Name '{}' already registered.".format(name)) | |
cls.mapping["paths"][name] = path | |
def register(cls, name, obj): | |
r"""Register an item to registry with key 'name' | |
Args: | |
name: Key with which the item will be registered. | |
Usage:: | |
from bubogpt.common.registry import registry | |
registry.register("config", {}) | |
""" | |
path = name.split(".") | |
current = cls.mapping["state"] | |
for part in path[:-1]: | |
if part not in current: | |
current[part] = {} | |
current = current[part] | |
current[path[-1]] = obj | |
# @classmethod | |
# def get_trainer_class(cls, name): | |
# return cls.mapping["trainer_name_mapping"].get(name, None) | |
def get_builder_class(cls, name): | |
return cls.mapping["builder_name_mapping"].get(name, None) | |
def get_model_class(cls, name): | |
return cls.mapping["model_name_mapping"].get(name, None) | |
def get_task_class(cls, name): | |
return cls.mapping["task_name_mapping"].get(name, None) | |
def get_processor_class(cls, name): | |
return cls.mapping["processor_name_mapping"].get(name, None) | |
def get_lr_scheduler_class(cls, name): | |
return cls.mapping["lr_scheduler_name_mapping"].get(name, None) | |
def get_runner_class(cls, name): | |
return cls.mapping["runner_name_mapping"].get(name, None) | |
def list_runners(cls): | |
return sorted(cls.mapping["runner_name_mapping"].keys()) | |
def list_models(cls): | |
return sorted(cls.mapping["model_name_mapping"].keys()) | |
def list_tasks(cls): | |
return sorted(cls.mapping["task_name_mapping"].keys()) | |
def list_processors(cls): | |
return sorted(cls.mapping["processor_name_mapping"].keys()) | |
def list_lr_schedulers(cls): | |
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) | |
def list_datasets(cls): | |
return sorted(cls.mapping["builder_name_mapping"].keys()) | |
def get_path(cls, name): | |
return cls.mapping["paths"].get(name, None) | |
def get(cls, name, default=None, no_warning=False): | |
r"""Get an item from registry with key 'name' | |
Args: | |
name (string): Key whose value needs to be retrieved. | |
default: If passed and key is not in registry, default value will | |
be returned with a warning. Default: None | |
no_warning (bool): If passed as True, warning when key doesn't exist | |
will not be generated. Useful for MMF's | |
internal operations. Default: False | |
""" | |
original_name = name | |
name = name.split(".") | |
value = cls.mapping["state"] | |
for subname in name: | |
value = value.get(subname, default) | |
if value is default: | |
break | |
if ( | |
"writer" in cls.mapping["state"] | |
and value == default | |
and no_warning is False | |
): | |
cls.mapping["state"]["writer"].warning( | |
"Key {} is not present in registry, returning default value " | |
"of {}".format(original_name, default) | |
) | |
return value | |
def unregister(cls, name): | |
r"""Remove an item from registry with key 'name' | |
Args: | |
name: Key which needs to be removed. | |
Usage:: | |
from mmf.common.registry import registry | |
config = registry.unregister("config") | |
""" | |
return cls.mapping["state"].pop(name, None) | |
registry = Registry() | |