Update modeling_chatglm.py
Browse files- modeling_chatglm.py +2 -3
modeling_chatglm.py
CHANGED
@@ -36,7 +36,7 @@ if sys.platform != 'darwin':
|
|
36 |
torch._C._jit_override_can_fuse_on_gpu(True)
|
37 |
|
38 |
logger = logging.get_logger(__name__)
|
39 |
-
|
40 |
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
41 |
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
42 |
|
@@ -157,7 +157,6 @@ class RotaryEmbedding(nn.Module):
|
|
157 |
)
|
158 |
|
159 |
|
160 |
-
@torch.jit.script
|
161 |
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
162 |
# x: [sq, b, np, hn]
|
163 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
@@ -1297,4 +1296,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1297 |
past_key_values=transformer_outputs.past_key_values,
|
1298 |
hidden_states=transformer_outputs.hidden_states,
|
1299 |
attentions=transformer_outputs.attentions,
|
1300 |
-
)
|
|
|
36 |
torch._C._jit_override_can_fuse_on_gpu(True)
|
37 |
|
38 |
logger = logging.get_logger(__name__)
|
39 |
+
|
40 |
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
41 |
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
42 |
|
|
|
157 |
)
|
158 |
|
159 |
|
|
|
160 |
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
161 |
# x: [sq, b, np, hn]
|
162 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
|
|
1296 |
past_key_values=transformer_outputs.past_key_values,
|
1297 |
hidden_states=transformer_outputs.hidden_states,
|
1298 |
attentions=transformer_outputs.attentions,
|
1299 |
+
)
|