Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class AntioxidantPredictor(nn.Module): | |
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1): | |
super(AntioxidantPredictor, self).__init__() | |
self.prott5_dim = 1024 | |
self.handcrafted_dim = input_dim - self.prott5_dim | |
self.seq_len = 16 | |
self.prott5_feature_dim = 64 # 16 * 64 = 1024 | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=self.prott5_feature_dim, | |
nhead=transformer_heads, | |
dropout=transformer_dropout, | |
batch_first=True | |
) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers) | |
fused_dim = self.prott5_feature_dim + self.handcrafted_dim | |
self.fusion_fc = nn.Sequential( | |
nn.Linear(fused_dim, 1024), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(1024, 512), | |
nn.ReLU(), | |
nn.Dropout(0.3) | |
) | |
self.classifier = nn.Sequential( | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(256, 1) | |
) | |
# 温度缩放参数 T | |
# 初始化为1.0,表示在校准前不改变logits | |
# requires_grad=False,因为T通常在模型训练完成后单独优化 | |
self.temperature = nn.Parameter(torch.ones(1), requires_grad=False) | |
def forward(self, x, *args): | |
batch_size = x.size(0) | |
prot_t5_features = x[:, :self.prott5_dim] | |
handcrafted_features = x[:, self.prott5_dim:] | |
prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim) | |
encoded_seq = self.transformer_encoder(prot_t5_seq) | |
refined_prott5 = encoded_seq.mean(dim=1) | |
fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1) | |
fused_features = self.fusion_fc(fused_features) | |
logits = self.classifier(fused_features) | |
# 应用温度缩放: logits / T | |
# 注意:这里是在获取原始logits后,外部应用sigmoid前进行缩放 | |
# 如果要直接输出校准后的概率,可以在这里除以T然后sigmoid | |
# 但通常T的优化和应用是分离的。 | |
# 为了在调用模型时就能获得校准的logits(如果T已优化),我们在这里应用它。 | |
# 如果T未被优化(仍为1),则此操作无影响。 | |
logits_scaled = logits / self.temperature | |
return logits_scaled # 返回校准后(或原始,如果T=1)的logits | |
def set_temperature(self, temp_value, device): | |
"""用于设置优化后的温度值""" | |
self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False) | |
print(f"模型温度 T 设置为: {self.temperature.item()}") | |
def get_temperature(self): | |
"""获取当前温度值""" | |
return self.temperature.item() | |
if __name__ == "__main__": | |
dummy_input = torch.randn(8, 1914) | |
model = AntioxidantPredictor(input_dim=1914) | |
print(f"初始温度: {model.get_temperature()}") | |
logits_output_initial = model(dummy_input) | |
print("初始 logits shape:", logits_output_initial.shape) | |
probs_initial = torch.sigmoid(logits_output_initial) | |
print("初始概率 (T=1.0):", probs_initial.detach().cpu().numpy()[:2]) | |
# 模拟设置一个优化后的温度 | |
model.set_temperature(1.5, device='cpu') # 假设优化得到 T=1.5 | |
print(f"设置后温度: {model.get_temperature()}") | |
logits_output_scaled = model(dummy_input) # 模型内部应用了 T | |
print("缩放后 logits shape:", logits_output_scaled.shape) | |
probs_scaled = torch.sigmoid(logits_output_scaled) # 外部仍然需要 sigmoid | |
print("缩放后概率 (T=1.5):", probs_scaled.detach().cpu().numpy()[:2]) | |
# 验证 logits / T 的效果 | |
# logits_manual_scale = logits_output_initial / 1.5 | |
# probs_manual_scale = torch.sigmoid(logits_manual_scale) | |
# print("手动缩放后概率 (T=1.5):", probs_manual_scale.detach().cpu().numpy()[:2]) | |
# assert torch.allclose(probs_scaled, probs_manual_scale) # 应该相等 | |