|
import os |
|
|
|
import click |
|
import omegaconf |
|
import ray |
|
from pyprojroot.pyprojroot import here |
|
from ray import tune |
|
from ray.train import Trainer |
|
from ray.tune.schedulers import AsyncHyperBandScheduler |
|
from ray.tune.suggest import ConcurrencyLimiter |
|
from ray.tune.suggest.optuna import OptunaSearch |
|
|
|
from ganime.trainer.ganime import TrainableGANime |
|
|
|
import os |
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1, 2, 3, 4, 5, 6" |
|
|
|
|
|
def get_metric_direction(metric: str): |
|
if "loss" in metric: |
|
return "min" |
|
else: |
|
raise ValueError(f"Unknown metric: {metric}") |
|
|
|
|
|
def trial_name_id(trial): |
|
return f"{trial.trainable_name}" |
|
|
|
|
|
def trial_dirname_creator(trial): |
|
return f"{trial.trial_id}" |
|
|
|
|
|
def get_search_space(model): |
|
if model == "vqgan": |
|
return { |
|
|
|
"num_embeddings": tune.choice([64, 128, 256]), |
|
"embedding_dim": tune.choice([128, 256, 512, 1024]), |
|
"z_channels": tune.choice([64, 128, 256]), |
|
"channels": tune.choice([64, 128, 256]), |
|
"channels_multiplier": tune.choice( |
|
[ |
|
[1, 2, 4], |
|
[1, 1, 2, 2], |
|
[1, 2, 2, 4], |
|
[1, 1, 2, 2, 4], |
|
] |
|
), |
|
"attention_resolution": tune.choice([[16], [32], [16, 32]]), |
|
"batch_size": tune.choice([8, 16]), |
|
"dropout": tune.choice([0.0, 0.1, 0.2]), |
|
"weight": tune.quniform(0.1, 1.0, 0.1), |
|
"codebook_weight": tune.quniform(0.2, 2.0, 0.2), |
|
"perceptual_weight": tune.quniform(0.5, 5.0, 0.5), |
|
"gen_lr": tune.qloguniform(1e-5, 1e-3, 1e-5), |
|
"disc_lr": tune.qloguniform(1e-5, 1e-3, 1e-5), |
|
"gen_beta_1": tune.quniform(0.5, 0.9, 0.1), |
|
"gen_beta_2": tune.quniform(0.9, 0.999, 0.001), |
|
"disc_beta_1": tune.quniform(0.5, 0.9, 0.1), |
|
"disc_beta_2": tune.quniform(0.9, 0.999, 0.001), |
|
"gen_clip_norm": tune.choice([1.0, None]), |
|
"disc_clip_norm": tune.choice([1.0, None]), |
|
} |
|
elif model == "gpt": |
|
return { |
|
"remaining_frames_method": tune.choice( |
|
["concat", "token_type_ids", "own_embeddings"] |
|
), |
|
|
|
"lr_max": tune.qloguniform(1e-5, 1e-3, 5e-5), |
|
"lr_start": tune.sample_from(lambda spec: spec.config.lr_max / 10), |
|
"perceptual_loss_weight": tune.quniform(0.0, 1.0, 0.1), |
|
"n_frames_before": tune.randint(1, 10), |
|
} |
|
|
|
|
|
def tune_ganime( |
|
experiment_name: str, |
|
dataset_name: str, |
|
config_file: str, |
|
model: str, |
|
metric: str, |
|
epochs: int, |
|
num_samples: int, |
|
num_cpus: int, |
|
num_gpus: int, |
|
max_concurrent_trials: int, |
|
): |
|
|
|
dataset_path = here("data") |
|
analysis = tune.run( |
|
TrainableGANime, |
|
name=experiment_name, |
|
search_alg=ConcurrencyLimiter( |
|
OptunaSearch(), max_concurrent=max_concurrent_trials |
|
), |
|
scheduler=AsyncHyperBandScheduler(max_t=epochs, grace_period=5), |
|
metric=metric, |
|
mode=get_metric_direction(metric), |
|
num_samples=num_samples, |
|
stop={"training_iteration": epochs}, |
|
local_dir="./ganime_results", |
|
config={ |
|
"dataset_name": dataset_name, |
|
"dataset_path": dataset_path, |
|
"model": model, |
|
"config_file": config_file, |
|
"hyperparameters": get_search_space(model), |
|
}, |
|
resources_per_trial={ |
|
"cpu": num_cpus // max_concurrent_trials, |
|
"gpu": num_gpus / max_concurrent_trials, |
|
}, |
|
trial_name_creator=trial_name_id, |
|
trial_dirname_creator=trial_dirname_creator, |
|
) |
|
best_loss = analysis.get_best_config(metric="total_loss", mode="min") |
|
|
|
print(f"Best loss config: {best_loss}") |
|
|
|
return analysis |
|
|
|
|
|
@click.command() |
|
@click.option( |
|
"--dataset", |
|
type=click.Choice( |
|
["moving_mnist_images", "kny_images", "kny_images_light"], case_sensitive=False |
|
), |
|
default="kny_images_light", |
|
help="Dataset to use", |
|
) |
|
@click.option( |
|
"--model", |
|
type=click.Choice(["vqgan", "gpt"], case_sensitive=False), |
|
default="vqgan", |
|
help="Model to use", |
|
) |
|
@click.option( |
|
"--epochs", |
|
default=500, |
|
help="Number of epochs to run", |
|
) |
|
@click.option( |
|
"--num_samples", |
|
default=100, |
|
help="Total number of trials to run", |
|
) |
|
@click.option( |
|
"--num_cpus", |
|
default=64, |
|
help="Number of cpus to use", |
|
) |
|
@click.option( |
|
"--num_gpus", |
|
default=6, |
|
help="Number of gpus to use", |
|
) |
|
@click.option( |
|
"--max_concurrent_trials", |
|
default=6, |
|
help="Maximum number of concurrent trials", |
|
) |
|
@click.option( |
|
"--metric", |
|
type=click.Choice( |
|
["total_loss", "reconstruction_loss", "vq_loss", "disc_loss"], |
|
case_sensitive=False, |
|
), |
|
default="total_loss", |
|
help="The metric used to select the best trial", |
|
) |
|
@click.option( |
|
"--experiment_name", |
|
default="kny_images_light_v2", |
|
help="The name of the experiment for logging in Tensorboard", |
|
) |
|
@click.option( |
|
"--config_file", |
|
default="kny_image.yaml", |
|
help="The name of the config file located inside ./config", |
|
) |
|
def run( |
|
experiment_name: str, |
|
config_file: str, |
|
dataset: str, |
|
model: str, |
|
epochs: int, |
|
num_samples: int, |
|
num_cpus: int, |
|
num_gpus: int, |
|
max_concurrent_trials: int, |
|
metric: str, |
|
): |
|
config_file = here(os.path.join("configs", config_file)) |
|
|
|
ray.init(num_cpus=num_cpus, num_gpus=num_gpus) |
|
tune_ganime( |
|
experiment_name=experiment_name, |
|
dataset_name=dataset, |
|
config_file=config_file, |
|
model=model, |
|
epochs=epochs, |
|
num_samples=num_samples, |
|
num_cpus=num_cpus, |
|
num_gpus=num_gpus, |
|
max_concurrent_trials=max_concurrent_trials, |
|
metric=metric, |
|
) |
|
|