| import argparse | |
| import os | |
| import embedder | |
| import torch.distributed as dist | |
| from lmms_eval.api.registry import ALL_TASKS, GROUP_REGISTRY | |
| from lmms_eval.tasks import ( | |
| ConfigurableTask, | |
| get_task_dict, | |
| include_path, | |
| initialize_tasks, | |
| ) | |
| from lmms_eval.utils import simple_parse_args_string | |
| def rank0_print(*args): | |
| if dist.is_initialized(): | |
| if dist.get_rank() == 0: | |
| print(f"Rank {dist.get_rank()}: ", *args) | |
| else: | |
| print(*args) | |
| def parse_arguments(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--name", type=str) | |
| parser.add_argument("--output_path", type=str) | |
| parser.add_argument("--tasks", type=str, required=False, default="") | |
| parser.add_argument("--data_path", type=str, required=False, default="") | |
| parser.add_argument("--image_folder", type=str, required=False, default="") | |
| parser.add_argument("--embedder_kwargs", type=str, default="") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_arguments() | |
| embedder_name = args.name | |
| output_path = args.output_path | |
| if args.tasks.lower().strip() == "all": | |
| initialize_tasks() | |
| for task in list(ALL_TASKS): | |
| if task in GROUP_REGISTRY: | |
| ALL_TASKS.remove(task) | |
| tasks = list(ALL_TASKS) | |
| else: | |
| tasks = args.tasks.split(",") | |
| cached_idx = [] | |
| for idx in range(len(tasks)): | |
| if os.path.exists(os.path.join(output_path, f"{tasks[idx]}_embed.npy")): | |
| rank0_print(f"Task {tasks[idx]} exists in cache folder, load from cache") | |
| cached_idx.append(idx) | |
| tasks = [tasks[idx] for idx in range(len(tasks)) if idx not in cached_idx] | |
| rank0_print(f"Tasks : {tasks}") | |
| embedder_kwargs = simple_parse_args_string(args.embedder_kwargs) | |
| embedder_cls = getattr(embedder, embedder_name) | |
| embedder_obj = embedder_cls(name=embedder_name, output_path=output_path, **embedder_kwargs) | |
| for task in tasks: | |
| embedder_obj.embed_task(task) | |