from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast def get_tokenlizer(text_encoder_type): if not isinstance(text_encoder_type, str): # print("text_encoder_type is not a str") if hasattr(text_encoder_type, "text_encoder_type"): text_encoder_type = text_encoder_type.text_encoder_type elif text_encoder_type.get("text_encoder_type", False): text_encoder_type = text_encoder_type.get("text_encoder_type") else: raise ValueError( "Unknown type of text_encoder_type: {}".format(type(text_encoder_type)) ) print("final text_encoder_type: {}".format(text_encoder_type)) tokenizer = AutoTokenizer.from_pretrained(text_encoder_type) return tokenizer def get_pretrained_language_model(text_encoder_type): if text_encoder_type == "bert-base-uncased": return BertModel.from_pretrained(text_encoder_type) if text_encoder_type == "roberta-base": return RobertaModel.from_pretrained(text_encoder_type) raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))