albef-vqa / utils.py
ryanramos's picture
Add source code
d1b8c9b
raw
history blame
3.57 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import torch
import torch.distributed as dist
from torch import nn
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args["rank"] = int(os.environ["RANK"])
args["world_size"] = int(os.environ["WORLD_SIZE"])
args["gpu"] = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args["rank"] = int(os.environ["SLURM_PROCID"])
args["gpu"] = args["rank"] % torch.cuda.device_count()
else:
print("Not using distributed mode")
args["distributed"] = False
return
args["distributed"] = True
torch.cuda.set_device(args["gpu"])
args["dist_backend"] = "nccl"
print(
"| distributed init (rank {}): {}".format(args["rank"], args["dist_url"]),
flush=True,
)
torch.distributed.init_process_group(
backend=args["dist_backend"],
init_method=args["dist_url"],
world_size=args["world_size"],
rank=args["rank"],
)
torch.distributed.barrier()
setup_for_distributed(args["rank"] == 0)
def save_result(result, directory, file_name):
rank_path = os.path.join(directory, "{}_rank_{}.json".format(file_name, get_rank()))
main_path = os.path.join(directory, "{}.json".format(file_name))
json.dump(result, open(rank_path, "w"))
if is_dist_avail_and_initialized():
dist.barrier()
if is_main_process():
result = []
for rank in range(get_world_size()):
rank_path = os.path.join(
directory, "{}_rank_{}.json".format(file_name, rank)
)
rank_res = json.load(open(rank_path, "r"))
result += rank_res
json.dump(result, open(main_path, "w"))
if is_dist_avail_and_initialized():
dist.barrier()
def add_weight_decay(model: nn.Module, weight_decay: float) -> None:
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # skip weight_decay for momentum models
if len(param.shape) == 1 or name.endswith(".bias"):
no_decay.append(param)
else:
decay.append(param)
return [
{"params": no_decay, "weight_decay": 0.0},
{"params": decay, "weight_decay": weight_decay},
]