x54-729
commited on
Commit
•
9b8d955
1
Parent(s):
e2c47cf
update modeling file to newest
Browse files- configuration_internlm2.py +1 -1
- modeling_internlm2.py +10 -2
configuration_internlm2.py
CHANGED
@@ -177,4 +177,4 @@ class InternLM2Config(PretrainedConfig):
|
|
177 |
raise ValueError(
|
178 |
f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
|
179 |
f"of type {type(rope_scaling_factor)}"
|
180 |
-
)
|
|
|
177 |
raise ValueError(
|
178 |
f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
|
179 |
f"of type {type(rope_scaling_factor)}"
|
180 |
+
)
|
modeling_internlm2.py
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
-
"""PyTorch InternLM2
|
17 |
import math
|
18 |
import queue
|
19 |
import threading
|
@@ -59,6 +59,10 @@ try:
|
|
59 |
except:
|
60 |
pass
|
61 |
|
|
|
|
|
|
|
|
|
62 |
|
63 |
logger = logging.get_logger(__name__)
|
64 |
|
@@ -1093,7 +1097,11 @@ class InternLM2Model(InternLM2PreTrainedModel):
|
|
1093 |
else:
|
1094 |
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
1095 |
if sequence_length != 1:
|
1096 |
-
|
|
|
|
|
|
|
|
|
1097 |
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
1098 |
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
1099 |
if attention_mask is not None:
|
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
+
"""PyTorch InternLM2 model."""
|
17 |
import math
|
18 |
import queue
|
19 |
import threading
|
|
|
59 |
except:
|
60 |
pass
|
61 |
|
62 |
+
try:
|
63 |
+
support_bf16_triu = torch.__version__ >= "2.1.0"
|
64 |
+
except Exception:
|
65 |
+
support_bf16_triu = False
|
66 |
|
67 |
logger = logging.get_logger(__name__)
|
68 |
|
|
|
1097 |
else:
|
1098 |
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
1099 |
if sequence_length != 1:
|
1100 |
+
if support_bf16_triu or dtype == torch.float32:
|
1101 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
1102 |
+
else:
|
1103 |
+
triu_mask = torch.triu(torch.ones(causal_mask.size(), device=device), diagonal=1).bool()
|
1104 |
+
causal_mask.masked_fill_(~triu_mask, 0)
|
1105 |
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
1106 |
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
1107 |
if attention_mask is not None:
|