File size: 22,800 Bytes
e0be88b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 |
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
from transformers import LlamaConfig
from transformers.testing_utils import is_torch_available, require_torch, torch_device
if is_torch_available():
import torch
from transformers import ROPE_INIT_FUNCTIONS
from transformers.modeling_rope_utils import rope_config_validation
@require_torch
class RopeTest(unittest.TestCase):
def test_rope_validation(self):
config = LlamaConfig()
all_rope_types = ROPE_INIT_FUNCTIONS.keys()
# The base config is always valid (default RoPE)
rope_config_validation(config)
# If we explicitly set the other RoPE types, then validation should fail
for rope_type in all_rope_types:
if rope_type != "default":
config.rope_scaling = {"rope_type": rope_type}
with self.assertRaises(KeyError):
rope_config_validation(config)
# Parameters are exclusive to their own RoPE type, and should raise an exception if incorrectly passed
valid_param_mapping = {
"factor": ["linear", "dynamic", "yarn", "longrope"],
"attention_factor": ["yarn", "longrope"],
"beta_fast": ["yarn"],
"beta_slow": ["yarn"],
"short_factor": ["longrope"],
"long_factor": ["longrope"],
}
for rope_type in all_rope_types:
if rope_type == "default":
continue # checked above
for param, valid_rope_types in valid_param_mapping.items():
# Set `param` with a dummy value -- we want to test the dict key
config.rope_scaling = {"rope_type": rope_type, param: True}
if rope_type in valid_rope_types:
continue
else:
with self.assertRaises(KeyError):
rope_config_validation(config)
# Any other parameters passed to RoPE will raise a warning that a particular key is not used
# But sometimes we can have model-specific RoPE kwargs and bypass warning with `ignore_keys`
model_specific_kwarg = "mrope_sections" # e,g in Qwen2-VL
for rope_type in all_rope_types:
if rope_type == "default":
config.rope_scaling = {"rope_type": rope_type, model_specific_kwarg: True}
rope_config_validation(config, ignore_keys={model_specific_kwarg})
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
rope_config_validation(config)
self.assertEqual(len(logs.output), 1)
self.assertIn(model_specific_kwarg, logs.output[0])
def test_default_rope_function_bc(self):
config = LlamaConfig()
device = torch_device
rope_kwargs = {
"rope_type": "default",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
}
rope_fn = ROPE_INIT_FUNCTIONS["default"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
def test_linear_rope_function_bc(self):
config = LlamaConfig()
config.rope_scaling = {"rope_type": "linear", "factor": 10.0}
device = torch_device
rope_kwargs = {
"rope_type": "linear",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
"factor": 10.0,
}
rope_fn = ROPE_INIT_FUNCTIONS["linear"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
def test_dynamic_rope_function_bc(self):
config = LlamaConfig()
config.rope_scaling = {"rope_type": "dynamic", "factor": 10.0}
device = torch_device
rope_kwargs = {
"rope_type": "dynamic",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
"factor": 10.0,
}
rope_fn = ROPE_INIT_FUNCTIONS["dynamic"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)
def test_default_rope_numerically(self):
# Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
# multiple RoPE strategies will fail.
# fmt: off
EXPECTED_INV_FREQ = torch.tensor(
[
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
1.3335e-02, 1.1548e-02, 1.0000e-02, 8.6596e-03, 7.4989e-03, 6.4938e-03,
5.6234e-03, 4.8697e-03, 4.2170e-03, 3.6517e-03, 3.1623e-03, 2.7384e-03,
2.3714e-03, 2.0535e-03, 1.7783e-03, 1.5399e-03, 1.3335e-03, 1.1548e-03,
1.0000e-03, 8.6596e-04, 7.4989e-04, 6.4938e-04, 5.6234e-04, 4.8697e-04,
4.2170e-04, 3.6517e-04, 3.1623e-04, 2.7384e-04, 2.3714e-04, 2.0535e-04,
1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04
], device=torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_scaling, None)
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertEqual(config.rope_theta, 10000.0)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
rope_fn = ROPE_INIT_FUNCTIONS["default"]
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for default RoPE
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
def test_linear_rope_numerically(self):
# This is a linear scaling strategy, the **frequencies** are scaled linearly with respect to the default
# frequencies (= the inverse frequencies are scaled **inversely**)
config = LlamaConfig()
default_rope_fn = ROPE_INIT_FUNCTIONS["default"]
default_inv_freq, _ = default_rope_fn(config=config, device=torch_device)
rope_fn = ROPE_INIT_FUNCTIONS["linear"]
for factor in (2.0, 10.0, 20.0):
config.rope_scaling = {"rope_type": "linear", "factor": factor}
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for linear RoPE
torch.testing.assert_close(inv_freq, default_inv_freq / factor)
def test_dynamic_rope_numerically(self):
# fmt: off
EXPECTED_INV_FREQ = torch.tensor(
[
1.0000e+00, 8.0931e-01, 6.5498e-01, 5.3008e-01, 4.2900e-01, 3.4720e-01,
2.8099e-01, 2.2741e-01, 1.8404e-01, 1.4895e-01, 1.2055e-01, 9.7558e-02,
7.8955e-02, 6.3899e-02, 5.1714e-02, 4.1853e-02, 3.3872e-02, 2.7413e-02,
2.2185e-02, 1.7955e-02, 1.4531e-02, 1.1760e-02, 9.5176e-03, 7.7027e-03,
6.2339e-03, 5.0451e-03, 4.0831e-03, 3.3045e-03, 2.6744e-03, 2.1644e-03,
1.7517e-03, 1.4176e-03, 1.1473e-03, 9.2852e-04, 7.5146e-04, 6.0817e-04,
4.9220e-04, 3.9834e-04, 3.2238e-04, 2.6091e-04, 2.1115e-04, 1.7089e-04,
1.3830e-04, 1.1193e-04, 9.0585e-05, 7.3312e-05, 5.9332e-05, 4.8018e-05,
3.8861e-05, 3.1451e-05, 2.5453e-05, 2.0600e-05, 1.6672e-05, 1.3492e-05,
1.0920e-05, 8.8374e-06, 7.1522e-06, 5.7883e-06, 4.6845e-06, 3.7912e-06,
3.0683e-06, 2.4832e-06, 2.0097e-06, 1.6265e-06
], device=torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_scaling, None)
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertEqual(config.rope_theta, 10000.0)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
rope_fn = ROPE_INIT_FUNCTIONS["default"]
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
# Check 1: this is a dynamic scaling strategy, it will not scale unless we provide `seq_len` larger than the
# model's original training sequence length
rope_fn = ROPE_INIT_FUNCTIONS["dynamic"]
for factor in (2.0, 10.0, 20.0):
config.rope_scaling = {"rope_type": "dynamic", "factor": factor}
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for dynamic RoPE
torch.testing.assert_close(inv_freq, default_inv_freq)
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1)
torch.testing.assert_close(inv_freq, default_inv_freq)
# Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies
# will scale up (i.e., the inverse frequencies will scale down).
factor = 10.0
config.rope_scaling = {"rope_type": "dynamic", "factor": factor}
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=16384)
with self.assertRaises(AssertionError): # It is NOT a linear factor
torch.testing.assert_close(inv_freq, default_inv_freq / factor)
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
def test_yarn_rope_numerically(self):
# fmt: off
EXPECTED_INV_FREQ = torch.tensor(
[
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.3479e-02,
6.9590e-02, 5.7925e-02, 4.8136e-02, 3.9931e-02, 3.3061e-02, 2.7315e-02,
2.2515e-02, 1.8512e-02, 1.5177e-02, 1.2403e-02, 1.0101e-02, 8.1924e-03,
6.6143e-03, 5.3120e-03, 4.2400e-03, 3.3599e-03, 2.6396e-03, 2.0520e-03,
1.5746e-03, 1.1882e-03, 8.7713e-04, 6.2810e-04, 4.3007e-04, 2.7384e-04,
2.3714e-04, 2.0535e-04, 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04,
1.0000e-04, 8.6596e-05, 7.4989e-05, 6.4938e-05, 5.6234e-05, 4.8697e-05,
4.2170e-05, 3.6517e-05, 3.1623e-05, 2.7384e-05, 2.3714e-05, 2.0535e-05,
1.7783e-05, 1.5399e-05, 1.3335e-05, 1.1548e-05
], device=torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_scaling, None)
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertEqual(config.rope_theta, 10000.0)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
rope_fn = ROPE_INIT_FUNCTIONS["default"]
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
# `0.1 * math.log(factor) + 1.0`
rope_fn = ROPE_INIT_FUNCTIONS["yarn"]
for factor in (2.0, 10.0, 20.0):
config.rope_scaling = {"rope_type": "yarn", "factor": factor}
_, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 0.1 * math.log(factor) + 1.0)
config.rope_scaling = {"rope_type": "yarn", "factor": factor, "attention_factor": 0.5}
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
self.assertEqual(attention_scale, 0.5)
# Check 2: based on `beta_fast` and `beta_slow`, the frequencies will be scaled between 1 and `factor`.
# Increasing `beta_fast` will make RoPE more interpolative (apply scaling), and the other way around.
# `beta_slow` behaves the opposite way. Remember: `beta_fast` > `beta_slow`
# (note: adds a margin to the test for numerical stability)
factor = 10.0
margin = 1e-8
config.rope_scaling = {"rope_type": "yarn", "factor": factor, "beta_fast": 32, "beta_slow": 1}
inv_freq, _ = rope_fn(config=config, device=torch_device)
is_bounded_by_factor = [
((default_inv_freq[idx] / factor) - margin) <= yarn_inv_freq_value <= (default_inv_freq[idx] + margin)
for idx, yarn_inv_freq_value in enumerate(inv_freq)
]
self.assertTrue(all(is_bounded_by_factor))
# super high beta_fast = interpolation (i.e. scaling) in all but the first inverse frequency. The last ~20
# values (empirically checked for `beta_fast` = 1000) should be very small to linear scaling
config.rope_scaling = {"rope_type": "yarn", "factor": factor, "beta_fast": 1000, "beta_slow": 1}
inv_freq, _ = rope_fn(config=config, device=torch_device)
is_interpolating = [
yarn_inv_freq_value < (default_inv_freq[idx] + margin) for idx, yarn_inv_freq_value in enumerate(inv_freq)
]
self.assertFalse(is_interpolating[0])
self.assertTrue(all(is_interpolating[1:]))
torch.testing.assert_close(inv_freq[-20:], default_inv_freq[-20:] / factor)
# Check 3: numerical snapshot to avoid regressions
config.rope_scaling = {"rope_type": "yarn", "factor": factor, "beta_fast": 32, "beta_slow": 1}
inv_freq, _ = rope_fn(config=config, device=torch_device)
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
def test_longrope_rope_numerically(self):
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_scaling, None)
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertEqual(config.rope_theta, 10000.0)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
# longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on the seq_len
dim = config.hidden_size // config.num_attention_heads
short_factor = [2.0] * (dim // 2) # scaling applied when seq_len <= max_position_embeddings
long_factor = torch.ones(dim // 2).cumsum(0).tolist() # scaling applied when seq_len > max_position_embeddings
rope_fn = ROPE_INIT_FUNCTIONS["default"]
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
# `math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))`
rope_fn = ROPE_INIT_FUNCTIONS["longrope"]
max_position_embeddings = config.max_position_embeddings
for factor in (2.0, 10.0, 20.0):
config.rope_scaling = {
"rope_type": "longrope",
"factor": factor,
"short_factor": short_factor,
"long_factor": long_factor,
}
_, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)))
config.rope_scaling = {
"rope_type": "longrope",
"factor": factor,
"short_factor": short_factor,
"long_factor": long_factor,
"attention_factor": 0.5,
}
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
self.assertEqual(attention_scale, 0.5)
config.rope_scaling = {
"rope_type": "longrope",
"factor": factor,
"short_factor": short_factor,
"long_factor": long_factor,
}
self.assertEqual(config.rope_scaling.get("attention_factor"), None)
# Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised.
rope_config_validation(config)
# Check 2: seq_len == 0 -> short factor is applied to the default frequencies
config.rope_scaling = {
"rope_type": "longrope",
"factor": 1.0,
"short_factor": short_factor,
"long_factor": long_factor,
}
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=0)
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(short_factor).to(torch_device))
# Check 3: seq_len > max_position_embeddings -> long factor is applied to the default frequencies
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=config.max_position_embeddings + 1)
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(long_factor).to(torch_device))
def test_llama3_rope_numerically(self):
# fmt: off
EXPECTED_INV_FREQ = torch.tensor(
[
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
1.3335e-02, 1.0730e-02, 7.7785e-03, 5.6009e-03, 3.9991e-03, 2.8248e-03,
1.9675e-03, 1.3449e-03, 8.9549e-04, 5.7363e-04, 3.4539e-04, 2.7384e-04,
2.3714e-04, 2.0535e-04, 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04,
1.0000e-04, 8.6596e-05, 7.4989e-05, 6.4938e-05, 5.6234e-05, 4.8697e-05,
4.2170e-05, 3.6517e-05, 3.1623e-05, 2.7384e-05, 2.3714e-05, 2.0535e-05,
1.7783e-05, 1.5399e-05, 1.3335e-05, 1.1548e-05
], device=torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_scaling, None)
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertEqual(config.rope_theta, 10000.0)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
rope_fn = ROPE_INIT_FUNCTIONS["default"]
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
# Check 1: `attention_factor` is always 1
rope_fn = ROPE_INIT_FUNCTIONS["llama3"]
for factor in (2.0, 10.0, 20.0):
config.rope_scaling = {
"rope_type": "llama3",
"factor": factor,
"original_max_position_embeddings": 2048,
"low_freq_factor": 1,
"high_freq_factor": 4,
}
_, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 1.0)
# Check 2: based on `low_freq_factor` and `high_freq_factor`, the frequencies will be scaled between 1 and
# `factor` (similar to yarn). Low frequencies get scaled by `factor`, high frequencies see no change, medium
# frequencies are scaled by a value in between. Changing `low_freq_factor` and `high_freq_factor` changes what
# is considered low, medium, and high frequencies.
factor = 10.0
config.rope_scaling = {
"rope_type": "llama3",
"factor": factor,
"original_max_position_embeddings": 2048,
"low_freq_factor": 1,
"high_freq_factor": 4,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
is_bounded_by_factor = [
(default_inv_freq[idx] / factor) <= llama3_inv_freq_value <= default_inv_freq[idx]
for idx, llama3_inv_freq_value in enumerate(inv_freq)
]
self.assertTrue(all(is_bounded_by_factor))
# if we change `high_freq_factor` to a very high value, none is considered high-frequency -> ALL values will be
# scaled
config.rope_scaling = config.rope_scaling = {
"rope_type": "llama3",
"factor": factor,
"original_max_position_embeddings": 2048,
"low_freq_factor": 1,
"high_freq_factor": 1000,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
is_scaled = [yarn_inv_freq_value < default_inv_freq[idx] for idx, yarn_inv_freq_value in enumerate(inv_freq)]
self.assertTrue(all(is_scaled))
# Check 3: numerical snapshot to avoid regressions
config.rope_scaling = {
"rope_type": "llama3",
"factor": factor,
"original_max_position_embeddings": 2048,
"low_freq_factor": 1,
"high_freq_factor": 4,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
|