File size: 856 Bytes
0e956f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from transformers import PretrainedConfig
class SimilarityModelConfig(PretrainedConfig):
model_type = 'roberta'
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.embedding_model_config = kwargs.get("embedding_model_config")
self.score_model_config = kwargs.get("score_model_config")
self.weighting_function_config = kwargs.get("weighting_function_config")
nama_base = SimilarityModelConfig(
embedding_model_config={
"model_class": 'roberta',
"model_name":'roberta-base',
"pooling": 'pooler',
"normalize":True,
"d":128,
"prompt":'',
"device":'cpu',
"add_upper": True,
"upper_case":False
},
score_model_config={"alpha": 50},
weighting_function_config={"weighting_exponent": 0.5},
device="cpu",
)
|