| import os |
| import torch |
| from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer |
| import argparse |
|
|
| |
| MODEL_CLASSES = { |
| "dna": (BertConfig, BertForMaskedLM, DNATokenizer), |
| |
| } |
|
|
| def loadmodel(model_dir): |
| config_class, model_class, tokenizer_class = MODEL_CLASSES['dna'] |
| print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}") |
| |
| |
| config = config_class.from_pretrained( |
| model_dir, |
| cache_dir = None, |
| ) |
| |
| |
| |
| |
| base_model_class = BertModel if model_class == BertForMaskedLM else model_class |
| |
| model = base_model_class.from_pretrained( |
| model_dir, |
| from_tf=bool(".ckpt" in model_dir), |
| config=config, |
| cache_dir= None, |
| ) |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| model.eval() |
| print(f"Model loaded onto device: {device}") |
|
|
| |
| |
| |
| tokenizer = tokenizer_class.from_pretrained(model_dir) |
| print(f"Tokenizer vocabulary size: {len(tokenizer)}") |
| |
| return config, model, tokenizer |
|
|
| |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--MODEL_DIR", type=str, required=True) |
| args = parser.parse_args() |
|
|
| model_dir = args.MODEL_DIR |
|
|
| if model_dir != "/path/to/default": |
| config, model, tokenizer = loadmodel(model_dir) |
| print("Model and Tokenizer loaded successfully.") |
| |
| embedding_layer = model.get_input_embeddings() |
| print(embedding_layer.weight.shape) |
| |
|
|
| seq = "ACGTACGTACGT" |
| tokens = tokenizer.tokenize(" ".join([seq[i:i+6] for i in range(len(seq)-5)])) |
| print(tokens[:10]) |
| else: |
| print("Error: MODEL_DIR environment variable was not set.") |
|
|
|
|