Spaces:
Runtime error
Runtime error
import sys | |
import os | |
os.environ["HF_ENDPOINT"]="https://hf-mirror.com" | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) | |
import argparse | |
import torch | |
import re | |
import json | |
import os | |
import warnings | |
import pandas as pd | |
import torch.nn as nn | |
from tqdm import tqdm | |
from torchmetrics.classification import Accuracy, Recall, Precision, MatthewsCorrCoef, AUROC, F1Score, MatthewsCorrCoef | |
from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryMatthewsCorrCoef, BinaryF1Score | |
from torchmetrics.regression import SpearmanCorrCoef | |
from transformers import EsmTokenizer, EsmModel, BertModel, BertTokenizer | |
from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer, AutoModelForMaskedLM, AutoModel | |
from transformers import logging | |
from datasets import load_dataset | |
from torch.utils.data import DataLoader | |
# from utils.data_utils import BatchSampler | |
# from utils.metrics import MultilabelF1Max | |
# from models.adapter_mdoel import AdapterModel | |
from data.batch_sampler import BatchSampler | |
from training.metrics import MultilabelF1Max | |
from models.adapter_model import AdapterModel | |
from models.lora_model import LoraModel | |
from peft import PeftModel | |
from typing import Dict, Any, Union, Tuple | |
from data.dataloader import prepare_dataloaders | |
from datetime import datetime | |
# ignore warning information | |
logging.set_verbosity_error() | |
warnings.filterwarnings("ignore") | |
def evaluate(model, plm_model, metrics, dataloader, loss_function, device=None): | |
total_loss = 0 | |
total_samples = len(dataloader.dataset) | |
print(f"Total samples: {total_samples}") | |
epoch_iterator = tqdm(dataloader) | |
pred_labels = [] | |
for i, batch in enumerate(epoch_iterator, 1): | |
for k, v in batch.items(): | |
batch[k] = v.to(device) | |
label = batch["label"] | |
logits = model(plm_model, batch) | |
pred_labels.extend(logits.argmax(dim=1).cpu().numpy()) | |
for metric_name, metric in metrics_dict.items(): | |
if args.problem_type == 'regression' and args.num_labels == 1: | |
loss = loss_function(logits.squeeze(), label.squeeze()) | |
metric(logits.squeeze(), label.squeeze()) | |
elif args.problem_type == 'multi_label_classification': | |
loss = loss_function(logits, label.float()) | |
metric(logits, label) | |
else: | |
loss = loss_function(logits, label) | |
metric(torch.argmax(logits, 1), label) | |
total_loss += loss.item() * len(label) | |
epoch_iterator.set_postfix(eval_loss=loss.item()) | |
epoch_loss = total_loss / len(dataloader.dataset) | |
for k, v in metrics.items(): | |
metrics[k] = [v.compute().item()] | |
print(f"{k}: {metrics[k][0]}") | |
metrics['loss'] = [epoch_loss] | |
return metrics, pred_labels | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
# model params | |
parser.add_argument('--eval_method', type=str, default=None, help='evaluation method') | |
parser.add_argument('--hidden_size', type=int, default=None, help='embedding hidden size of the model') | |
parser.add_argument('--num_attention_head', type=int, default=8, help='number of attention heads') | |
parser.add_argument('--attention_probs_dropout', type=float, default=0, help='attention probs dropout prob') | |
parser.add_argument('--plm_model', type=str, default='facebook/esm2_t33_650M_UR50D', help='esm model name') | |
parser.add_argument('--num_labels', type=int, default=2, help='number of labels') | |
parser.add_argument('--pooling_method', type=str, default='mean', help='pooling method') | |
parser.add_argument('--pooling_dropout', type=float, default=0.25, help='pooling dropout') | |
# dataset | |
parser.add_argument('--dataset', type=str, default=None, help='dataset name') | |
parser.add_argument('--problem_type', type=str, default=None, help='problem type') | |
parser.add_argument('--test_file', type=str, default=None, help='test file') | |
parser.add_argument('--split', type=str, default=None, help='split name in Huggingface') | |
parser.add_argument('--test_result_dir', type=str, default=None, help='test result directory') | |
parser.add_argument('--metrics', type=str, default=None, help='computation metrics') | |
parser.add_argument('--num_workers', type=int, default=4, help='number of workers') | |
parser.add_argument('--max_seq_len', type=int, default=None, help='max sequence length') | |
parser.add_argument('--batch_size', type=int, default=None, help='batch size for fixed batch size') | |
parser.add_argument('--batch_token', type=int, default=10000, help='max number of token per batch') | |
parser.add_argument('--use_foldseek', action='store_true', help='use foldseek') | |
parser.add_argument('--use_ss8', action='store_true', help='use ss8') | |
# model path | |
parser.add_argument('--output_model_name', type=str, default=None, help='model name') | |
parser.add_argument('--output_root', default="result", help='root directory to save trained models') | |
parser.add_argument('--output_dir', default=None, help='directory to save trained models') | |
parser.add_argument('--model_path', default=None, help='model path directly') | |
parser.add_argument('--structure_seq', type=str, default="", help='structure sequence') | |
parser.add_argument('--training_method', type=str, default="freeze", help='training method') | |
args = parser.parse_args() | |
if 'foldseek_seq' in args.structure_seq: | |
args.use_foldseek = True | |
print("Enabled foldseek_seq based on structure_seq parameter") | |
if 'ss8_seq' in args.structure_seq: | |
args.use_ss8 = True | |
print("Enabled ss8_seq based on structure_seq parameter") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
os.makedirs(args.test_result_dir, exist_ok=True) | |
# build tokenizer and protein language model | |
if "esm" in args.plm_model: | |
tokenizer = EsmTokenizer.from_pretrained(args.plm_model) | |
plm_model = EsmModel.from_pretrained(args.plm_model) | |
args.hidden_size = plm_model.config.hidden_size | |
elif "bert" in args.plm_model: | |
tokenizer = BertTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
plm_model = BertModel.from_pretrained(args.plm_model) | |
args.hidden_size = plm_model.config.hidden_size | |
elif "prot_t5" in args.plm_model: | |
tokenizer = T5Tokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
plm_model = T5EncoderModel.from_pretrained(args.plm_model) | |
args.hidden_size = plm_model.config.d_model | |
elif "ankh" in args.plm_model: | |
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
plm_model = T5EncoderModel.from_pretrained(args.plm_model) | |
args.hidden_size = plm_model.config.d_model | |
elif "ProSST" in args.plm_model: | |
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model) | |
args.hidden_size = plm_model.config.hidden_size | |
elif "Prime" in args.plm_model: | |
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model) | |
args.hidden_size = plm_model.config.hidden_size | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(args.plm_model) | |
plm_model = AutoModel.from_pretrained(args.plm_model).to(device).eval() | |
args.hidden_size = plm_model.config.hidden_size | |
args.vocab_size = plm_model.config.vocab_size | |
# Define metric configurations | |
metric_configs = { | |
'accuracy': { | |
'binary': BinaryAccuracy, | |
'multi': lambda: Accuracy(task="multiclass", num_classes=args.num_labels) | |
}, | |
'recall': { | |
'binary': BinaryRecall, | |
'multi': lambda: Recall(task="multiclass", num_classes=args.num_labels) | |
}, | |
'precision': { | |
'binary': BinaryPrecision, | |
'multi': lambda: Precision(task="multiclass", num_classes=args.num_labels) | |
}, | |
'f1': { | |
'binary': BinaryF1Score, | |
'multi': lambda: F1Score(task="multiclass", num_classes=args.num_labels) | |
}, | |
'mcc': { | |
'binary': BinaryMatthewsCorrCoef, | |
'multi': lambda: MatthewsCorrCoef(task="multiclass", num_classes=args.num_labels) | |
}, | |
'auroc': { | |
'binary': BinaryAUROC, | |
'multi': lambda: AUROC(task="multiclass", num_classes=args.num_labels) | |
}, | |
'f1_max': { | |
'any': lambda: MultilabelF1Max(num_labels=args.num_labels) | |
}, | |
'spearman_corr': { | |
'any': SpearmanCorrCoef | |
} | |
} | |
# Initialize metrics dictionary | |
metrics_dict = {} | |
args.metrics = args.metrics.split(',') | |
# Create metrics based on configurations | |
for metric_name in args.metrics: | |
if metric_name not in metric_configs: | |
raise ValueError(f"Invalid metric: {metric_name}") | |
config = metric_configs[metric_name] | |
if 'any' in config: | |
metrics_dict[metric_name] = config['any']() | |
else: | |
metrics_dict[metric_name] = (config['binary']() if args.num_labels == 2 | |
else config['multi']()) | |
# Move metric to device | |
metrics_dict[metric_name].to(device) | |
# load adapter model | |
print("---------- Load Model ----------") | |
# model = AdapterModel(args) | |
# if args.model_path is not None: | |
# model_path = args.model_path | |
# else: | |
# model_path = f"{args.output_root}/{args.output_dir}/{args.output_model_name}" | |
if args.eval_method in ["full", "ses-adapter", "freeze"]: | |
model = AdapterModel(args) | |
elif args.eval_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: | |
model = LoraModel(args) | |
if args.model_path is not None: | |
model_path = args.model_path | |
else: | |
model_path = f"{args.output_root}/{args.output_dir}/{args.output_model_name}" | |
if args.eval_method == "full": | |
model_weights = torch.load(model_path) | |
model.load_state_dict(model_weights['model_state_dict']) | |
plm_model.load_state_dict(model_weights['plm_state_dict']) | |
else: | |
model.load_state_dict(torch.load(model_path)) | |
model.to(device).eval() | |
if args.eval_method == 'plm-lora': | |
lora_path = model_path.replace(".pt", "_lora") | |
plm_model = PeftModel.from_pretrained(plm_model,lora_path) | |
plm_model = plm_model.merge_and_unload() | |
elif args.eval_method == 'plm-qlora': | |
lora_path = model_path.replace(".pt", "_qlora") | |
plm_model = PeftModel.from_pretrained(plm_model,lora_path) | |
plm_model = plm_model.merge_and_unload() | |
elif args.eval_method == "plm-dora": | |
dora_path = model_path.replace(".pt", "_dora") | |
plm_model = PeftModel.from_pretrained(plm_model, dora_path) | |
plm_model = plm_model.merge_and_unload() | |
elif args.eval_method == "plm-adalora": | |
adalora_path = model_path.replace(".pt", "_adalora") | |
plm_model = PeftModel.from_pretrained(plm_model, adalora_path) | |
plm_model = plm_model.merge_and_unload() | |
elif args.eval_method == "plm-ia3": | |
ia3_path = model_path.replace(".pt", "_ia3") | |
plm_model = PeftModel.from_pretrained(plm_model, ia3_path) | |
plm_model = plm_model.merge_and_unload() | |
plm_model.to(device).eval() | |
def param_num(model): | |
total = sum([param.numel() for param in model.parameters() if param.requires_grad]) | |
num_M = total/1e6 | |
if num_M >= 1000: | |
return "Number of parameter: %.2fB" % (num_M/1e3) | |
else: | |
return "Number of parameter: %.2fM" % (num_M) | |
print(param_num(model)) | |
def collate_fn(examples): | |
aa_seqs, labels = [], [] | |
if args.use_foldseek: | |
foldseek_seqs = [] | |
if args.use_ss8: | |
ss8_seqs = [] | |
prosst_stru_tokens = [] if "ProSST" in args.plm_model else None | |
for e in examples: | |
aa_seq = e["aa_seq"] | |
if args.use_foldseek: | |
foldseek_seq = e["foldseek_seq"] | |
if args.use_ss8: | |
ss8_seq = e["ss8_seq"] | |
if "ProSST" in args.plm_model and "prosst_stru_token" in e: | |
stru_token = e["prosst_stru_token"] | |
if isinstance(stru_token, str): | |
seq_clean = stru_token.strip("[]").replace(" ","") | |
tokens = list(map(int, seq_clean.split(','))) if seq_clean else [] | |
elif isinstance(stru_token, (list, tuple)): | |
tokens = [int(x) for x in stru_token] | |
else: | |
tokens = [] | |
prosst_stru_tokens.append(torch.tensor(tokens)) | |
if 'prot_bert' in args.plm_model or "prot_t5" in args.plm_model: | |
aa_seq = " ".join(list(aa_seq)) | |
aa_seq = re.sub(r"[UZOB]", "X", aa_seq) | |
if args.use_foldseek: | |
foldseek_seq = " ".join(list(foldseek_seq)) | |
if args.use_ss8: | |
ss8_seq = " ".join(list(ss8_seq)) | |
elif 'ankh' in args.plm_model: | |
aa_seq = list(aa_seq) | |
if args.use_foldseek: | |
foldseek_seq = list(foldseek_seq) | |
if args.use_ss8: | |
ss8_seq = list(ss8_seq) | |
aa_seqs.append(aa_seq) | |
if args.use_foldseek: | |
foldseek_seqs.append(foldseek_seq) | |
if args.use_ss8: | |
ss8_seqs.append(ss8_seq) | |
labels.append(e["label"]) | |
if 'ankh' in args.plm_model: | |
aa_inputs = tokenizer.batch_encode_plus(aa_seqs, add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt") | |
if args.use_foldseek: | |
foldseek_input_ids = tokenizer.batch_encode_plus(foldseek_seqs, add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt")["input_ids"] | |
if args.use_ss8: | |
ss8_input_ids = tokenizer.batch_encode_plus(ss8_seqs, add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt")["input_ids"] | |
else: | |
aa_inputs = tokenizer(aa_seqs, return_tensors="pt", padding=True, truncation=True) | |
if args.use_foldseek: | |
foldseek_input_ids = tokenizer(foldseek_seqs, return_tensors="pt", padding=True, truncation=True)["input_ids"] | |
if args.use_ss8: | |
ss8_input_ids = tokenizer(ss8_seqs, return_tensors="pt", padding=True, truncation=True)["input_ids"] | |
aa_input_ids = aa_inputs["input_ids"] | |
attention_mask = aa_inputs["attention_mask"] | |
if args.problem_type == 'regression': | |
labels = torch.as_tensor(labels, dtype=torch.float) | |
else: | |
labels = torch.as_tensor(labels, dtype=torch.long) | |
data_dict = { | |
"aa_seq_input_ids": aa_input_ids, | |
"aa_seq_attention_mask": attention_mask, | |
"label": labels | |
} | |
if "ProSST" in args.plm_model and prosst_stru_tokens: | |
aa_max_length = len(aa_input_ids[0]) | |
padded_tokens = [] | |
for tokens in prosst_stru_tokens: | |
if tokens is None or len(tokens) == 0: | |
padded_tokens.append([0] * aa_max_length) | |
else: | |
struct_sequence = tokens.tolist() | |
padded_tokens.append(struct_sequence + [0] * (aa_max_length - len(struct_sequence))) | |
data_dict["aa_seq_stru_tokens"] = torch.tensor(padded_tokens, dtype=torch.long) | |
if args.use_foldseek: | |
data_dict["foldseek_seq_input_ids"] = foldseek_input_ids | |
if args.use_ss8: | |
data_dict["ss8_seq_input_ids"] = ss8_input_ids | |
return data_dict | |
loss_function = nn.CrossEntropyLoss() | |
def process_data_line(data): | |
if args.problem_type == 'multi_label_classification': | |
label_list = data['label'].split(',') | |
data['label'] = [int(l) for l in label_list] | |
binary_list = [0] * args.num_labels | |
for index in data['label']: | |
binary_list[index] = 1 | |
data['label'] = binary_list | |
if args.max_seq_len is not None: | |
data["aa_seq"] = data["aa_seq"][:args.max_seq_len] | |
if args.use_foldseek: | |
data["foldseek_seq"] = data["foldseek_seq"][:args.max_seq_len] | |
if args.use_ss8: | |
data["ss8_seq"] = data["ss8_seq"][:args.max_seq_len] | |
# 如果是 ProSST 模型且有结构标记,也需要截断 | |
if "ProSST" in args.plm_model and "prosst_stru_token" in data: | |
# 结构标记可能是字符串或列表形式 | |
if isinstance(data["prosst_stru_token"], str): | |
pass | |
elif isinstance(data["prosst_stru_token"], (list, tuple)): | |
data["prosst_stru_token"] = data["prosst_stru_token"][:args.max_seq_len] | |
token_num = min(len(data["aa_seq"]), args.max_seq_len) | |
else: | |
token_num = len(data["aa_seq"]) | |
return data, token_num | |
# process dataset from json file | |
def process_dataset_from_json(file): | |
dataset, token_nums = [], [] | |
for l in open(file): | |
data = json.loads(l) | |
data, token_num = process_data_line(data) | |
dataset.append(data) | |
token_nums.append(token_num) | |
return dataset, token_nums | |
# process dataset from list | |
def process_dataset_from_list(data_list): | |
dataset, token_nums = [], [] | |
for l in data_list: | |
data, token_num = process_data_line(l) | |
dataset.append(data) | |
token_nums.append(token_num) | |
return dataset, token_nums | |
if args.test_file.endswith('json'): | |
test_dataset, test_token_num = process_dataset_from_json(args.test_file) | |
elif args.test_file.endswith('csv'): | |
test_dataset, test_token_num = process_dataset_from_list(load_dataset("csv", data_files=args.test_file)['train']) | |
if args.test_result_dir: | |
test_result_df = pd.read_csv(args.test_file) | |
elif '/' in args.test_file: # Huggingface dataset (only csv now) | |
raw_dataset = load_dataset(args.test_file) | |
# Using the chosen split first. | |
if args.split and args.split in raw_dataset: | |
split = args.split | |
elif 'test' in raw_dataset: | |
split = 'test' | |
elif 'validation' in raw_dataset: | |
split = 'validation' | |
elif 'train' in raw_dataset: | |
split = 'train' | |
else: | |
split = list(raw_dataset.keys())[0] | |
test_dataset, test_token_num = process_dataset_from_list(raw_dataset[split]) | |
if args.test_result_dir: | |
test_result_df = pd.DataFrame(raw_dataset[split]) | |
else: | |
raise ValueError("Invalid file format") | |
if args.batch_size is None: | |
if args.batch_token is None: | |
raise ValueError("batch_size or batch_token must be specified") | |
test_loader = DataLoader( | |
test_dataset, | |
num_workers=args.num_workers, | |
collate_fn=collate_fn, | |
batch_sampler=BatchSampler(test_token_num, args.batch_token, False) | |
) | |
else: | |
test_loader = DataLoader( | |
test_dataset, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers, | |
collate_fn=collate_fn, | |
shuffle=False | |
) | |
print("---------- Start Eval ----------") | |
with torch.no_grad(): | |
metric, pred_labels = evaluate(model, plm_model, metrics_dict, test_loader, loss_function, device) | |
if args.test_result_dir: | |
pd.DataFrame(metric).to_csv(f"{args.test_result_dir}/evaluation_metrics.csv", index=False) | |
test_result_df["pred_label"] = pred_labels | |
test_result_df.to_csv(f"{args.test_result_dir}/evaluation_result.csv", index=False) | |