| | |
| | |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from utils.nn_utils import graph_to_batch |
| |
|
| | from .backbone import FrameBuilder |
| |
|
| |
|
| | class BackboneModel(nn.Module): |
| | def __init__(self) -> None: |
| | super().__init__() |
| | self.backbone_builder = FrameBuilder() |
| |
|
| | def forward(self, X, batch_ids): |
| | ''' |
| | X: [N, 14, 3], predicted all-atom coordinates (obviously with a lot of invalidities) |
| | assume the first 4 are N, CA, C, O |
| | S: [N], predicted sequence |
| | ''' |
| |
|
| | |
| | X, mask = graph_to_batch(X, batch_ids, mask_is_pad=False) |
| | C = mask.long() |
| |
|
| | |
| | R, t, q = self.backbone_builder.inverse(X, C) |
| | X_bb = self.backbone_builder(R, t, C) |
| | X = torch.cat([X_bb, X[:, :, 4:]], dim=-2) |
| | |
| | |
| | return X[mask] |