DNABERT_save / examples /load_model_test.py
nancyH's picture
Upload folder using huggingface_hub
ab6c03c verified
import os
import torch
from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer
import argparse
# Define MODEL_CLASSES as it's required by your loadmodel function
MODEL_CLASSES = {
"dna": (BertConfig, BertForMaskedLM, DNATokenizer),
# ... (other classes omitted for brevity)
}
def loadmodel(model_dir):
config_class, model_class, tokenizer_class = MODEL_CLASSES['dna'] # Changed 'DNA' to 'dna' for Python keys
print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}")
# 1. Load Configuration
config = config_class.from_pretrained(
model_dir,
cache_dir = None,
)
# 2. Load Model Weights
# NOTE: Since you are extracting embeddings, we should use BertModel, not BertForMaskedLM
# BertModel is the base transformer without the MLM head.
base_model_class = BertModel if model_class == BertForMaskedLM else model_class
model = base_model_class.from_pretrained(
model_dir,
from_tf=bool(".ckpt" in model_dir),
config=config,
cache_dir= None,
)
# 3. Set Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval() # Set model to evaluation mode
print(f"Model loaded onto device: {device}")
# 4. Load Tokenizer (using custom environment variables)
#tokenizer_class.vocab_files_names = {"vocab_file": os.getenv("VOCAB_NAME")}
#tokenizer_class.pretrained_vocab_files_map = {"vocab_file": {'dna': os.getenv("VOCAB_PATH")}} # Use 'dna' key
tokenizer = tokenizer_class.from_pretrained(model_dir)
print(f"Tokenizer vocabulary size: {len(tokenizer)}")
return config, model, tokenizer
# --- Main Call ---
# Use the environment variable set in the shell as the model directory
parser = argparse.ArgumentParser()
parser.add_argument("--MODEL_DIR", type=str, required=True)
args = parser.parse_args()
model_dir = args.MODEL_DIR
if model_dir != "/path/to/default":
config, model, tokenizer = loadmodel(model_dir)
print("Model and Tokenizer loaded successfully.")
embedding_layer = model.get_input_embeddings()
print(embedding_layer.weight.shape)
seq = "ACGTACGTACGT"
tokens = tokenizer.tokenize(" ".join([seq[i:i+6] for i in range(len(seq)-5)]))
print(tokens[:10])
else:
print("Error: MODEL_DIR environment variable was not set.")