Upload modeling_xtrimopglm.py
Browse files- modeling_xtrimopglm.py +1 -8
modeling_xtrimopglm.py
CHANGED
@@ -12,7 +12,7 @@ import random
|
|
12 |
import numpy as np
|
13 |
from tqdm.auto import tqdm
|
14 |
|
15 |
-
import torch
|
16 |
import torch.utils.checkpoint
|
17 |
import torch.nn.functional as F
|
18 |
from torch import nn
|
@@ -37,13 +37,6 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
|
|
37 |
from .configuration_xtrimopglm import xTrimoPGLMConfig
|
38 |
from .quantization import quantize
|
39 |
|
40 |
-
def get_checkpoint_fn():
|
41 |
-
if deepspeed.checkpointing.is_configured():
|
42 |
-
checkpoint = deepspeed.checkpointing.checkpoint
|
43 |
-
else:
|
44 |
-
checkpoint = torch.utils.checkpoint.checkpoint
|
45 |
-
return checkpoint
|
46 |
-
|
47 |
# flags required to enable jit fusion kernels
|
48 |
|
49 |
if sys.platform != 'darwin':
|
|
|
12 |
import numpy as np
|
13 |
from tqdm.auto import tqdm
|
14 |
|
15 |
+
import torch
|
16 |
import torch.utils.checkpoint
|
17 |
import torch.nn.functional as F
|
18 |
from torch import nn
|
|
|
37 |
from .configuration_xtrimopglm import xTrimoPGLMConfig
|
38 |
from .quantization import quantize
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# flags required to enable jit fusion kernels
|
41 |
|
42 |
if sys.platform != 'darwin':
|