|
|
|
import os |
|
import fuson_plm.training.config as config |
|
|
|
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY |
|
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES |
|
|
|
import torch |
|
import tqdm |
|
import numpy as np |
|
import pandas as pd |
|
import logging |
|
from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy |
|
from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders |
|
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders |
|
from fuson_plm.training.train import test |
|
|
|
def load_esm2_maskedlm(esm_type, device=None): |
|
""" |
|
Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D) |
|
""" |
|
|
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
|
|
|
if device is None: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
model = AutoModelForMaskedLM.from_pretrained(f"facebook/{esm_type}") |
|
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}") |
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
return model, tokenizer, device |
|
|
|
|
|
def val(model, tokenizer, val_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'): |
|
""" |
|
Same method as val, just for running the val set |
|
""" |
|
model.to(device) |
|
model.eval() |
|
total_val_loss = 0 |
|
total_weighted_val_loss = 0 |
|
total_val_masked_tokens = 0 |
|
|
|
with torch.no_grad(): |
|
|
|
with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Val Batch', leave=True, position=0) as tbar: |
|
for batch_idx, (inputs, prob) in tbar: |
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
prob = prob.to(device) |
|
|
|
|
|
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=mask_percentage) |
|
|
|
|
|
outputs = model(**masked_inputs) |
|
val_loss = outputs.loss |
|
|
|
|
|
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() |
|
|
|
|
|
total_val_loss += val_loss.item() |
|
total_weighted_val_loss += val_loss.item() * num_masked_tokens |
|
total_val_masked_tokens += num_masked_tokens |
|
|
|
|
|
n_val_batches = len(val_loader) |
|
avg_val_loss = total_val_loss / n_val_batches |
|
avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens |
|
val_perplexity = np.exp(avg_weighted_val_loss) |
|
|
|
log_update(f"\nval results:\nTotal batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}") |
|
|
|
|
|
val_stats_df = pd.DataFrame(data={ |
|
"total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss], |
|
"avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss], |
|
"val_perplexity": [val_perplexity] |
|
}) |
|
val_stats_df.to_csv(f"{checkpoint_dir}/val_results.csv",index=False) |
|
|
|
def main(): |
|
|
|
model, tokenizer, device = load_esm2_maskedlm("esm2_t33_650M_UR50D") |
|
|
|
checkpoint_dir = f"checkpoints/esm2_t33_650M_UR50D_{config.PROBABILITY_TYPE}_mask{config.MASK_PERCENTAGE}" |
|
os.makedirs(checkpoint_dir,exist_ok=True) |
|
|
|
with open_logfile(f"{checkpoint_dir}/evaluate_val_test_esm.txt"): |
|
|
|
print_configpy(config) |
|
|
|
|
|
val_loader = get_dataloader(config.VAL_PATH, tokenizer, |
|
probability_type=config.PROBABILITY_TYPE, |
|
batch_size=config.BATCH_SIZE, |
|
max_length=config.MAX_LENGTH, shuffle=False) |
|
|
|
|
|
val(model, tokenizer, val_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir) |
|
|
|
|
|
|
|
|
|
test_loader = get_dataloader(config.TEST_PATH, |
|
tokenizer, |
|
probability_type=config.PROBABILITY_TYPE, |
|
batch_size=config.BATCH_SIZE, |
|
max_length=config.MAX_LENGTH, shuffle=False) |
|
|
|
|
|
|
|
test(model, tokenizer, test_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir) |
|
|
|
if __name__ == "__main__": |
|
main() |