mp
commited on
Commit
•
19886f5
1
Parent(s):
0cb216c
adapted modeling code for HF with AutoModel and AutoConfig
Browse files- modeling_pharia.py +17 -1
modeling_pharia.py
CHANGED
@@ -13,6 +13,8 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
13 |
from transformers.modeling_outputs import (
|
14 |
BaseModelOutputWithPast,
|
15 |
)
|
|
|
|
|
16 |
from transformers.modeling_utils import PreTrainedModel
|
17 |
try:
|
18 |
from flash_attn.flash_attn_interface import flash_attn_func
|
@@ -22,6 +24,7 @@ except Exception as e:
|
|
22 |
)
|
23 |
flash_attn_func = None
|
24 |
|
|
|
25 |
|
26 |
class RotaryConfig():
|
27 |
def __init__(
|
@@ -1005,4 +1008,17 @@ class PhariaForEmbedding(PhariaPreTrainedModel):
|
|
1005 |
if input_was_string:
|
1006 |
all_embeddings = all_embeddings[0]
|
1007 |
|
1008 |
-
return all_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from transformers.modeling_outputs import (
|
14 |
BaseModelOutputWithPast,
|
15 |
)
|
16 |
+
from transformers import AutoConfig
|
17 |
+
from transformers import AutoModel
|
18 |
from transformers.modeling_utils import PreTrainedModel
|
19 |
try:
|
20 |
from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
24 |
)
|
25 |
flash_attn_func = None
|
26 |
|
27 |
+
PHARIAEMBED_TYPE = "phariaembed"
|
28 |
|
29 |
class RotaryConfig():
|
30 |
def __init__(
|
|
|
1008 |
if input_was_string:
|
1009 |
all_embeddings = all_embeddings[0]
|
1010 |
|
1011 |
+
return all_embeddings
|
1012 |
+
|
1013 |
+
|
1014 |
+
# registration for Autoconfig and auto class
|
1015 |
+
|
1016 |
+
AutoConfig.register(PHARIAEMBED_TYPE, PhariaConfig)
|
1017 |
+
|
1018 |
+
PhariaConfig.register_for_auto_class()
|
1019 |
+
|
1020 |
+
# registration for AutoModel and auto class
|
1021 |
+
|
1022 |
+
AutoModel.register(PhariaConfig, PhariaForEmbedding)
|
1023 |
+
|
1024 |
+
PhariaForEmbedding.register_for_auto_class("AutoModel")
|