|
import logging |
|
import os |
|
|
|
import chainer |
|
import torch |
|
|
|
|
|
def set_deterministic_pytorch(args): |
|
"""Ensures pytorch produces deterministic results depending on the program arguments |
|
|
|
:param Namespace args: The program arguments |
|
""" |
|
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = ( |
|
False |
|
) |
|
if args.debugmode < 2: |
|
chainer.config.type_check = False |
|
logging.info("torch type check is disabled") |
|
|
|
if args.debugmode < 1: |
|
torch.backends.cudnn.deterministic = False |
|
torch.backends.cudnn.benchmark = True |
|
logging.info("torch cudnn deterministic is disabled") |
|
|
|
|
|
def set_deterministic_chainer(args): |
|
"""Ensures chainer produces deterministic results depending on the program arguments |
|
|
|
:param Namespace args: The program arguments |
|
""" |
|
|
|
os.environ["CHAINER_SEED"] = str(args.seed) |
|
logging.info("chainer seed = " + os.environ["CHAINER_SEED"]) |
|
|
|
|
|
|
|
|
|
|
|
if args.debugmode < 2: |
|
chainer.config.type_check = False |
|
logging.info("chainer type check is disabled") |
|
|
|
if args.debugmode < 1: |
|
chainer.config.cudnn_deterministic = False |
|
logging.info("chainer cudnn deterministic is disabled") |
|
else: |
|
chainer.config.cudnn_deterministic = True |
|
|