Spaces:
Runtime error
Runtime error
File size: 834 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import torch.nn.functional as F
def get_consecutive_flag(chain_nb, res_nb, mask):
"""
Args:
chain_nb, res_nb
Returns:
consec: A flag tensor indicating whether residue-i is connected to residue-(i+1),
BoolTensor, (B, L-1)[b, i].
"""
d_res_nb = (res_nb[:, 1:] - res_nb[:, :-1]).abs() # (B, L-1)
same_chain = (chain_nb[:, 1:] == chain_nb[:, :-1])
consec = torch.logical_and(d_res_nb == 1, same_chain)
consec = torch.logical_and(consec, mask[:, :-1])
return consec
def get_terminus_flag(chain_nb, res_nb, mask):
consec = get_consecutive_flag(chain_nb, res_nb, mask)
N_term_flag = F.pad(torch.logical_not(consec), pad=(1, 0), value=1)
C_term_flag = F.pad(torch.logical_not(consec), pad=(0, 1), value=1)
return N_term_flag, C_term_flag
|