nama-test4 / configuration.py
beny2000's picture
Upload model
0e956f2
raw
history blame contribute delete
856 Bytes
from transformers import PretrainedConfig
class SimilarityModelConfig(PretrainedConfig):
model_type = 'roberta'
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.embedding_model_config = kwargs.get("embedding_model_config")
self.score_model_config = kwargs.get("score_model_config")
self.weighting_function_config = kwargs.get("weighting_function_config")
nama_base = SimilarityModelConfig(
embedding_model_config={
"model_class": 'roberta',
"model_name":'roberta-base',
"pooling": 'pooler',
"normalize":True,
"d":128,
"prompt":'',
"device":'cpu',
"add_upper": True,
"upper_case":False
},
score_model_config={"alpha": 50},
weighting_function_config={"weighting_exponent": 0.5},
device="cpu",
)