Llama-3.1-70B / README.md
osanseviero's picture
Create README.md
8e8975b verified
|
raw
history blame
2.2 kB

This repository corresponds to the base Llama 3.1 70B model. The model has the same model weight format, but does RoPE using per frequency scaling, hence requiring code changes for inference.

Here is a short term patch to make it generate properly

diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 5c0c57f3e..f94a4cb37 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -73,6 +73,29 @@ class LlamaRMSNorm(nn.Module):
 
 ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
 
+def apply_scaling(freqs: torch.Tensor):
+    # Values obtained from grid search
+    scale_factor = 8
+    low_freq_factor = 1
+    high_freq_factor = 4
+    old_context_len = 8192  # original llama3 length
+
+    low_freq_wavelen = old_context_len / low_freq_factor
+    high_freq_wavelen = old_context_len / high_freq_factor
+    new_freqs = []
+    for freq in freqs:
+        wavelen = 2 * math.pi / freq
+        if wavelen < high_freq_wavelen:
+            new_freqs.append(freq)
+        elif wavelen > low_freq_wavelen:
+            new_freqs.append(freq / scale_factor)
+        else:
+            assert low_freq_wavelen != high_freq_wavelen
+            smooth = (old_context_len / wavelen - low_freq_factor) / (
+                high_freq_factor - low_freq_factor
+            )
+            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
+    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
 
 class LlamaRotaryEmbedding(nn.Module):
     def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
@@ -82,6 +105,7 @@ class LlamaRotaryEmbedding(nn.Module):
         self.max_position_embeddings = max_position_embeddings
         self.base = base
         inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
+        inv_freq = apply_scaling(inv_freq)
         self.register_buffer("inv_freq", inv_freq, persistent=False)
         # For BC we register cos and sin cached
         self.max_seq_len_cached = max_position_embeddings