Update model.py
Browse files
model.py
CHANGED
@@ -98,7 +98,7 @@ ALL_LAYERNORM_LAYERS.append(BharataiRMSNorm)
|
|
98 |
|
99 |
|
100 |
class BharataiRotaryEmbedding(nn.Module):
|
101 |
-
def __init__(self, dim, max_position_embeddings=
|
102 |
super().__init__()
|
103 |
|
104 |
self.dim = dim
|
@@ -136,7 +136,7 @@ class BharataiRotaryEmbedding(nn.Module):
|
|
136 |
class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
137 |
"""BharataiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
138 |
|
139 |
-
def __init__(self, dim, max_position_embeddings=
|
140 |
self.scaling_factor = scaling_factor
|
141 |
super().__init__(dim, max_position_embeddings, base, device)
|
142 |
|
@@ -155,7 +155,7 @@ class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
|
155 |
class BharataiDynamicNTKScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
156 |
"""BharataiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
157 |
|
158 |
-
def __init__(self, dim, max_position_embeddings=
|
159 |
self.scaling_factor = scaling_factor
|
160 |
super().__init__(dim, max_position_embeddings, base, device)
|
161 |
|
|
|
98 |
|
99 |
|
100 |
class BharataiRotaryEmbedding(nn.Module):
|
101 |
+
def __init__(self, dim, max_position_embeddings=16384, base=10000, device=None):
|
102 |
super().__init__()
|
103 |
|
104 |
self.dim = dim
|
|
|
136 |
class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
137 |
"""BharataiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
138 |
|
139 |
+
def __init__(self, dim, max_position_embeddings=16384, base=10000, device=None, scaling_factor=1.0):
|
140 |
self.scaling_factor = scaling_factor
|
141 |
super().__init__(dim, max_position_embeddings, base, device)
|
142 |
|
|
|
155 |
class BharataiDynamicNTKScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
156 |
"""BharataiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
157 |
|
158 |
+
def __init__(self, dim, max_position_embeddings=16384, base=10000, device=None, scaling_factor=1.0):
|
159 |
self.scaling_factor = scaling_factor
|
160 |
super().__init__(dim, max_position_embeddings, base, device)
|
161 |
|