Spaces:
Sleeping
Sleeping
| # ------------------------------------------------------------------------------ | |
| # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport | |
| # Copyright (c) 2024 Borui Zhang. All Rights Reserved. | |
| # Licensed under the MIT License [see LICENSE for details] | |
| # ------------------------------------------------------------------------------ | |
| from typing import Tuple, Union, Iterable | |
| from omegaconf import OmegaConf | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| def dist_all_gather(x): | |
| tensor_list = [torch.zeros_like(x) for _ in range(dist.get_world_size())] | |
| dist.all_gather(tensor_list, x) | |
| x = torch.cat(tensor_list, dim=0) | |
| return x | |
| def any_2tuple(data: Union[int, Tuple[int]]) -> Tuple[int]: | |
| if isinstance(data, int): | |
| return (data, data) | |
| elif isinstance(data, Iterable): | |
| assert len(data) == 2, "target size must be tuple of (w, h)" | |
| return tuple(data) | |
| else: | |
| raise ValueError("target size must be int or tuple of (w, h)") | |