Xianfish9's picture
Update Feature_extraction_algorithms/PSTAAP.py
f808770 verified
import os
import numpy as np
from typing import List
import scipy.io
def load_precomputed_fr_matrix(mat_file_path: str):
"""
从 .mat 文件直接加载预先计算好的 Fr 矩阵并进行缓存。
...
"""
global _cached_fr_matrix, _expected_length_after_processing
print(f"正在从 {mat_file_path} 加载预计算的 Fr 矩阵...")
try:
mat_data = scipy.io.loadmat(mat_file_path)
matrix_key = 'Fr' # 修改后的变量名
if matrix_key not in mat_data:
raise KeyError(f"在 {mat_file_path} 中未找到变量名 '{matrix_key}'。 "
f"文件中可用的变量有: {list(mat_data.keys())}")
_cached_fr_matrix = mat_data[matrix_key]
_expected_length_after_processing = _cached_fr_matrix.shape[1] + 2
print(f"Fr 矩阵加载并缓存成功。形状: {_cached_fr_matrix.shape}")
print(f"推断出的序列期望长度 (处理后): {_expected_length_after_processing}")
except Exception as e:
print(f"❌ 加载 Fr 矩阵失败: {e}")
raise
# --- 模块级缓存 ---
# 这个变量将会在内存中存储计算好的Fr矩阵,避免重复计算和文件IO。
_cached_fr_matrix = None
# 存储预处理后的序列长度,用于后续校验
_expected_length_after_processing = None
# --- 内部辅助函数 (从你的mat计算代码中提取) ---
def _read_fasta_sequences(filename: str) -> List[str]:
"""
(内部函数) 读取FASTA格式文件,返回序列列表。
"""
sequences = []
current_seq = []
try:
with open(filename, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line: continue
if line.startswith('>'):
if current_seq:
sequences.append(''.join(current_seq))
current_seq = []
else:
# 确保序列为大写,以匹配氨基酸字典
current_seq.append(line.upper())
if current_seq:
sequences.append(''.join(current_seq))
except FileNotFoundError:
raise FileNotFoundError(f"错误:文件 '{filename}' 未找到。")
return sequences
def _process_sequence(sequence: str) -> str:
"""
(内部函数) 对单条序列进行预处理:移除正中间的氨基酸。
这个函数统一了训练和提取时的预处理逻辑。
"""
middle_index = (len(sequence) - 1) // 2
return sequence[:middle_index] + sequence[middle_index + 1:]
def _calculate_frequency_matrix(sequences: List[str], aa_map: dict) -> np.ndarray:
"""
(内部函数) 为一组序列计算标准化的三肽频率矩阵。
"""
if not sequences:
return np.zeros((20 ** 3, 0))
num_sequences = len(sequences)
seq_length = len(sequences[0])
freq_matrix = np.zeros((20 ** 3, seq_length - 2))
for seq in sequences:
for j in range(seq_length - 2):
k1 = aa_map.get(seq[j], -1)
k2 = aa_map.get(seq[j + 1], -1)
k3 = aa_map.get(seq[j + 2], -1)
if -1 not in {k1, k2, k3}:
index = 400 * k1 + 20 * k2 + k3
freq_matrix[index, j] += 1
return freq_matrix / num_sequences if num_sequences > 0 else freq_matrix
# --- 公共API函数 ---
def initialize_fr_matrix(fasta_files: List[str]):
"""
根据输入的FASTA文件列表计算并缓存Fr矩阵。
这是使用 PSTAAP_feature 前必须调用的初始化函数。
Args:
fasta_files (List[str]): FASTA文件的路径列表。每个文件被视为一个独立的类别。
"""
global _cached_fr_matrix, _expected_length_after_processing
print("正在初始化PSTAAP特征提取器...")
AA_MAP = {char: i for i, char in enumerate('ACDEFGHIKLMNPQRSTVWY')}
# 1. 读取并验证所有序列
all_sequences_by_file = [_read_fasta_sequences(f) for f in fasta_files]
if not all_sequences_by_file or not any(all_sequences_by_file):
raise ValueError("输入的文件列表为空或所有文件均不包含序列。")
first_len = len(all_sequences_by_file[0][0])
for i, seqs in enumerate(all_sequences_by_file):
if not all(len(s) == first_len for s in seqs):
raise ValueError(f"文件 '{fasta_files[i]}' 中的序列长度不一致或与其他文件不同。")
# 2. 预处理所有序列
processed_sequences_list = [[_process_sequence(seq) for seq in seqs] for seqs in all_sequences_by_file]
_expected_length_after_processing = len(processed_sequences_list[0][0])
# 3. 计算 Fr 矩阵
f_matrices, ff_matrices = [], []
num_files = len(processed_sequences_list)
for i in range(num_files):
current_seqs = processed_sequences_list[i]
other_seqs_combined = [seq for idx, lst in enumerate(processed_sequences_list) if idx != i for seq in lst]
f_matrices.append(_calculate_frequency_matrix(current_seqs, AA_MAP))
ff_matrices.append(_calculate_frequency_matrix(other_seqs_combined, AA_MAP))
F_avg = np.mean(f_matrices, axis=0)
FF_avg = np.mean(ff_matrices, axis=0)
# 4. 缓存计算结果
_cached_fr_matrix = F_avg - FF_avg
print(f"Fr 矩阵计算完成并已缓存。形状: {_cached_fr_matrix.shape}")
def PSTAAP_feature(protein_sequences: List[str]) -> np.ndarray:
"""
从蛋白质序列中提取PSTAAP特征。
在使用此函数之前,必须先调用 initialize_fr_matrix() 来计算并缓存Fr矩阵。
Args:
protein_sequences (List[str]): 需要提取特征的蛋白质序列列表。
Returns:
np.ndarray: PSTAAP特征矩阵,形状为 (序列数, 特征维度)。
"""
if _cached_fr_matrix is None:
raise RuntimeError(
"Fr矩阵尚未初始化。请在使用此函数前,先调用 initialize_fr_matrix(fasta_files) 函数。"
)
# 统一的预处理步骤
processed_sequences = [_process_sequence(seq) for seq in protein_sequences]
if len(processed_sequences[0]) != _expected_length_after_processing:
raise ValueError(f"输入序列处理后的长度 ({len(processed_sequences[0])}) 与训练时"
f"的期望长度 ({_expected_length_after_processing}) 不匹配。")
AA = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
num_seqs = len(processed_sequences)
feature_dim = len(processed_sequences[0]) - 2
PSTAAP = np.zeros((num_seqs, feature_dim))
for i in range(num_seqs):
for j in range(feature_dim):
t1 = processed_sequences[i][j]
t2 = processed_sequences[i][j+1]
t3 = processed_sequences[i][j+2]
try:
position1 = AA.index(t1)
position2 = AA.index(t2)
position3 = AA.index(t3)
index = 400 * position1 + 20 * position2 + position3
PSTAAP[i][j] = _cached_fr_matrix[index][j]
except ValueError:
# 如果遇到非标准氨基酸,可以选择跳过、设为0或报错
# 这里我们默认该特征值为0
pass
return PSTAAP