|
import torch |
|
import glog |
|
import time |
|
|
|
def get_graph_wrapper(cls): |
|
class GraphWrapper(cls): |
|
def __init__(self, config): |
|
super(GraphWrapper, self).__init__(config) |
|
self.built_graph = False |
|
|
|
def forward(self, *args, **kwargs): |
|
start = time.time() |
|
if not self.built_graph: |
|
self.static_args = args |
|
self.static_kwargs = kwargs |
|
|
|
s = torch.cuda.Stream() |
|
s.wait_stream(torch.cuda.current_stream()) |
|
with torch.cuda.stream(s): |
|
super(GraphWrapper, self).forward(*self.static_args, **self.static_kwargs) |
|
torch.cuda.current_stream().wait_stream(s) |
|
|
|
self.graph = torch.cuda.CUDAGraph() |
|
with torch.cuda.graph(self.graph): |
|
self.static_output = super(GraphWrapper, self).forward(*self.static_args, **self.static_kwargs) |
|
|
|
self.built_graph = True |
|
glog.info("Built CUDA graph of model.") |
|
|
|
|
|
for i in range(len(args)): |
|
if isinstance(args[i], torch.Tensor): |
|
self.static_args[i].copy_(args[i]) |
|
for kw in kwargs: |
|
if isinstance(kwargs[kw], torch.Tensor): |
|
self.static_kwargs[kw].copy_(kwargs[kw]) |
|
|
|
self.graph.replay() |
|
return self.static_output |
|
|
|
def reset(self): |
|
if self.built_graph: |
|
del self.static_args, self.static_kwargs |
|
self.built_graph = False |
|
|
|
return GraphWrapper |
|
|