mp commited on
Commit
19886f5
1 Parent(s): 0cb216c

adapted modeling code for HF with AutoModel and AutoConfig

Browse files
Files changed (1) hide show
  1. modeling_pharia.py +17 -1
modeling_pharia.py CHANGED
@@ -13,6 +13,8 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
13
  from transformers.modeling_outputs import (
14
  BaseModelOutputWithPast,
15
  )
 
 
16
  from transformers.modeling_utils import PreTrainedModel
17
  try:
18
  from flash_attn.flash_attn_interface import flash_attn_func
@@ -22,6 +24,7 @@ except Exception as e:
22
  )
23
  flash_attn_func = None
24
 
 
25
 
26
  class RotaryConfig():
27
  def __init__(
@@ -1005,4 +1008,17 @@ class PhariaForEmbedding(PhariaPreTrainedModel):
1005
  if input_was_string:
1006
  all_embeddings = all_embeddings[0]
1007
 
1008
- return all_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from transformers.modeling_outputs import (
14
  BaseModelOutputWithPast,
15
  )
16
+ from transformers import AutoConfig
17
+ from transformers import AutoModel
18
  from transformers.modeling_utils import PreTrainedModel
19
  try:
20
  from flash_attn.flash_attn_interface import flash_attn_func
 
24
  )
25
  flash_attn_func = None
26
 
27
+ PHARIAEMBED_TYPE = "phariaembed"
28
 
29
  class RotaryConfig():
30
  def __init__(
 
1008
  if input_was_string:
1009
  all_embeddings = all_embeddings[0]
1010
 
1011
+ return all_embeddings
1012
+
1013
+
1014
+ # registration for Autoconfig and auto class
1015
+
1016
+ AutoConfig.register(PHARIAEMBED_TYPE, PhariaConfig)
1017
+
1018
+ PhariaConfig.register_for_auto_class()
1019
+
1020
+ # registration for AutoModel and auto class
1021
+
1022
+ AutoModel.register(PhariaConfig, PhariaForEmbedding)
1023
+
1024
+ PhariaForEmbedding.register_for_auto_class("AutoModel")