| | |
| |
|
| | from __future__ import annotations |
| |
|
| | from math import sqrt |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from torch import nn |
| | from torch.nn import Module |
| |
|
| | from . import vb_const as const |
| | from . import vb_layers_initialize as init |
| | from .vb_loss_diffusionv2 import ( |
| | smooth_lddt_loss, |
| | weighted_rigid_align, |
| | ) |
| | from .vb_modules_encodersv2 import ( |
| | AtomAttentionDecoder, |
| | AtomAttentionEncoder, |
| | SingleConditioning, |
| | ) |
| | from .vb_modules_transformersv2 import ( |
| | DiffusionTransformer, |
| | ) |
| | from .vb_modules_utils import ( |
| | LinearNoBias, |
| | center_random_augmentation, |
| | compute_random_augmentation, |
| | default, |
| | log, |
| | ) |
| | from .vb_potentials_potentials import get_potentials |
| |
|
| |
|
| | class DiffusionModule(Module): |
| | """Diffusion module""" |
| |
|
| | def __init__( |
| | self, |
| | token_s: int, |
| | atom_s: int, |
| | atoms_per_window_queries: int = 32, |
| | atoms_per_window_keys: int = 128, |
| | sigma_data: int = 16, |
| | dim_fourier: int = 256, |
| | atom_encoder_depth: int = 3, |
| | atom_encoder_heads: int = 4, |
| | token_transformer_depth: int = 24, |
| | token_transformer_heads: int = 8, |
| | atom_decoder_depth: int = 3, |
| | atom_decoder_heads: int = 4, |
| | conditioning_transition_layers: int = 2, |
| | activation_checkpointing: bool = False, |
| | transformer_post_ln: bool = False, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | self.atoms_per_window_queries = atoms_per_window_queries |
| | self.atoms_per_window_keys = atoms_per_window_keys |
| | self.sigma_data = sigma_data |
| | self.activation_checkpointing = activation_checkpointing |
| |
|
| | |
| | self.single_conditioner = SingleConditioning( |
| | sigma_data=sigma_data, |
| | token_s=token_s, |
| | dim_fourier=dim_fourier, |
| | num_transitions=conditioning_transition_layers, |
| | ) |
| |
|
| | self.atom_attention_encoder = AtomAttentionEncoder( |
| | atom_s=atom_s, |
| | token_s=token_s, |
| | atoms_per_window_queries=atoms_per_window_queries, |
| | atoms_per_window_keys=atoms_per_window_keys, |
| | atom_encoder_depth=atom_encoder_depth, |
| | atom_encoder_heads=atom_encoder_heads, |
| | structure_prediction=True, |
| | activation_checkpointing=activation_checkpointing, |
| | transformer_post_layer_norm=transformer_post_ln, |
| | ) |
| |
|
| | self.s_to_a_linear = nn.Sequential( |
| | nn.LayerNorm(2 * token_s), LinearNoBias(2 * token_s, 2 * token_s) |
| | ) |
| | init.final_init_(self.s_to_a_linear[1].weight) |
| |
|
| | self.token_transformer = DiffusionTransformer( |
| | dim=2 * token_s, |
| | dim_single_cond=2 * token_s, |
| | depth=token_transformer_depth, |
| | heads=token_transformer_heads, |
| | activation_checkpointing=activation_checkpointing, |
| | |
| | ) |
| |
|
| | self.a_norm = nn.LayerNorm( |
| | 2 * token_s |
| | ) |
| |
|
| | self.atom_attention_decoder = AtomAttentionDecoder( |
| | atom_s=atom_s, |
| | token_s=token_s, |
| | attn_window_queries=atoms_per_window_queries, |
| | attn_window_keys=atoms_per_window_keys, |
| | atom_decoder_depth=atom_decoder_depth, |
| | atom_decoder_heads=atom_decoder_heads, |
| | activation_checkpointing=activation_checkpointing, |
| | |
| | ) |
| |
|
| | def forward( |
| | self, |
| | s_inputs, |
| | s_trunk, |
| | r_noisy, |
| | times, |
| | feats, |
| | diffusion_conditioning, |
| | multiplicity=1, |
| | ): |
| | if self.activation_checkpointing and self.training: |
| | s, normed_fourier = torch.utils.checkpoint.checkpoint( |
| | self.single_conditioner, |
| | times, |
| | s_trunk.repeat_interleave(multiplicity, 0), |
| | s_inputs.repeat_interleave(multiplicity, 0), |
| | ) |
| | else: |
| | s, normed_fourier = self.single_conditioner( |
| | times, |
| | s_trunk.repeat_interleave(multiplicity, 0), |
| | s_inputs.repeat_interleave(multiplicity, 0), |
| | ) |
| |
|
| | |
| | a, q_skip, c_skip, to_keys = self.atom_attention_encoder( |
| | feats=feats, |
| | q=diffusion_conditioning["q"].float(), |
| | c=diffusion_conditioning["c"].float(), |
| | atom_enc_bias=diffusion_conditioning["atom_enc_bias"].float(), |
| | to_keys=diffusion_conditioning["to_keys"], |
| | r=r_noisy, |
| | multiplicity=multiplicity, |
| | ) |
| |
|
| | |
| | a = a + self.s_to_a_linear(s) |
| |
|
| | mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) |
| | a = self.token_transformer( |
| | a, |
| | mask=mask.float(), |
| | s=s, |
| | bias=diffusion_conditioning[ |
| | "token_trans_bias" |
| | ].float(), |
| | multiplicity=multiplicity, |
| | ) |
| | a = self.a_norm(a) |
| |
|
| | |
| | r_update = self.atom_attention_decoder( |
| | a=a, |
| | q=q_skip, |
| | c=c_skip, |
| | atom_dec_bias=diffusion_conditioning["atom_dec_bias"].float(), |
| | feats=feats, |
| | multiplicity=multiplicity, |
| | to_keys=to_keys, |
| | ) |
| |
|
| | return r_update |
| |
|
| |
|
| | class AtomDiffusion(Module): |
| | def __init__( |
| | self, |
| | score_model_args, |
| | num_sampling_steps: int = 5, |
| | sigma_min: float = 0.0004, |
| | sigma_max: float = 160.0, |
| | sigma_data: float = 16.0, |
| | rho: float = 7, |
| | P_mean: float = -1.2, |
| | P_std: float = 1.5, |
| | gamma_0: float = 0.8, |
| | gamma_min: float = 1.0, |
| | noise_scale: float = 1.003, |
| | step_scale: float = 1.5, |
| | step_scale_random: list = None, |
| | coordinate_augmentation: bool = True, |
| | coordinate_augmentation_inference=None, |
| | compile_score: bool = False, |
| | alignment_reverse_diff: bool = False, |
| | synchronize_sigmas: bool = False, |
| | ): |
| | super().__init__() |
| | self.score_model = DiffusionModule( |
| | **score_model_args, |
| | ) |
| | if compile_score: |
| | self.score_model = torch.compile( |
| | self.score_model, dynamic=False, fullgraph=False |
| | ) |
| |
|
| | |
| | self.sigma_min = sigma_min |
| | self.sigma_max = sigma_max |
| | self.sigma_data = sigma_data |
| | self.rho = rho |
| | self.P_mean = P_mean |
| | self.P_std = P_std |
| | self.num_sampling_steps = num_sampling_steps |
| | self.gamma_0 = gamma_0 |
| | self.gamma_min = gamma_min |
| | self.noise_scale = noise_scale |
| | self.step_scale = step_scale |
| | self.step_scale_random = step_scale_random |
| | self.coordinate_augmentation = coordinate_augmentation |
| | self.coordinate_augmentation_inference = ( |
| | coordinate_augmentation_inference |
| | if coordinate_augmentation_inference is not None |
| | else coordinate_augmentation |
| | ) |
| | self.alignment_reverse_diff = alignment_reverse_diff |
| | self.synchronize_sigmas = synchronize_sigmas |
| |
|
| | self.token_s = score_model_args["token_s"] |
| | self.register_buffer("zero", torch.tensor(0.0), persistent=False) |
| |
|
| | @property |
| | def device(self): |
| | return next(self.score_model.parameters()).device |
| |
|
| | def c_skip(self, sigma): |
| | return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2) |
| |
|
| | def c_out(self, sigma): |
| | return sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2) |
| |
|
| | def c_in(self, sigma): |
| | return 1 / torch.sqrt(sigma**2 + self.sigma_data**2) |
| |
|
| | def c_noise(self, sigma): |
| | return log(sigma / self.sigma_data) * 0.25 |
| |
|
| | def preconditioned_network_forward( |
| | self, |
| | noised_atom_coords, |
| | sigma, |
| | network_condition_kwargs: dict, |
| | ): |
| | batch, device = noised_atom_coords.shape[0], noised_atom_coords.device |
| |
|
| | if isinstance(sigma, float): |
| | sigma = torch.full((batch,), sigma, device=device) |
| |
|
| | padded_sigma = rearrange(sigma, "b -> b 1 1") |
| |
|
| | r_update = self.score_model( |
| | r_noisy=self.c_in(padded_sigma) * noised_atom_coords, |
| | times=self.c_noise(sigma), |
| | **network_condition_kwargs, |
| | ) |
| |
|
| | denoised_coords = ( |
| | self.c_skip(padded_sigma) * noised_atom_coords |
| | + self.c_out(padded_sigma) * r_update |
| | ) |
| | return denoised_coords |
| |
|
| | def sample_schedule(self, num_sampling_steps=None): |
| | num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps) |
| | inv_rho = 1 / self.rho |
| |
|
| | steps = torch.arange( |
| | num_sampling_steps, device=self.device, dtype=torch.float32 |
| | ) |
| | sigmas = ( |
| | self.sigma_max**inv_rho |
| | + steps |
| | / (num_sampling_steps - 1) |
| | * (self.sigma_min**inv_rho - self.sigma_max**inv_rho) |
| | ) ** self.rho |
| |
|
| | sigmas = sigmas * self.sigma_data |
| |
|
| | sigmas = F.pad(sigmas, (0, 1), value=0.0) |
| | return sigmas |
| |
|
| | def sample( |
| | self, |
| | atom_mask, |
| | num_sampling_steps=None, |
| | multiplicity=1, |
| | max_parallel_samples=None, |
| | steering_args=None, |
| | **network_condition_kwargs, |
| | ): |
| | if steering_args is not None and ( |
| | steering_args["fk_steering"] |
| | or steering_args["physical_guidance_update"] |
| | or steering_args["contact_guidance_update"] |
| | ): |
| | potentials = get_potentials(steering_args, boltz2=True) |
| |
|
| | if steering_args["fk_steering"]: |
| | multiplicity = multiplicity * steering_args["num_particles"] |
| | energy_traj = torch.empty((multiplicity, 0), device=self.device) |
| | resample_weights = torch.ones(multiplicity, device=self.device).reshape( |
| | -1, steering_args["num_particles"] |
| | ) |
| | if ( |
| | steering_args["physical_guidance_update"] |
| | or steering_args["contact_guidance_update"] |
| | ): |
| | scaled_guidance_update = torch.zeros( |
| | (multiplicity, *atom_mask.shape[1:], 3), |
| | dtype=torch.float32, |
| | device=self.device, |
| | ) |
| | if max_parallel_samples is None: |
| | max_parallel_samples = multiplicity |
| |
|
| | num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps) |
| | atom_mask = atom_mask.repeat_interleave(multiplicity, 0) |
| |
|
| | shape = (*atom_mask.shape, 3) |
| |
|
| | |
| | sigmas = self.sample_schedule(num_sampling_steps) |
| | gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0) |
| | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:])) |
| | if self.training and self.step_scale_random is not None: |
| | step_scale = np.random.choice(self.step_scale_random) |
| | else: |
| | step_scale = self.step_scale |
| |
|
| | |
| | init_sigma = sigmas[0] |
| | atom_coords = init_sigma * torch.randn(shape, device=self.device) |
| | token_repr = None |
| | atom_coords_denoised = None |
| |
|
| | |
| | for step_idx, (sigma_tm, sigma_t, gamma) in enumerate(sigmas_and_gammas): |
| | random_R, random_tr = compute_random_augmentation( |
| | multiplicity, device=atom_coords.device, dtype=atom_coords.dtype |
| | ) |
| | atom_coords = atom_coords - atom_coords.mean(dim=-2, keepdims=True) |
| | atom_coords = ( |
| | torch.einsum("bmd,bds->bms", atom_coords, random_R) + random_tr |
| | ) |
| | if atom_coords_denoised is not None: |
| | atom_coords_denoised -= atom_coords_denoised.mean(dim=-2, keepdims=True) |
| | atom_coords_denoised = ( |
| | torch.einsum("bmd,bds->bms", atom_coords_denoised, random_R) |
| | + random_tr |
| | ) |
| | if ( |
| | steering_args["physical_guidance_update"] |
| | or steering_args["contact_guidance_update"] |
| | ) and scaled_guidance_update is not None: |
| | scaled_guidance_update = torch.einsum( |
| | "bmd,bds->bms", scaled_guidance_update, random_R |
| | ) |
| |
|
| | sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item() |
| |
|
| | t_hat = sigma_tm * (1 + gamma) |
| | steering_t = 1.0 - (step_idx / num_sampling_steps) |
| | noise_var = self.noise_scale**2 * (t_hat**2 - sigma_tm**2) |
| | eps = sqrt(noise_var) * torch.randn(shape, device=self.device) |
| | atom_coords_noisy = atom_coords + eps |
| |
|
| | with torch.no_grad(): |
| | atom_coords_denoised = torch.zeros_like(atom_coords_noisy) |
| | sample_ids = torch.arange(multiplicity).to(atom_coords_noisy.device) |
| | sample_ids_chunks = sample_ids.chunk( |
| | multiplicity % max_parallel_samples + 1 |
| | ) |
| |
|
| | for sample_ids_chunk in sample_ids_chunks: |
| | atom_coords_denoised_chunk = self.preconditioned_network_forward( |
| | atom_coords_noisy[sample_ids_chunk], |
| | t_hat, |
| | network_condition_kwargs=dict( |
| | multiplicity=sample_ids_chunk.numel(), |
| | **network_condition_kwargs, |
| | ), |
| | ) |
| | atom_coords_denoised[sample_ids_chunk] = atom_coords_denoised_chunk |
| |
|
| | if steering_args["fk_steering"] and ( |
| | ( |
| | step_idx % steering_args["fk_resampling_interval"] == 0 |
| | and noise_var > 0 |
| | ) |
| | or step_idx == num_sampling_steps - 1 |
| | ): |
| | |
| | energy = torch.zeros(multiplicity, device=self.device) |
| | for potential in potentials: |
| | parameters = potential.compute_parameters(steering_t) |
| | if parameters["resampling_weight"] > 0: |
| | component_energy = potential.compute( |
| | atom_coords_denoised, |
| | network_condition_kwargs["feats"], |
| | parameters, |
| | ) |
| | energy += parameters["resampling_weight"] * component_energy |
| | energy_traj = torch.cat((energy_traj, energy.unsqueeze(1)), dim=1) |
| |
|
| | |
| | if step_idx == 0: |
| | log_G = -1 * energy |
| | else: |
| | log_G = energy_traj[:, -2] - energy_traj[:, -1] |
| |
|
| | |
| | if ( |
| | steering_args["physical_guidance_update"] |
| | or steering_args["contact_guidance_update"] |
| | ) and noise_var > 0: |
| | ll_difference = ( |
| | eps**2 - (eps + scaled_guidance_update) ** 2 |
| | ).sum(dim=(-1, -2)) / (2 * noise_var) |
| | else: |
| | ll_difference = torch.zeros_like(energy) |
| |
|
| | |
| | resample_weights = F.softmax( |
| | (ll_difference + steering_args["fk_lambda"] * log_G).reshape( |
| | -1, steering_args["num_particles"] |
| | ), |
| | dim=1, |
| | ) |
| |
|
| | |
| | if ( |
| | steering_args["physical_guidance_update"] |
| | or steering_args["contact_guidance_update"] |
| | ) and step_idx < num_sampling_steps - 1: |
| | guidance_update = torch.zeros_like(atom_coords_denoised) |
| | for guidance_step in range(steering_args["num_gd_steps"]): |
| | energy_gradient = torch.zeros_like(atom_coords_denoised) |
| | for potential in potentials: |
| | parameters = potential.compute_parameters(steering_t) |
| | if ( |
| | parameters["guidance_weight"] > 0 |
| | and (guidance_step) % parameters["guidance_interval"] |
| | == 0 |
| | ): |
| | energy_gradient += parameters[ |
| | "guidance_weight" |
| | ] * potential.compute_gradient( |
| | atom_coords_denoised + guidance_update, |
| | network_condition_kwargs["feats"], |
| | parameters, |
| | ) |
| | guidance_update -= energy_gradient |
| | atom_coords_denoised += guidance_update |
| | scaled_guidance_update = ( |
| | guidance_update |
| | * -1 |
| | * self.step_scale |
| | * (sigma_t - t_hat) |
| | / t_hat |
| | ) |
| |
|
| | if steering_args["fk_steering"] and ( |
| | ( |
| | step_idx % steering_args["fk_resampling_interval"] == 0 |
| | and noise_var > 0 |
| | ) |
| | or step_idx == num_sampling_steps - 1 |
| | ): |
| | resample_indices = ( |
| | torch.multinomial( |
| | resample_weights, |
| | resample_weights.shape[1] |
| | if step_idx < num_sampling_steps - 1 |
| | else 1, |
| | replacement=True, |
| | ) |
| | + resample_weights.shape[1] |
| | * torch.arange( |
| | resample_weights.shape[0], device=resample_weights.device |
| | ).unsqueeze(-1) |
| | ).flatten() |
| |
|
| | atom_coords = atom_coords[resample_indices] |
| | atom_coords_noisy = atom_coords_noisy[resample_indices] |
| | atom_mask = atom_mask[resample_indices] |
| | if atom_coords_denoised is not None: |
| | atom_coords_denoised = atom_coords_denoised[resample_indices] |
| | energy_traj = energy_traj[resample_indices] |
| | if ( |
| | steering_args["physical_guidance_update"] |
| | or steering_args["contact_guidance_update"] |
| | ): |
| | scaled_guidance_update = scaled_guidance_update[ |
| | resample_indices |
| | ] |
| | if token_repr is not None: |
| | token_repr = token_repr[resample_indices] |
| |
|
| | if self.alignment_reverse_diff: |
| | with torch.autocast("cuda", enabled=False): |
| | atom_coords_noisy = weighted_rigid_align( |
| | atom_coords_noisy.float(), |
| | atom_coords_denoised.float(), |
| | atom_mask.float(), |
| | atom_mask.float(), |
| | ) |
| |
|
| | atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised) |
| |
|
| | denoised_over_sigma = (atom_coords_noisy - atom_coords_denoised) / t_hat |
| | atom_coords_next = ( |
| | atom_coords_noisy + step_scale * (sigma_t - t_hat) * denoised_over_sigma |
| | ) |
| |
|
| | atom_coords = atom_coords_next |
| |
|
| | return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr) |
| |
|
| | def loss_weight(self, sigma): |
| | return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2) |
| |
|
| | def noise_distribution(self, batch_size): |
| | return ( |
| | self.sigma_data |
| | * ( |
| | self.P_mean |
| | + self.P_std * torch.randn((batch_size,), device=self.device) |
| | ).exp() |
| | ) |
| |
|
| | def forward( |
| | self, |
| | s_inputs, |
| | s_trunk, |
| | feats, |
| | diffusion_conditioning, |
| | multiplicity=1, |
| | ): |
| | |
| | batch_size = feats["coords"].shape[0] // multiplicity |
| |
|
| | if self.synchronize_sigmas: |
| | sigmas = self.noise_distribution(batch_size).repeat_interleave( |
| | multiplicity, 0 |
| | ) |
| | else: |
| | sigmas = self.noise_distribution(batch_size * multiplicity) |
| | padded_sigmas = rearrange(sigmas, "b -> b 1 1") |
| |
|
| | atom_coords = feats["coords"] |
| |
|
| | atom_mask = feats["atom_pad_mask"] |
| | atom_mask = atom_mask.repeat_interleave(multiplicity, 0) |
| |
|
| | atom_coords = center_random_augmentation( |
| | atom_coords, atom_mask, augmentation=self.coordinate_augmentation |
| | ) |
| |
|
| | noise = torch.randn_like(atom_coords) |
| | noised_atom_coords = atom_coords + padded_sigmas * noise |
| |
|
| | denoised_atom_coords = self.preconditioned_network_forward( |
| | noised_atom_coords, |
| | sigmas, |
| | network_condition_kwargs={ |
| | "s_inputs": s_inputs, |
| | "s_trunk": s_trunk, |
| | "feats": feats, |
| | "multiplicity": multiplicity, |
| | "diffusion_conditioning": diffusion_conditioning, |
| | }, |
| | ) |
| |
|
| | return { |
| | "denoised_atom_coords": denoised_atom_coords, |
| | "sigmas": sigmas, |
| | "aligned_true_atom_coords": atom_coords, |
| | } |
| |
|
| | def compute_loss( |
| | self, |
| | feats, |
| | out_dict, |
| | add_smooth_lddt_loss=True, |
| | nucleotide_loss_weight=5.0, |
| | ligand_loss_weight=10.0, |
| | multiplicity=1, |
| | filter_by_plddt=0.0, |
| | ): |
| | with torch.autocast("cuda", enabled=False): |
| | denoised_atom_coords = out_dict["denoised_atom_coords"].float() |
| | sigmas = out_dict["sigmas"].float() |
| |
|
| | resolved_atom_mask_uni = feats["atom_resolved_mask"].float() |
| |
|
| | if filter_by_plddt > 0: |
| | plddt_mask = feats["plddt"] > filter_by_plddt |
| | resolved_atom_mask_uni = resolved_atom_mask_uni * plddt_mask.float() |
| |
|
| | resolved_atom_mask = resolved_atom_mask_uni.repeat_interleave( |
| | multiplicity, 0 |
| | ) |
| |
|
| | align_weights = denoised_atom_coords.new_ones(denoised_atom_coords.shape[:2]) |
| | atom_type = ( |
| | torch.bmm( |
| | feats["atom_to_token"].float(), |
| | feats["mol_type"].unsqueeze(-1).float(), |
| | ) |
| | .squeeze(-1) |
| | .long() |
| | ) |
| | atom_type_mult = atom_type.repeat_interleave(multiplicity, 0) |
| |
|
| | align_weights = ( |
| | align_weights |
| | * ( |
| | 1 |
| | + nucleotide_loss_weight |
| | * ( |
| | torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float() |
| | + torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float() |
| | ) |
| | + ligand_loss_weight |
| | * torch.eq( |
| | atom_type_mult, const.chain_type_ids["NONPOLYMER"] |
| | ).float() |
| | ).float() |
| | ) |
| |
|
| | atom_coords = out_dict["aligned_true_atom_coords"].float() |
| | atom_coords_aligned_ground_truth = weighted_rigid_align( |
| | atom_coords.detach(), |
| | denoised_atom_coords.detach(), |
| | align_weights.detach(), |
| | mask=feats["atom_resolved_mask"] |
| | .float() |
| | .repeat_interleave(multiplicity, 0) |
| | .detach(), |
| | ) |
| |
|
| | |
| | atom_coords_aligned_ground_truth = atom_coords_aligned_ground_truth.to( |
| | denoised_atom_coords |
| | ) |
| |
|
| | |
| | mse_loss = ( |
| | (denoised_atom_coords - atom_coords_aligned_ground_truth) ** 2 |
| | ).sum(dim=-1) |
| | mse_loss = torch.sum( |
| | mse_loss * align_weights * resolved_atom_mask, dim=-1 |
| | ) / (torch.sum(3 * align_weights * resolved_atom_mask, dim=-1) + 1e-5) |
| |
|
| | |
| | loss_weights = self.loss_weight(sigmas) |
| | mse_loss = (mse_loss * loss_weights).mean() |
| |
|
| | total_loss = mse_loss |
| |
|
| | |
| | lddt_loss = self.zero |
| | if add_smooth_lddt_loss: |
| | lddt_loss = smooth_lddt_loss( |
| | denoised_atom_coords, |
| | feats["coords"], |
| | torch.eq(atom_type, const.chain_type_ids["DNA"]).float() |
| | + torch.eq(atom_type, const.chain_type_ids["RNA"]).float(), |
| | coords_mask=resolved_atom_mask_uni, |
| | multiplicity=multiplicity, |
| | ) |
| |
|
| | total_loss = total_loss + lddt_loss |
| |
|
| | loss_breakdown = { |
| | "mse_loss": mse_loss, |
| | "smooth_lddt_loss": lddt_loss, |
| | } |
| |
|
| | return {"loss": total_loss, "loss_breakdown": loss_breakdown} |
| |
|