LaBSe_GCP / estimator.py
billyhines's picture
[CONFIG] - LaBSE model fine-tuned on ESCO data, 100K sample and Vacancy data - vector_match_202404221015_202404221015
da968a9 verified
from prime_radiant.configs.pr_config import PRConfig
from prime_radiant.models import MODEL_CLASSES
from prime_radiant.models.pr_model import PRModel
class LaBSeModelConfig(PRConfig):
"""
Configuration class for LaBSeModel.
"""
model_type = "LaBSe"
def __init__(self, **kwargs):
# Validate and initialize the config
super(LaBSeModelConfig, self).__init__(**kwargs)
def validate(self, **kwargs: dict):
# Perform validation of the parent
super(LaBSeModelConfig, self).validate(**kwargs)
# No further validation since the base model does not take any more
# custom configuration
class LaBSeModel(PRModel):
"""
LaBSeModel class for using LaBSe-based models.
"""
config_class = LaBSeModelConfig
def __init__(self, config: LaBSeModelConfig, ):
# Initialize the PRModel superclass
super().__init__(config=config)
# Get the appropriate model class based on the configuration
model_class = MODEL_CLASSES[config.pr_model_name]
# Load the SentenceTransformer model
# This should be improved by using the actual model from model_class
# instead of the `.sbert()` method which doesn't exist
self.model = model_class.model.sbert(model_class.checkpoint)
# Set the maximum sequence length
self.model.max_seq_length = 512
# Get the tokenizer associated with the model
self.tokenizer = self.model.tokenizer
def predict(self, text: str):
# Encode a single text
return self.model.encode(text)
def predict_batch(self, text_list: list):
# Encode a list of texts
return self.model.encode(text_list)
def push_to_hub(self, only_base_model_flag=True, *args, **kwargs):
"""
Push the base model to the hub
We need to use a different way of pushing it so that we can use the
base model to instantiate other models like ROC.
"""
# Remove only_base_model from kwargs
if "only_base_model" in kwargs:
del kwargs["only_base_model"]
# Call the superclass method with the modified arguments
return super().push_to_hub(only_base_model=only_base_model_flag, *args, **kwargs)