Spaces:
No application file
No application file
| import os | |
| import getpass | |
| from datetime import datetime | |
| import torch | |
| import random | |
| import numpy as np | |
| import torch.distributed as dist | |
| import inspect | |
| import importlib.util | |
| import socket | |
| import os | |
| from typing import Dict, Union, Type, List | |
| def get_open_port(): | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| s.bind(('', 0)) # bind to all interfaces and use an OS provided port | |
| return s.getsockname()[1] # return only the port number | |
| def get_remote_file(remote_path, local_path=None): | |
| hostname, path = remote_path.split(':') | |
| local_hostname = socket.gethostname() | |
| if hostname == local_hostname or hostname == local_hostname[:local_hostname.find('.')]: | |
| return path | |
| if local_path is None: | |
| local_path = path | |
| # local_path = local_path.replace('/scr-ssd', '/scr') | |
| if os.path.exists(local_path): | |
| return local_path | |
| local_dir = os.path.dirname(local_path) | |
| os.makedirs(local_dir, exist_ok=True) | |
| print(f'Copying {hostname}:{path} to {local_path}') | |
| os.system(f'scp {remote_path} {local_path}') | |
| return local_path | |
| def rank0_print(*args, **kwargs): | |
| """Print, but only on rank 0.""" | |
| if not dist.is_initialized() or dist.get_rank() == 0: | |
| print(*args, **kwargs) | |
| def get_local_dir(prefixes_to_resolve: List[str]) -> str: | |
| """Return the path to the cache directory for this user.""" | |
| for prefix in prefixes_to_resolve: | |
| if os.path.exists(prefix): | |
| return f"{prefix}/{getpass.getuser()}" | |
| os.makedirs(prefix) | |
| return f"{prefix}/{getpass.getuser()}" | |
| def get_local_run_dir(exp_name: str, local_dirs: List[str]) -> str: | |
| """Create a local directory to store outputs for this run, and return its path.""" | |
| now = datetime.now() | |
| timestamp = now.strftime("%Y-%m-%d_%H-%M-%S_%f") | |
| run_dir = f"{get_local_dir(local_dirs)}/{exp_name}_{timestamp}" | |
| os.makedirs(run_dir, exist_ok=True) | |
| return run_dir | |
| def slice_and_move_batch_for_device(batch: Dict, rank: int, world_size: int, device: str) -> Dict: | |
| """Slice a batch into chunks, and move each chunk to the specified device.""" | |
| chunk_size = len(list(batch.values())[0]) // world_size | |
| start = chunk_size * rank | |
| end = chunk_size * (rank + 1) | |
| sliced = {k: v[start:end] for k, v in batch.items()} | |
| on_device = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in sliced.items()} | |
| return on_device | |
| def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: | |
| if tensor.size(dim) >= length: | |
| return tensor | |
| else: | |
| pad_size = list(tensor.shape) | |
| pad_size[dim] = length - tensor.size(dim) | |
| return torch.cat([tensor, pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device)], dim=dim) | |
| def all_gather_if_needed(values: torch.Tensor, rank: int, world_size: int) -> torch.Tensor: | |
| """Gather and stack/cat values from all processes, if there are multiple processes.""" | |
| if world_size == 1: | |
| return values | |
| all_values = [torch.empty_like(values).to(rank) for _ in range(world_size)] | |
| dist.all_gather(all_values, values) | |
| cat_function = torch.cat if values.dim() > 0 else torch.stack | |
| return cat_function(all_values, dim=0) | |
| def formatted_dict(d: Dict) -> Dict: | |
| """Format a dictionary for printing.""" | |
| return {k: (f"{v:.5g}" if type(v) == float else v) for k, v in d.items()} | |
| def disable_dropout(model: torch.nn.Module): | |
| """Disable dropout in a model.""" | |
| for module in model.modules(): | |
| if isinstance(module, torch.nn.Dropout): | |
| module.p = 0 | |
| def print_gpu_memory(rank: int = None, message: str = ''): | |
| """Print the amount of GPU memory currently allocated for each GPU.""" | |
| if torch.cuda.is_available(): | |
| device_count = torch.cuda.device_count() | |
| for i in range(device_count): | |
| device = torch.device(f'cuda:{i}') | |
| allocated_bytes = torch.cuda.memory_allocated(device) | |
| if allocated_bytes == 0: | |
| continue | |
| print('*' * 40) | |
| print(f'[{message} rank {rank} ] GPU {i}: {allocated_bytes / 1024**2:.2f} MB') | |
| print('*' * 40) | |
| def get_block_class_from_model(model: torch.nn.Module, block_class_name: str) -> torch.nn.Module: | |
| """Get the class of a block from a model, using the block's class name.""" | |
| for module in model.modules(): | |
| if module.__class__.__name__ == block_class_name: | |
| return module.__class__ | |
| raise ValueError(f"Could not find block class {block_class_name} in model {model}") | |
| def get_block_class_from_model_class_and_block_name(model_class: Type, block_class_name: str) -> Type: | |
| filepath = inspect.getfile(model_class) | |
| assert filepath.endswith('.py'), f"Expected a .py file, got {filepath}" | |
| assert os.path.exists(filepath), f"File {filepath} does not exist" | |
| assert "transformers" in filepath, f"Expected a transformers model, got {filepath}" | |
| module_name = filepath[filepath.find('transformers'):].replace('/', '.')[:-3] | |
| print(f"Searching in file {filepath}, module {module_name} for class {block_class_name}") | |
| # Load the module dynamically | |
| spec = importlib.util.spec_from_file_location(module_name, filepath) | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| # Get the class dynamically | |
| class_ = getattr(module, block_class_name) | |
| print(f"Found class {class_} in module {module_name}") | |
| return class_ | |
| def init_distributed(rank: int, world_size: int, master_addr: str = 'localhost', port: int = 12355, backend: str = 'nccl'): | |
| print(rank, 'initializing distributed') | |
| os.environ["MASTER_ADDR"] = master_addr | |
| os.environ["MASTER_PORT"] = str(port) | |
| dist.init_process_group(backend, rank=rank, world_size=world_size) | |
| torch.cuda.set_device(rank) | |
| class TemporarilySeededRandom: | |
| def __init__(self, seed): | |
| """Temporarily set the random seed, and then restore it when exiting the context.""" | |
| self.seed = seed | |
| self.stored_state = None | |
| self.stored_np_state = None | |
| def __enter__(self): | |
| # Store the current random state | |
| self.stored_state = random.getstate() | |
| self.stored_np_state = np.random.get_state() | |
| # Set the random seed | |
| random.seed(self.seed) | |
| np.random.seed(self.seed) | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| # Restore the random state | |
| random.setstate(self.stored_state) | |
| np.random.set_state(self.stored_np_state) |