from prime_radiant.configs.pr_config import PRConfig from prime_radiant.models import MODEL_CLASSES from prime_radiant.models.pr_model import PRModel class LaBSeModelConfig(PRConfig): """ Configuration class for LaBSeModel. """ model_type = "LaBSe" def __init__(self, **kwargs): # Validate and initialize the config super(LaBSeModelConfig, self).__init__(**kwargs) def validate(self, **kwargs: dict): # Perform validation of the parent super(LaBSeModelConfig, self).validate(**kwargs) # No further validation since the base model does not take any more # custom configuration class LaBSeModel(PRModel): """ LaBSeModel class for using LaBSe-based models. """ config_class = LaBSeModelConfig def __init__(self, config: LaBSeModelConfig, ): # Initialize the PRModel superclass super().__init__(config=config) # Get the appropriate model class based on the configuration model_class = MODEL_CLASSES[config.pr_model_name] # Load the SentenceTransformer model # This should be improved by using the actual model from model_class # instead of the `.sbert()` method which doesn't exist self.model = model_class.model.sbert(model_class.checkpoint) # Set the maximum sequence length self.model.max_seq_length = 512 # Get the tokenizer associated with the model self.tokenizer = self.model.tokenizer def predict(self, text: str): # Encode a single text return self.model.encode(text) def predict_batch(self, text_list: list): # Encode a list of texts return self.model.encode(text_list) def push_to_hub(self, only_base_model_flag=True, *args, **kwargs): """ Push the base model to the hub We need to use a different way of pushing it so that we can use the base model to instantiate other models like ROC. """ # Remove only_base_model from kwargs if "only_base_model" in kwargs: del kwargs["only_base_model"] # Call the superclass method with the modified arguments return super().push_to_hub(only_base_model=only_base_model_flag, *args, **kwargs)