RLAnOxPeptide / feature_extract.py
chshan's picture
Update feature_extract.py
02b6e86 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import numpy as np
import random
import pandas as pd
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler # 导入 RobustScaler
import torch
from transformers import T5EncoderModel, T5Tokenizer
class ProtT5Model:
"""
从本地加载 ProtT5 模型。如果 finetuned_model_file 不为空,则加载微调后的权重(使用 strict=False)。
"""
def __init__(self, model_path, finetuned_model_file=None):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# 尝试加载本地文件,如果失败,transformers库可能会尝试从hub下载(取决于配置)
try:
self.tokenizer = T5Tokenizer.from_pretrained(model_path, do_lower_case=False, local_files_only=True)
self.model = T5EncoderModel.from_pretrained(model_path, local_files_only=True)
except OSError: # OSError: Can't load tokenizer for '...'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure '...' is the correct path to a directory containing all relevant files for a T5Tokenizer tokenizer.
print(f"警告: 无法从本地路径 {model_path} 加载ProtT5模型/分词器。尝试从HuggingFace Hub下载(如果transformers配置允许)。")
self.tokenizer = T5Tokenizer.from_pretrained(model_path.split('/')[-1] if '/' in model_path else model_path, do_lower_case=False) # 尝试使用模型名下载
self.model = T5EncoderModel.from_pretrained(model_path.split('/')[-1] if '/' in model_path else model_path)
if finetuned_model_file is not None and os.path.exists(finetuned_model_file):
try:
state_dict = torch.load(finetuned_model_file, map_location=self.device)
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
print(f"加载微调权重 {finetuned_model_file}:缺失键 {missing_keys}, 意外键 {unexpected_keys}")
except Exception as e:
print(f"加载微调权重 {finetuned_model_file} 失败: {e}")
self.model.to(self.device)
self.model.eval()
def encode(self, sequence):
if not sequence or not isinstance(sequence, str): # 增加对空序列或非字符串的检查
print(f"警告: ProtT5Model.encode 接收到无效序列: {sequence}")
# 返回一个零向量或根据需要处理错误
# 假设 ProtT5 输出维度为 1024 (embedding.shape[1])
# 假设序列处理后平均池化,所以返回 (1024,)
# 但 encode 返回的是 (seq_len, hidden_dim),所以这里返回一个模拟的短序列零嵌入
return np.zeros((1, 1024), dtype=np.float32) # (1, hidden_dim)
seq_spaced = " ".join(list(sequence)) # 修改变量名以避免覆盖外部seq
try:
encoded_input = self.tokenizer(seq_spaced, return_tensors='pt', padding=True, truncation=True, max_length=1022) # ProtT5通常最大长度1024,tokenized后可能更长
except Exception as e:
print(f"分词失败序列 '{sequence[:30]}...': {e}")
return np.zeros((1, 1024), dtype=np.float32)
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
with torch.no_grad():
try:
embedding = self.model(**encoded_input).last_hidden_state # (batch_size, seq_len, hidden_dim)
except Exception as e:
print(f"ProtT5模型推理失败序列 '{sequence[:30]}...': {e}")
return np.zeros((1, 1024), dtype=np.float32)
emb = embedding.squeeze(0).cpu().numpy() # (seq_len, hidden_dim)
if emb.shape[0] == 0: # 如果由于某种原因序列长度为0
return np.zeros((1, 1024), dtype=np.float32)
return emb
def load_fasta(fasta_file):
# (您的 load_fasta 实现)
sequences = []
try:
with open(fasta_file, 'r') as f:
current_seq_lines = []
for line in f:
line = line.strip()
if not line: continue
if line.startswith(">"):
if current_seq_lines: sequences.append("".join(current_seq_lines))
current_seq_lines = []
else: current_seq_lines.append(line)
if current_seq_lines: sequences.append("".join(current_seq_lines))
except FileNotFoundError: print(f"文件未找到: {fasta_file}"); return []
return sequences
def load_fasta_with_labels(fasta_file):
sequences, labels = [], []
try:
with open(fasta_file, 'r') as f:
current_seq_lines, current_label = [], None
for line in f:
line = line.strip()
if not line: continue
if line.startswith(">"):
if current_seq_lines:
sequences.append("".join(current_seq_lines))
labels.append(current_label if current_label is not None else 0) # Default label 0
current_seq_lines = []
current_label = int(line[1]) if len(line) > 1 and line[1] in ['0', '1'] else 0
else: current_seq_lines.append(line)
if current_seq_lines:
sequences.append("".join(current_seq_lines))
labels.append(current_label if current_label is not None else 0)
except FileNotFoundError: print(f"文件未找到: {fasta_file}"); return [],[]
return sequences, labels
def compute_amino_acid_composition(seq):
if not seq: return {aa: 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"}
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
seq_len = len(seq)
return {aa: seq.upper().count(aa) / seq_len for aa in amino_acids}
def compute_reducing_aa_ratio(seq):
if not seq: return 0.0
reducing = ['C', 'M', 'W']
return sum(seq.upper().count(aa) for aa in reducing) / len(seq) if len(seq) > 0 else 0.0
def compute_physicochemical_properties(seq):
if not seq or not all(c.upper() in "ACDEFGHIKLMNPQRSTVWYXUBZ" for c in seq): # ProteinAnalysis might fail on invalid chars
return 0.0, 0.0, 0.0 # Default values
try:
analysis = ProteinAnalysis(str(seq).upper().replace('X','A').replace('U','C').replace('B','N').replace('Z','Q')) # Replace non-standard with common ones for analysis
return analysis.gravy(), analysis.isoelectric_point(), analysis.molecular_weight()
except Exception: # Catch any error from ProteinAnalysis
return 0.0, 7.0, 110.0 * len(seq) # Rough defaults
def compute_electronic_features(seq):
if not seq: return 0.0, 0.0
electronegativity = {'A':1.8,'C':2.5,'D':3.0,'E':3.2,'F':2.8,'G':1.6,'H':2.4,'I':4.5,'K':3.0,'L':4.2,'M':4.5,'N':2.0,'P':3.5,'Q':3.5,'R':2.5,'S':1.8,'T':2.5,'V':4.0,'W':5.0,'Y':4.0}
values = [electronegativity.get(aa.upper(), 2.5) for aa in seq]
avg_val = sum(values) / len(values) if values else 2.5
return avg_val + 0.1, avg_val - 0.1
def compute_dimer_frequency(seq):
if len(seq) < 2: return np.zeros(400) # 20*20
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
dimer_counts = {aa1+aa2: 0 for aa1 in amino_acids for aa2 in amino_acids}
for i in range(len(seq) - 1):
dimer = seq[i:i+2].upper()
if dimer in dimer_counts: dimer_counts[dimer] += 1
total = max(len(seq) - 1, 1)
for key in dimer_counts: dimer_counts[key] /= total
return np.array([dimer_counts[d] for d in sorted(dimer_counts.keys())])
def positional_encoding(seq_len_actual, L_fixed=29, d_model=16):
pos_enc = np.zeros((L_fixed, d_model))
for pos in range(L_fixed):
for i in range(d_model):
angle = pos / (10000 ** (2 * (i // 2) / d_model))
pos_enc[pos, i] = np.sin(angle) if i % 2 == 0 else np.cos(angle)
return pos_enc.flatten()
def perturb_sequence(seq, perturb_rate=0.1, critical=['C', 'M', 'W']):
# (您的 perturb_sequence 实现)
if not seq: return ""
seq_list = list(seq)
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
for i, aa in enumerate(seq_list):
if aa.upper() not in critical and random.random() < perturb_rate:
seq_list[i] = random.choice([x for x in amino_acids if x != aa.upper()])
return "".join(seq_list)
def extract_features(seq, prott5_model_instance, L_fixed=29, d_model_pe=16): # Renamed d_model to d_model_pe
if not seq or not isinstance(seq, str) or len(seq) == 0:
print(f"警告: extract_features 接收到空或无效序列。返回零特征。")
return np.zeros(1024 + 20 + 1 + 3 + 2 + 400 + (L_fixed * d_model_pe))
embedding = prott5_model_instance.encode(seq) # prott5_model is now an instance
prot_embed = np.mean(embedding, axis=0) if embedding.shape[0] > 0 else np.zeros(embedding.shape[1] if embedding.ndim > 1 else 1024) # Handle empty embedding
if prot_embed.shape[0] != 1024: # Ensure consistent ProtT5 embedding dim
# print(f"警告: ProtT5 嵌入维度异常 ({prot_embed.shape[0]}) for seq '{seq[:20]}...'. 使用零向量。")
prot_embed = np.zeros(1024)
aa_comp = compute_amino_acid_composition(seq)
aa_comp_vector = np.array([aa_comp[aa] for aa in "ACDEFGHIKLMNPQRSTVWY"])
red_ratio = np.array([compute_reducing_aa_ratio(seq)])
gravy, pI, mol_weight = compute_physicochemical_properties(seq)
phys_props = np.array([gravy, pI, mol_weight])
HOMO, LUMO = compute_electronic_features(seq)
electronic = np.array([HOMO, LUMO])
dimer_vector = compute_dimer_frequency(seq)
pos_enc = positional_encoding(len(seq), L_fixed, d_model_pe) # Pass actual length, though current PE uses L_fixed
features = np.concatenate([prot_embed, aa_comp_vector, red_ratio, phys_props, electronic, dimer_vector, pos_enc])
return features
##############################################
# 主接口函数 prepare_features
##############################################
def prepare_features(neg_fasta, pos_fasta, prott5_model_path, additional_params=None):
neg_seqs = load_fasta(neg_fasta)
pos_seqs = load_fasta(pos_fasta)
if not neg_seqs and not pos_seqs:
raise ValueError("未能从FASTA文件加载任何序列。请检查文件路径和内容。")
neg_labels = [0] * len(neg_seqs)
pos_labels = [1] * len(pos_seqs)
sequences = neg_seqs + pos_seqs
labels = neg_labels + pos_labels
combined = list(zip(sequences, labels))
random.shuffle(combined)
sequences, labels = zip(*combined)
sequences = list(sequences)
labels = list(labels)
train_seqs, val_seqs, train_labels, val_labels = train_test_split(
sequences, labels, test_size=0.1, random_state=42, stratify=labels if len(np.unique(labels)) > 1 else None
)
print("训练集原始样本数:", len(train_seqs))
print("验证集原始样本数:", len(val_seqs))
if additional_params is not None and additional_params.get("augment", False):
# (数据增强逻辑 - 如果启用)
augmented_seqs, augmented_labels = [], []
perturb_rate = additional_params.get("perturb_rate", 0.1)
for seq, label in zip(train_seqs, train_labels):
aug_seq = perturb_sequence(seq, perturb_rate=perturb_rate)
augmented_seqs.append(aug_seq)
augmented_labels.append(label)
train_seqs.extend(augmented_seqs)
train_labels.extend(augmented_labels)
print("数据增强后训练集样本数:", len(train_seqs))
finetuned_model_file = additional_params.get("finetuned_model_file") if additional_params else None
# 创建 ProtT5Model 实例
prott5_model_instance = ProtT5Model(prott5_model_path, finetuned_model_file=finetuned_model_file)
def process_data(seqs_list): # Renamed seqs to seqs_list
feature_list = []
for s_item in seqs_list: # Renamed s to s_item
# 将 ProtT5Model 实例传递给 extract_features
features = extract_features(s_item, prott5_model_instance)
feature_list.append(features)
return np.array(feature_list)
X_train = process_data(train_seqs)
X_val = process_data(val_seqs)
if X_train.shape[0] == 0 or X_val.shape[0] == 0:
raise ValueError("特征提取后训练集或验证集为空。请检查序列数据和特征提取过程。")
# --- **关键修改:使用 RobustScaler** ---
scaler = RobustScaler()
print("使用 RobustScaler 进行特征归一化。")
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
return X_train_scaled, X_val_scaled, np.array(train_labels), np.array(val_labels), scaler
if __name__ == "__main__":
# 确保测试时使用的路径是有效的,或者创建虚拟文件
neg_fasta_test = "dummy_data/test_neg.fasta"
pos_fasta_test = "dummy_data/test_pos.fasta"
prott5_path_test = "dummy_prott5_model/" # 需要一个包含config.json, pytorch_model.bin等的目录结构
os.makedirs("dummy_data", exist_ok=True)
os.makedirs(prott5_path_test, exist_ok=True) # 创建虚拟模型目录
if not os.path.exists(neg_fasta_test):
with open(neg_fasta_test, "w") as f: f.write(">neg1\nKALKALKALK\n>neg2\nPEPTPEPT\n")
if not os.path.exists(pos_fasta_test):
with open(pos_fasta_test, "w") as f: f.write(">pos1\nAOPPEPTIDE\n>pos2\nTRYTRYTRY\n")
if not os.listdir(prott5_path_test): # 如果目录为空
print(f"警告: {prott5_path_test} 为空。ProtT5Model可能尝试从HuggingFace Hub下载模型。")
print(f"请确保您已下载Rostlab/ProstT5-XL-UniRef50或类似模型到该路径,或使用其HuggingFace名称。")
# 作为演示,我们假设用户会提供一个有效的路径或transformers可以处理它
# 如果要完全本地运行而不下载,需要填充该目录。
additional_params_test = {
"augment": False,
"perturb_rate": 0.1,
"finetuned_model_file": None # 测试时不使用微调模型
}
print("开始测试 prepare_features (使用RobustScaler)...")
try:
X_train_t, X_val_t, y_train_t, y_val_t, scaler_t = prepare_features(
neg_fasta_test,
pos_fasta_test,
"Rostlab/ProstT5-XL-UniRef50", # 使用HuggingFace模型名称,如果本地路径无效
additional_params_test
)
print("prepare_features 测试完成。")
print("训练集样本数:", X_train_t.shape[0])
print("验证集样本数:", X_val_t.shape[0])
if X_train_t.shape[0] > 0:
print("训练集特征维度:", X_train_t.shape[1])
print("一个缩放后的训练样本 (前5个特征):", X_train_t[0, :5])
if scaler_t:
print("Scaler类型:", type(scaler_t))
except Exception as e:
print(f"prepare_features 测试失败: {e}")
print("这可能是由于无法加载ProtT5模型或FASTA文件处理问题。请检查路径和文件内容。")