yairschiff
commited on
Upload Caduceus
Browse files- model.safetensors +2 -2
- modeling_caduceus.py +3 -3
- modeling_rcps.py +1 -1
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:16d3acb10a57ce482dd0799e59fd8616b83ce414143b04d68d95a9ab8cd8180e
|
3 |
+
size 2173880
|
modeling_caduceus.py
CHANGED
@@ -158,7 +158,7 @@ class CaduceusMixerModel(nn.Module):
|
|
158 |
self.rcps = config.rcps
|
159 |
self.residual_in_fp32 = config.residual_in_fp32
|
160 |
|
161 |
-
self.embeddings =
|
162 |
|
163 |
# Mamba changes the order of residual and layer norm:
|
164 |
# Instead of LN -> Attn / MLP -> Add, we do:
|
@@ -377,12 +377,12 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
|
|
377 |
factory_kwargs = {"device": device, "dtype": dtype}
|
378 |
self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
|
379 |
if config.rcps:
|
380 |
-
self.lm_head =
|
381 |
complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
|
382 |
vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
|
383 |
true_dim=config.d_model,
|
384 |
dtype=dtype
|
385 |
-
)
|
386 |
else:
|
387 |
self.lm_head = nn.Linear(
|
388 |
config.d_model,
|
|
|
158 |
self.rcps = config.rcps
|
159 |
self.residual_in_fp32 = config.residual_in_fp32
|
160 |
|
161 |
+
self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
|
162 |
|
163 |
# Mamba changes the order of residual and layer norm:
|
164 |
# Instead of LN -> Attn / MLP -> Add, we do:
|
|
|
377 |
factory_kwargs = {"device": device, "dtype": dtype}
|
378 |
self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
|
379 |
if config.rcps:
|
380 |
+
self.lm_head = RCPSLMHead(
|
381 |
complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
|
382 |
vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
|
383 |
true_dim=config.d_model,
|
384 |
dtype=dtype
|
385 |
+
)
|
386 |
else:
|
387 |
self.lm_head = nn.Linear(
|
388 |
config.d_model,
|
modeling_rcps.py
CHANGED
@@ -144,7 +144,7 @@ class RCPSMambaBlock(nn.Module):
|
|
144 |
super().__init__()
|
145 |
self.residual_in_fp32 = residual_in_fp32
|
146 |
self.fused_add_norm = fused_add_norm
|
147 |
-
self.mixer =
|
148 |
norm_f = norm_cls(dim)
|
149 |
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
|
150 |
|
|
|
144 |
super().__init__()
|
145 |
self.residual_in_fp32 = residual_in_fp32
|
146 |
self.fused_add_norm = fused_add_norm
|
147 |
+
self.mixer = RCPSWrapper(mixer_cls(dim))
|
148 |
norm_f = norm_cls(dim)
|
149 |
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
|
150 |
|