Tess-M-34B-2bit / quip-sharp /model /graph_wrapper.py
KnutJaegersberg's picture
Upload 132 files
c1a41d7
raw
history blame
1.65 kB
import torch
import glog
import time
def get_graph_wrapper(cls):
class GraphWrapper(cls):
def __init__(self, config):
super(GraphWrapper, self).__init__(config)
self.built_graph = False
def forward(self, *args, **kwargs):
start = time.time()
if not self.built_graph:
self.static_args = args
self.static_kwargs = kwargs
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
super(GraphWrapper, self).forward(*self.static_args, **self.static_kwargs)
torch.cuda.current_stream().wait_stream(s)
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = super(GraphWrapper, self).forward(*self.static_args, **self.static_kwargs)
self.built_graph = True
glog.info("Built CUDA graph of model.")
# these two loops take < 1e-4 seconds for llama2
for i in range(len(args)):
if isinstance(args[i], torch.Tensor):
self.static_args[i].copy_(args[i])
for kw in kwargs:
if isinstance(kwargs[kw], torch.Tensor):
self.static_kwargs[kw].copy_(kwargs[kw])
self.graph.replay()
return self.static_output
def reset(self):
if self.built_graph:
del self.static_args, self.static_kwargs
self.built_graph = False
return GraphWrapper