Spaces:
Sleeping
Sleeping
import sys | |
from pprint import pformat | |
from typing import Any | |
import os | |
import torch | |
import ignite.distributed as idist | |
import yaml | |
from ignite.engine import Events | |
from ignite.metrics import Accuracy, Loss | |
from ignite.utils import manual_seed | |
from torch import nn, optim | |
from modelguidedattacks.data.setup import setup_data | |
from modelguidedattacks.losses.boilerplate import BoilerplateLoss | |
from modelguidedattacks.losses.energy import Energy, EnergyLoss | |
from modelguidedattacks.metrics.topk_accuracy import TopKAccuracy | |
from modelguidedattacks.models import setup_model | |
from modelguidedattacks.trainers import setup_evaluator, setup_trainer | |
from modelguidedattacks.utils import setup_parser, setup_output_dir | |
from modelguidedattacks.utils import setup_logging, log_metrics, Engine | |
def run(local_rank: int, config: Any): | |
print ("Running ", local_rank) | |
# make a certain seed | |
rank = idist.get_rank() | |
manual_seed(config.seed + rank) | |
# create output folder | |
config.output_dir = setup_output_dir(config, rank) | |
# setup engines logger with python logging | |
# print training configurations | |
logger = setup_logging(config) | |
logger.info("Configuration: \n%s", pformat(vars(config))) | |
(config.output_dir / "config-lock.yaml").write_text(yaml.dump(config)) | |
# donwload datasets and create dataloaders | |
dataloader_train, dataloader_eval = setup_data(config, rank) | |
# model, optimizer, loss function, device | |
device = idist.device() | |
model = idist.auto_model(setup_model(config, idist.device())) | |
loss_fn = BoilerplateLoss().to(device=device) | |
l2_energy_loss = Energy(p=2).to(device) | |
l1_energy_loss = Energy(p=1).to(device) | |
l_inf_energy_loss = Energy(p=torch.inf).to(device) | |
evaluator = setup_evaluator(config, model, device) | |
evaluator.logger = logger | |
# attach metrics to evaluator | |
accuracy = TopKAccuracy(device=device) | |
metrics = { | |
"ASR": accuracy, | |
"L2 Energy": EnergyLoss(l2_energy_loss, device=device), | |
"L1 Energy": EnergyLoss(l1_energy_loss, device=device), | |
"L_inf Energy": EnergyLoss(l_inf_energy_loss, device=device), | |
"L2 Energy Min": EnergyLoss(l2_energy_loss, reduction="min", device=device), | |
"L1 Energy Min": EnergyLoss(l1_energy_loss, reduction="min", device=device), | |
"L_inf Energy Min": EnergyLoss(l_inf_energy_loss, reduction="min", device=device), | |
"L2 Energy Max": EnergyLoss(l2_energy_loss, reduction="max", device=device), | |
"L1 Energy Max": EnergyLoss(l1_energy_loss, reduction="max", device=device), | |
"L_inf Energy Max": EnergyLoss(l_inf_energy_loss, reduction="max", device=device) | |
} | |
for name, metric in metrics.items(): | |
metric.attach(evaluator, name) | |
if config.guide_model in ["unguided", "instance_guided"]: | |
first_batch_passed = False | |
early_stopped = False | |
def compute_metrics(engine: Engine, tag: str): | |
nonlocal first_batch_passed | |
nonlocal early_stopped | |
for name, metric in metrics.items(): | |
metric.completed(engine, name) | |
if not first_batch_passed: | |
if engine.state.metrics["ASR"] < 1e-3: | |
print ("Early stop, assuming no success throughout") | |
early_stopped = True | |
engine.terminate() | |
else: | |
first_batch_passed = True | |
evaluator.add_event_handler( | |
Events.ITERATION_COMPLETED(every=config.log_every_iters), | |
compute_metrics, | |
tag="eval", | |
) | |
evaluator.add_event_handler( | |
Events.ITERATION_COMPLETED(every=config.log_every_iters), | |
log_metrics, | |
tag="eval", | |
) | |
evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length) | |
log_metrics(evaluator, "eval") | |
if len(config.out_dir) > 0: | |
# Store results in out_dir | |
os.makedirs(config.out_dir, exist_ok=True) | |
metrics_dict = evaluator.state.metrics | |
metrics_dict["config"] = config | |
metrics_dict["early_stopped"] = early_stopped | |
metrics_file_path = os.path.join(config.out_dir, "results.save") | |
torch.save(metrics_dict, metrics_file_path) | |
# No need to train with an unguided model | |
return | |
assert False, "This code path is for the future" | |
# main entrypoint | |
def launch(config=None): | |
if config is None: | |
config_path = sys.argv[1] | |
config = setup_parser(config_path).parse_args(sys.argv[2:]) | |
backend = config.backend | |
nproc_per_node = config.nproc_per_node | |
if nproc_per_node == 0 or backend is None: | |
backend = None | |
nproc_per_node = None | |
with idist.Parallel(backend, nproc_per_node) as p: | |
p.run(run, config=config) | |
if __name__ == "__main__": | |
launch() | |