Markus28 commited on
Commit
e209593
1 Parent(s): 2b23340

Try to subclass PretrainedModel

Browse files
Files changed (1) hide show
  1. modeling_bert.py +1 -2
modeling_bert.py CHANGED
@@ -22,7 +22,7 @@ import torch
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
  from einops import rearrange
25
- from transformers import PretrainedModel
26
  from .configuration_bert import JinaBertConfig
27
  from transformers.models.bert.modeling_bert import (
28
  BaseModelOutputWithPoolingAndCrossAttentions,
@@ -39,7 +39,6 @@ from flash_attn.modules.block import Block
39
  from flash_attn.modules.embedding import BertEmbeddings
40
  from flash_attn.modules.mha import MHA
41
  from flash_attn.modules.mlp import FusedMLP, Mlp
42
- from flash_attn.utils.pretrained import state_dict_from_pretrained
43
 
44
  try:
45
  from flash_attn.ops.fused_dense import FusedDense
 
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
  from einops import rearrange
25
+ from transformers.modeling_utils import PretrainedModel
26
  from .configuration_bert import JinaBertConfig
27
  from transformers.models.bert.modeling_bert import (
28
  BaseModelOutputWithPoolingAndCrossAttentions,
 
39
  from flash_attn.modules.embedding import BertEmbeddings
40
  from flash_attn.modules.mha import MHA
41
  from flash_attn.modules.mlp import FusedMLP, Mlp
 
42
 
43
  try:
44
  from flash_attn.ops.fused_dense import FusedDense