|
import torch.cuda |
|
|
|
try: |
|
from torch._C import _cudnn |
|
except ImportError: |
|
|
|
|
|
_cudnn = None |
|
|
|
|
|
def get_cudnn_mode(mode): |
|
if mode == 'RNN_RELU': |
|
return int(_cudnn.RNNMode.rnn_relu) |
|
elif mode == 'RNN_TANH': |
|
return int(_cudnn.RNNMode.rnn_tanh) |
|
elif mode == 'LSTM': |
|
return int(_cudnn.RNNMode.lstm) |
|
elif mode == 'GRU': |
|
return int(_cudnn.RNNMode.gru) |
|
else: |
|
raise Exception("Unknown mode: {}".format(mode)) |
|
|
|
|
|
|
|
|
|
|
|
class Unserializable(object): |
|
|
|
def __init__(self, inner): |
|
self.inner = inner |
|
|
|
def get(self): |
|
return self.inner |
|
|
|
def __getstate__(self): |
|
|
|
|
|
return "<unserializable>" |
|
|
|
def __setstate__(self, state): |
|
self.inner = None |
|
|
|
|
|
def init_dropout_state(dropout, train, dropout_seed, dropout_state): |
|
dropout_desc_name = 'desc_' + str(torch.cuda.current_device()) |
|
dropout_p = dropout if train else 0 |
|
if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None): |
|
if dropout_p == 0: |
|
dropout_state[dropout_desc_name] = Unserializable(None) |
|
else: |
|
dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state( |
|
dropout_p, |
|
train, |
|
dropout_seed, |
|
self_ty=torch.uint8, |
|
device=torch.device('cuda'))) |
|
dropout_ts = dropout_state[dropout_desc_name].get() |
|
return dropout_ts |
|
|