yairschiff
commited on
Enable mambav2 compat
Browse files- modeling_caduceus.py +33 -21
modeling_caduceus.py
CHANGED
@@ -2,21 +2,29 @@
|
|
2 |
|
3 |
"""
|
4 |
|
|
|
5 |
import math
|
6 |
from functools import partial
|
7 |
from typing import Optional, Tuple, Union
|
8 |
|
9 |
import torch
|
10 |
-
from mamba_ssm.modules.mamba_simple import Mamba
|
|
|
|
|
|
|
|
|
11 |
from torch import nn
|
12 |
from torch.nn import functional as F
|
13 |
from transformers import PreTrainedModel
|
14 |
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
|
15 |
|
16 |
try:
|
17 |
-
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
18 |
except ImportError:
|
19 |
-
|
|
|
|
|
|
|
20 |
|
21 |
from .configuration_caduceus import CaduceusConfig
|
22 |
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
|
@@ -54,13 +62,24 @@ def create_block(
|
|
54 |
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
55 |
)
|
56 |
block_cls = RCPSMambaBlock if rcps else Block
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
block.layer_idx = layer_idx
|
65 |
return block
|
66 |
|
@@ -264,15 +283,15 @@ def cross_entropy(logits, y, ignore_index=-100):
|
|
264 |
return F.cross_entropy(logits, y, ignore_index=ignore_index)
|
265 |
|
266 |
|
267 |
-
def weighted_cross_entropy(logits, y,
|
268 |
"""Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
|
269 |
logits = logits.view(-1, logits.shape[-1])
|
270 |
y = y.view(-1)
|
271 |
ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
|
272 |
-
|
273 |
-
|
274 |
# TODO: Follows GPN implementation, but should we remove weight normalization?
|
275 |
-
return (ce * (
|
276 |
|
277 |
|
278 |
class CaduceusPreTrainedModel(PreTrainedModel):
|
@@ -497,12 +516,6 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
497 |
|
498 |
# Initialize weights and apply final processing
|
499 |
self.post_init()
|
500 |
-
self.init_scorer()
|
501 |
-
|
502 |
-
def init_scorer(self, initializer_range=0.02):
|
503 |
-
initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \
|
504 |
-
if self.config.initializer_cfg is not None else initializer_range
|
505 |
-
self.score.weight.data.normal_(std=initializer_range)
|
506 |
|
507 |
def get_input_embeddings(self):
|
508 |
return self.caduceus.backbone.embeddings.word_embeddings
|
@@ -530,7 +543,6 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
|
|
530 |
labels: Optional[torch.LongTensor] = None,
|
531 |
output_hidden_states: Optional[bool] = None,
|
532 |
return_dict: Optional[bool] = None,
|
533 |
-
**kwargs,
|
534 |
) -> Union[Tuple, SequenceClassifierOutput]:
|
535 |
r"""
|
536 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
2 |
|
3 |
"""
|
4 |
|
5 |
+
import inspect
|
6 |
import math
|
7 |
from functools import partial
|
8 |
from typing import Optional, Tuple, Union
|
9 |
|
10 |
import torch
|
11 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
12 |
+
try:
|
13 |
+
from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
|
14 |
+
except ImportError:
|
15 |
+
from mamba_ssm.modules.block import Block # mambav2 file structure
|
16 |
from torch import nn
|
17 |
from torch.nn import functional as F
|
18 |
from transformers import PreTrainedModel
|
19 |
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
|
20 |
|
21 |
try:
|
22 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
|
23 |
except ImportError:
|
24 |
+
try:
|
25 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
|
26 |
+
except ImportError:
|
27 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
28 |
|
29 |
from .configuration_caduceus import CaduceusConfig
|
30 |
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
|
|
|
62 |
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
63 |
)
|
64 |
block_cls = RCPSMambaBlock if rcps else Block
|
65 |
+
# mambav2 compatibility
|
66 |
+
if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
|
67 |
+
block = block_cls(
|
68 |
+
d_model,
|
69 |
+
mixer_cls,
|
70 |
+
mlp_cls=nn.Identity,
|
71 |
+
norm_cls=norm_cls,
|
72 |
+
fused_add_norm=fused_add_norm,
|
73 |
+
residual_in_fp32=residual_in_fp32,
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
block = block_cls(
|
77 |
+
d_model,
|
78 |
+
mixer_cls,
|
79 |
+
norm_cls=norm_cls,
|
80 |
+
fused_add_norm=fused_add_norm,
|
81 |
+
residual_in_fp32=residual_in_fp32,
|
82 |
+
)
|
83 |
block.layer_idx = layer_idx
|
84 |
return block
|
85 |
|
|
|
283 |
return F.cross_entropy(logits, y, ignore_index=ignore_index)
|
284 |
|
285 |
|
286 |
+
def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
|
287 |
"""Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
|
288 |
logits = logits.view(-1, logits.shape[-1])
|
289 |
y = y.view(-1)
|
290 |
ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
|
291 |
+
loss_weights = loss_weights.view(-1)
|
292 |
+
loss_weights[y == ignore_index] = 0.0
|
293 |
# TODO: Follows GPN implementation, but should we remove weight normalization?
|
294 |
+
return (ce * (loss_weights / loss_weights.sum())).sum()
|
295 |
|
296 |
|
297 |
class CaduceusPreTrainedModel(PreTrainedModel):
|
|
|
516 |
|
517 |
# Initialize weights and apply final processing
|
518 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
|
520 |
def get_input_embeddings(self):
|
521 |
return self.caduceus.backbone.embeddings.word_embeddings
|
|
|
543 |
labels: Optional[torch.LongTensor] = None,
|
544 |
output_hidden_states: Optional[bool] = None,
|
545 |
return_dict: Optional[bool] = None,
|
|
|
546 |
) -> Union[Tuple, SequenceClassifierOutput]:
|
547 |
r"""
|
548 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|