Spaces:
Runtime error
Runtime error
import os | |
import math | |
import wandb | |
import random | |
import logging | |
import inspect | |
import argparse | |
import datetime | |
import subprocess | |
from pathlib import Path | |
from tqdm.auto import tqdm | |
from einops import rearrange | |
from omegaconf import OmegaConf | |
from safetensors import safe_open | |
from typing import Dict, Optional, Tuple | |
import torch | |
import torchvision | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from torch.optim.swa_utils import AveragedModel | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import diffusers | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from diffusers.models import UNet2DConditionModel | |
from diffusers.pipelines import StableDiffusionPipeline | |
from diffusers.optimization import get_scheduler | |
from diffusers.utils import check_min_version | |
from diffusers.utils.import_utils import is_xformers_available | |
from animatediff.models.resnet import InflatedConv3d | |
import transformers | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from animatediff.data.dataset_web import WebVid10M | |
from animatediff.models.unet import UNet3DConditionModel | |
from animatediff.pipelines.pipeline_animation import AnimationPipeline | |
from animatediff.pipelines.validation_pipeline import ValidationPipeline | |
from animatediff.utils.util import save_videos_grid, zero_rank_print, prepare_mask_coef, prepare_mask_coef_by_score | |
def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs): | |
"""Initializes distributed environment.""" | |
if launcher == 'pytorch': | |
rank = int(os.environ['RANK']) | |
num_gpus = torch.cuda.device_count() | |
local_rank = rank % num_gpus | |
torch.cuda.set_device(local_rank) | |
dist.init_process_group(backend=backend, **kwargs) | |
elif launcher == 'slurm': | |
proc_id = int(os.environ['SLURM_PROCID']) | |
ntasks = int(os.environ['SLURM_NTASKS']) | |
node_list = os.environ['SLURM_NODELIST'] | |
num_gpus = torch.cuda.device_count() | |
local_rank = proc_id % num_gpus | |
torch.cuda.set_device(local_rank) | |
addr = subprocess.getoutput( | |
f'scontrol show hostname {node_list} | head -n1') | |
os.environ['MASTER_ADDR'] = addr | |
os.environ['WORLD_SIZE'] = str(ntasks) | |
os.environ['RANK'] = str(proc_id) | |
port = os.environ.get('PORT', port) | |
os.environ['MASTER_PORT'] = str(port) | |
dist.init_process_group(backend=backend) | |
zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}") | |
else: | |
raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!') | |
return local_rank | |
def main( | |
image_finetune: bool, | |
name: str, | |
use_wandb: bool, | |
launcher: str, | |
output_dir: str, | |
pretrained_model_path: str, | |
train_data: Dict, | |
validation_data: Dict, | |
cfg_random_null_text: bool = True, | |
cfg_random_null_text_ratio: float = 0.1, | |
unet_checkpoint_path: str = "", | |
unet_additional_kwargs: Dict = {}, | |
ema_decay: float = 0.9999, | |
noise_scheduler_kwargs = None, | |
max_train_epoch: int = -1, | |
max_train_steps: int = 100, | |
validation_steps: int = 100, | |
validation_steps_tuple: Tuple = (-1,), | |
learning_rate: float = 3e-5, | |
scale_lr: bool = False, | |
lr_warmup_steps: int = 0, | |
lr_scheduler: str = "constant", | |
trainable_modules: Tuple[str] = (None, ), | |
num_workers: int = 32, | |
train_batch_size: int = 1, | |
adam_beta1: float = 0.9, | |
adam_beta2: float = 0.999, | |
adam_weight_decay: float = 1e-2, | |
adam_epsilon: float = 1e-08, | |
max_grad_norm: float = 1.0, | |
gradient_accumulation_steps: int = 32, | |
gradient_checkpointing: bool = False, | |
checkpointing_epochs: int = 5, | |
checkpointing_steps: int = -1, | |
mixed_precision_training: bool = True, | |
enable_xformers_memory_efficient_attention: bool = True, | |
statistic: list = [1, 40], | |
global_seed: int = 42, | |
is_debug: bool = False, | |
mask_frame: list = [0], | |
pretrained_motion_module_path: str = '', | |
pretrained_sd_path: str = '', | |
mask_sim_range: list = [0.2, 1.0], | |
): | |
check_min_version("0.10.0.dev0") | |
# Initialize distributed training | |
local_rank = init_dist(launcher=launcher) | |
global_rank = dist.get_rank() | |
num_processes = dist.get_world_size() | |
is_main_process = global_rank == 0 | |
seed = global_seed + global_rank | |
torch.manual_seed(seed) | |
# Logging folder | |
folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S") | |
output_dir = os.path.join(output_dir, folder_name) | |
if is_debug and os.path.exists(output_dir): | |
os.system(f"rm -rf {output_dir}") | |
*_, config = inspect.getargvalues(inspect.currentframe()) | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
filemode='a', | |
filename='train_v2_2.log', | |
) | |
if is_main_process and (not is_debug) and use_wandb: | |
run = wandb.init(project="image2video", name=folder_name, config=config) | |
# Handle the output folder creation | |
if is_main_process: | |
os.makedirs(output_dir, exist_ok=True) | |
os.makedirs(f"{output_dir}/samples", exist_ok=True) | |
os.makedirs(f"{output_dir}/sanity_check", exist_ok=True) | |
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) | |
OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) | |
# Load scheduler, tokenizer and models. | |
noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) | |
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
if not image_finetune: | |
unet = UNet3DConditionModel.from_pretrained_2d( | |
pretrained_model_path, subfolder="unet", | |
unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) | |
) | |
else: | |
unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") | |
# Load pretrained unet weights | |
if unet_checkpoint_path != "": | |
zero_rank_print(f"from checkpoint: {unet_checkpoint_path}") | |
unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu") | |
if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}") | |
state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path | |
m, u = unet.load_state_dict(state_dict, strict=False) | |
zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") | |
#assert len(u) == 0 | |
old_weights = unet.conv_in.weight | |
old_bias = unet.conv_in.bias | |
new_conv1 = InflatedConv3d(9, old_weights.shape[0], kernel_size=unet.conv_in.kernel_size, stride=unet.conv_in.stride, padding=unet.conv_in.padding, bias=True if old_bias is not None else False) | |
param = torch.zeros((320,5,3,3),requires_grad=True) | |
new_conv1.weight = torch.nn.Parameter(torch.cat((old_weights,param),dim=1)) | |
if old_bias is not None: | |
new_conv1.bias = old_bias | |
unet.conv_in = new_conv1 | |
unet.config["in_channels"] = 9 | |
# Load webvid-Pretrained sd | |
'''webvid_sd_ckpt = torch.load(pretrained_sd_path) | |
unet.load_state_dict(webvid_sd_ckpt, strict=False) | |
vae.load_state_dict(webvid_sd_ckpt, strict=False) | |
print('Webvid_pretrained sd loaded')''' | |
# Freeze vae and text_encoder | |
vae.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
# Set unet trainable parameters | |
unet.requires_grad_(False) | |
for name, param in unet.named_parameters(): | |
for trainable_module_name in trainable_modules: | |
if trainable_module_name in name: | |
logging.info(f'{name} is trainable \n') | |
#print(f'{name} is trainable') | |
param.requires_grad = True | |
break | |
# Load pre-trained motion module | |
unet_state_dict = unet.state_dict().keys() | |
pretrained_motion_module = torch.load(pretrained_motion_module_path) | |
for (name, param) in zip(pretrained_motion_module.keys(), pretrained_motion_module.values()): | |
if name in unet_state_dict: | |
unet.state_dict()[name].copy_(param) | |
#print(f"{name} weight replace") | |
trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters())) | |
optimizer = torch.optim.AdamW( | |
trainable_params, | |
lr=learning_rate, | |
betas=(adam_beta1, adam_beta2), | |
weight_decay=adam_weight_decay, | |
eps=adam_epsilon, | |
) | |
if is_main_process: | |
zero_rank_print(f"trainable params number: {len(trainable_params)}") | |
zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") | |
# Enable xformers | |
if enable_xformers_memory_efficient_attention: | |
if is_xformers_available(): | |
unet.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError("xformers is not available. Make sure it is installed correctly") | |
# Enable gradient checkpointing | |
if gradient_checkpointing: | |
unet.enable_gradient_checkpointing() | |
# Move models to GPU | |
vae.to(local_rank) | |
text_encoder.to(local_rank) | |
# Get the training dataset | |
train_dataset = WebVid10M(**train_data, is_image=image_finetune) | |
distributed_sampler = DistributedSampler( | |
train_dataset, | |
num_replicas=num_processes, | |
rank=global_rank, | |
shuffle=True, | |
seed=global_seed, | |
) | |
# DataLoaders creation: | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=train_batch_size, | |
shuffle=False, | |
sampler=distributed_sampler, | |
num_workers=num_workers, | |
pin_memory=True, | |
drop_last=True, | |
) | |
# Get the training iteration | |
if max_train_steps == -1: | |
assert max_train_epoch != -1 | |
max_train_steps = max_train_epoch * len(train_dataloader) | |
if checkpointing_steps == -1: | |
assert checkpointing_epochs != -1 | |
checkpointing_steps = checkpointing_epochs * len(train_dataloader) | |
if scale_lr: | |
learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes) | |
# Scheduler | |
lr_scheduler = get_scheduler( | |
lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, | |
num_training_steps=max_train_steps * gradient_accumulation_steps, | |
) | |
# Validation pipeline | |
if not image_finetune: | |
validation_pipeline = ValidationPipeline( | |
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, | |
).to(local_rank) | |
else: | |
validation_pipeline = ValidationPipeline( | |
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, | |
).to(local_rank) | |
validation_pipeline.enable_vae_slicing() | |
# DDP warpper | |
unet.to(local_rank) | |
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | |
# Afterwards we recalculate our number of training epochs | |
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) | |
# Train! | |
total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps | |
if is_main_process: | |
logging.info("***** Running training *****") | |
logging.info(f" Num examples = {len(train_dataset)}") | |
logging.info(f" Num Epochs = {num_train_epochs}") | |
logging.info(f" Instantaneous batch size per device = {train_batch_size}") | |
logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
logging.info(f" Total optimization steps = {max_train_steps}") | |
global_step = 0 | |
first_epoch = 0 | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process) | |
progress_bar.set_description("Steps") | |
# Support mixed-precision training | |
scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None | |
motion_module_trainable = False | |
for epoch in range(first_epoch, num_train_epochs): | |
train_dataloader.sampler.set_epoch(epoch) | |
unet.train() | |
for step, batch in enumerate(train_dataloader): | |
if cfg_random_null_text: | |
batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']] | |
# Data batch sanity check | |
if epoch == first_epoch and step == 0: | |
pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] | |
### >>>> Training >>>> ### | |
# Convert videos to latent space, sampling from video | |
pixel_values = batch["pixel_values"].to(local_rank) | |
video_length = pixel_values.shape[1] | |
# scores (b f) cond_frames(b f) | |
scores = batch['score'] | |
scores = torch.stack([score for score in scores]) | |
cond_frames = batch['cond_frames'] | |
with torch.no_grad(): | |
if not image_finetune: | |
pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") | |
latents = vae.encode(pixel_values).latent_dist | |
latents = latents.sample() | |
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) | |
else: | |
latents = vae.encode(pixel_values).latent_dist | |
latents = latents.sample() | |
latents = latents * 0.18215 | |
pixel_values = rearrange(pixel_values, "(b f) c h w -> b f c h w", f=video_length) | |
pixel_values = pixel_values / 2. + 0.5 | |
pixel_values*= 255 | |
# Create Mask and Masked_image_latent | |
# b c f h w | |
mask = torch.zeros((latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4])) | |
masked_image = torch.zeros_like(latents) | |
'''rand_mask = random.random() | |
if rand_mask > 0.2: | |
rand_frame = random.randint(0, video_length - 1) | |
mask[:,:,rand_frame,:,:] = 1 | |
for f in range(video_length): | |
masked_image[:,:,f,:,:] = latents[:,:,rand_frame,:,:].clone() | |
else: | |
masked_image = torch.zeros_like(latents) | |
mask = torch.zeros((latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4]))''' | |
is_cond = random.random() | |
rand_size = latents.shape[0] | |
if is_cond > 0.2: | |
for rs in range(rand_size): | |
#rand_frame = random.randint(0, video_length - 1) | |
video_shape = [pixel_values.shape[0], pixel_values.shape[1]] | |
mask_coef = prepare_mask_coef_by_score(video_shape, cond_frame_idx=cond_frames, | |
statistic=statistic, score=torch.tensor(scores).unsqueeze(0)) | |
#mask_coef = prepare_mask_coef(video_length, rand_frame, mask_sim_range) | |
#mask[:,:,rand_frame,:,:] = 1 | |
for f in range(video_length): | |
mask[rs,:,f,:,:] = mask_coef[rs, f] | |
masked_image[rs,:,f,:,:] = latents[rs,:,cond_frames[rs],:,:].clone() | |
else: | |
masked_image = torch.zeros_like(latents) | |
mask = torch.zeros((latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4])) | |
'''mask[:,:,0,:,:] = 1 | |
for f in range(video_length): | |
masked_image[:,:,f,:,:] = latents[:,:,0,:,:].clone()''' | |
# Sample noise that we'll add to the latents | |
noise = torch.randn_like(latents) | |
bsz = latents.shape[0] | |
# Sample a random timestep for each video | |
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | |
timesteps = timesteps.long() | |
# Add noise to the latents according to the noise magnitude at each timestep | |
# (this is the forward diffusion process) | |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
# Get the text embedding for conditioning | |
with torch.no_grad(): | |
prompt_ids = tokenizer( | |
batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | |
).input_ids.to(latents.device) | |
encoder_hidden_states = text_encoder(prompt_ids)[0] | |
# Get the target for loss depending on the prediction type | |
if noise_scheduler.config.prediction_type == "epsilon": | |
target = noise | |
elif noise_scheduler.config.prediction_type == "v_prediction": | |
raise NotImplementedError | |
else: | |
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | |
# Predict the noise residual and compute loss | |
# Mixed-precision training | |
with torch.cuda.amp.autocast(enabled=mixed_precision_training): | |
model_pred = unet(noisy_latents, mask, masked_image, timesteps, encoder_hidden_states).sample | |
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
loss = loss / gradient_accumulation_steps | |
'''if (step + 1) % gradient_accumulation_steps == 0: | |
optimizer.zero_grad()''' | |
# Backpropagate, accumulate gradient | |
if mixed_precision_training: | |
scaler.scale(loss).backward() | |
""" >>> gradient clipping >>> """ | |
if (step + 1) % gradient_accumulation_steps == 0: | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) | |
# Calculate the gradient norm | |
if (step + 1) % gradient_accumulation_steps == 0: | |
if isinstance(unet.parameters(), torch.Tensor): | |
params = [unet.parameters()] | |
grads = [p.grad for p in params if p.grad is not None] | |
else: | |
grads = [p.grad for p in unet.parameters() if p.grad is not None] | |
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2.0) for g in grads]), 2.0) | |
""" <<< gradient clipping <<< """ | |
if (step + 1) % gradient_accumulation_steps == 0: | |
scaler.step(optimizer) | |
scaler.update() | |
else: | |
loss.backward() | |
""" >>> gradient clipping >>> """ | |
torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) | |
# Calculate the gradient norm | |
if (step + 1) % gradient_accumulation_steps == 0: | |
if isinstance(unet.parameters(), torch.Tensor): | |
params = [unet.parameters()] | |
grads = [p.grad for p in params if p.grad is not None] | |
else: | |
grads = [p.grad for p in unet.parameters() if p.grad is not None] | |
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2.0) for g in grads]), 2.0) | |
""" <<< gradient clipping <<< """ | |
if (step + 1) % gradient_accumulation_steps == 0: | |
optimizer.step() | |
if (step + 1) % gradient_accumulation_steps == 0: | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
progress_bar.update(1 * gradient_accumulation_steps) | |
global_step += 1 | |
# Set motion module trainable TODO: Debug | |
'''if (motion_module_trainable == False) and (step > motion_module_trainable_step) and ((step + 1) % gradient_accumulation_steps == 0): | |
for name, param in unet.named_parameters(): | |
if 'motion_modules.' in name: | |
logging.info(f'{name} is trainable \n') | |
#print(f'{name} is trainable') | |
param.requires_grad = True | |
zero_rank_print('motion module is trainable now!') | |
motion_module_trainable = True''' | |
### <<<< Training <<<< ### | |
# Wandb logging | |
if is_main_process and (not is_debug) and use_wandb and ((step + 1) % gradient_accumulation_steps == 0): | |
wandb.log({"gradient_norm": total_norm.item()}, step=global_step) | |
# Save checkpoint and Periodically validation | |
if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple): | |
samples = [] | |
generator = torch.Generator(device=latents.device) | |
generator.manual_seed(global_seed) | |
height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size | |
width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size | |
prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts | |
for idx, prompt in enumerate(prompts): | |
use_image = False | |
if not image_finetune: | |
if idx < 2: | |
use_image = idx + 1 | |
else: | |
use_image = False | |
sample = validation_pipeline( | |
prompt, | |
use_image = use_image, | |
generator = generator, | |
video_length = train_data.sample_n_frames, | |
height = 512, | |
width = 512, | |
**validation_data, | |
).videos | |
save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif") | |
samples.append(sample) | |
else: | |
sample = validation_pipeline( | |
prompt, | |
generator = generator, | |
height = height, | |
width = width, | |
num_inference_steps = validation_data.get("num_inference_steps", 25), | |
guidance_scale = validation_data.get("guidance_scale", 8.), | |
).images[0] | |
sample = torchvision.transforms.functional.to_tensor(sample) | |
samples.append(sample) | |
if not image_finetune: | |
samples = torch.concat(samples) | |
save_path = f"{output_dir}/samples/sample-{global_step}.gif" | |
save_videos_grid(samples, save_path) | |
else: | |
samples = torch.stack(samples) | |
save_path = f"{output_dir}/samples/sample-{global_step}.png" | |
torchvision.utils.save_image(samples, save_path, nrow=4) | |
logging.info(f"Saved samples to {save_path}") | |
save_path = os.path.join(output_dir, f"checkpoints") | |
state_dict = { | |
"epoch": epoch, | |
"global_step": global_step, | |
"state_dict": unet.state_dict(), | |
} | |
inpaint_ckpt = state_dict['state_dict'] | |
trained_ckpt = {} | |
for (key, value) in zip(inpaint_ckpt.keys(), inpaint_ckpt.values()): | |
new_key = key.replace('module.', '') | |
trained_ckpt[new_key] = value | |
if step == len(train_dataloader) - 1: | |
torch.save(trained_ckpt, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt")) | |
else: | |
torch.save(trained_ckpt, os.path.join(save_path, f"checkpoint{step+1}.ckpt")) | |
logging.info(f"Saved state to {save_path} (global_step: {global_step})") | |
logging.info(f"(global_step: {global_step}) loss: {loss.detach().item()}") | |
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | |
progress_bar.set_postfix(**logs) | |
if global_step >= max_train_steps: | |
break | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, required=True) | |
parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="slurm") | |
parser.add_argument("--wandb", action="store_true", default=True) | |
args = parser.parse_args() | |
name = Path(args.config).stem | |
config = OmegaConf.load(args.config) | |
main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config) | |