[CONFIG] - LaBSE model fine-tuned on ESCO data, 100K sample and Vacancy data - vector_match_202404221015_202404221015
da968a9
verified
from prime_radiant.configs.pr_config import PRConfig | |
from prime_radiant.models import MODEL_CLASSES | |
from prime_radiant.models.pr_model import PRModel | |
class LaBSeModelConfig(PRConfig): | |
""" | |
Configuration class for LaBSeModel. | |
""" | |
model_type = "LaBSe" | |
def __init__(self, **kwargs): | |
# Validate and initialize the config | |
super(LaBSeModelConfig, self).__init__(**kwargs) | |
def validate(self, **kwargs: dict): | |
# Perform validation of the parent | |
super(LaBSeModelConfig, self).validate(**kwargs) | |
# No further validation since the base model does not take any more | |
# custom configuration | |
class LaBSeModel(PRModel): | |
""" | |
LaBSeModel class for using LaBSe-based models. | |
""" | |
config_class = LaBSeModelConfig | |
def __init__(self, config: LaBSeModelConfig, ): | |
# Initialize the PRModel superclass | |
super().__init__(config=config) | |
# Get the appropriate model class based on the configuration | |
model_class = MODEL_CLASSES[config.pr_model_name] | |
# Load the SentenceTransformer model | |
# This should be improved by using the actual model from model_class | |
# instead of the `.sbert()` method which doesn't exist | |
self.model = model_class.model.sbert(model_class.checkpoint) | |
# Set the maximum sequence length | |
self.model.max_seq_length = 512 | |
# Get the tokenizer associated with the model | |
self.tokenizer = self.model.tokenizer | |
def predict(self, text: str): | |
# Encode a single text | |
return self.model.encode(text) | |
def predict_batch(self, text_list: list): | |
# Encode a list of texts | |
return self.model.encode(text_list) | |
def push_to_hub(self, only_base_model_flag=True, *args, **kwargs): | |
""" | |
Push the base model to the hub | |
We need to use a different way of pushing it so that we can use the | |
base model to instantiate other models like ROC. | |
""" | |
# Remove only_base_model from kwargs | |
if "only_base_model" in kwargs: | |
del kwargs["only_base_model"] | |
# Call the superclass method with the modified arguments | |
return super().push_to_hub(only_base_model=only_base_model_flag, *args, **kwargs) |