| |
| |
|
|
| """ |
| Pre-denoising steps for RF3: input processing, timestep setup, recycling trunk, latent preparation. |
| """ |
|
|
| from typing import List |
|
|
| import torch |
|
|
| from diffusers.utils import logging |
| from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState |
| from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class RF3InputStep(ModularPipelineBlocks): |
| """Parse sequence input and prepare feature dict for RF3.""" |
|
|
| model_name = "rf3" |
|
|
| @property |
| def description(self) -> str: |
| return "Parse sequence and optional MSA/template inputs for structure prediction." |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam("sequence", required=True, type_hint=str, description="Amino acid sequence (one-letter codes)"), |
| InputParam("f", type_hint=dict, description="Pre-built feature dict (overrides sequence)"), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam("f", type_hint=dict, description="Feature dictionary for RF3"), |
| OutputParam("L", type_hint=int, description="Sequence length (num atoms)"), |
| OutputParam("I", type_hint=int, description="Num tokens"), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components, state): |
| block_state = self.get_block_state(state) |
|
|
| f = block_state.f |
| sequence = block_state.sequence |
|
|
| if f is None: |
| |
| L = len(sequence) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| AA_ORDER = "ARNDCQEGHILKMFPSTWYV" |
| restype = torch.zeros(L, 32, device=device) |
| for i, aa in enumerate(sequence): |
| idx = AA_ORDER.find(aa) |
| if idx >= 0: |
| restype[i, idx] = 1.0 |
| else: |
| restype[i, 20] = 1.0 |
|
|
| f = { |
| "restype": restype, |
| "atom_to_token_map": torch.arange(L, device=device), |
| "is_ca": torch.ones(L, dtype=torch.bool, device=device), |
| "ref_pos": torch.zeros(L, 3, device=device), |
| "ref_charge": torch.zeros(L, device=device), |
| "ref_mask": torch.ones(L, device=device), |
| "ref_element": torch.zeros(L, 128, device=device), |
| "ref_atom_name_chars": torch.zeros(L, 4, 64, device=device), |
| } |
| else: |
| L = f.get("ref_element", f.get("restype")).shape[0] |
|
|
| block_state.f = f |
| block_state.L = L |
| block_state.I = L |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class RF3SetTimestepsStep(ModularPipelineBlocks): |
| """Set up EDM noise schedule for RF3.""" |
|
|
| model_name = "rf3" |
|
|
| @property |
| def description(self) -> str: |
| return "Construct EDM noise schedule for RF3 diffusion sampling." |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ComponentSpec("scheduler", description="RF3 EDM scheduler")] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam("num_inference_steps", default=None, type_hint=int), |
| InputParam("L", required=True, type_hint=int), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam("noise_schedule", type_hint=torch.Tensor), |
| OutputParam("num_inference_steps", type_hint=int), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components, state): |
| block_state = self.get_block_state(state) |
|
|
| if hasattr(components, "scheduler") and components.scheduler is not None: |
| noise_schedule = components.scheduler.get_noise_schedule() |
| else: |
| noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200) |
|
|
| block_state.noise_schedule = noise_schedule |
| block_state.num_inference_steps = len(noise_schedule) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class RF3RecyclingStep(ModularPipelineBlocks): |
| """Run the recycling trunk (pairformer + MSA + templates).""" |
|
|
| model_name = "rf3" |
|
|
| @property |
| def description(self) -> str: |
| return "Run RF3 recycling trunk to produce single/pair representations." |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ComponentSpec("transformer", description="RF3 transformer model")] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam("f", required=True, type_hint=dict), |
| InputParam("n_recycles", default=None, type_hint=int), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam("single", type_hint=torch.Tensor, description="Single representation [I, c_s]"), |
| OutputParam("pair", type_hint=torch.Tensor, description="Pair representation [I, I, c_z]"), |
| OutputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"), |
| OutputParam("distogram", type_hint=torch.Tensor, description="Distogram prediction [I, I, bins]"), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components, state): |
| block_state = self.get_block_state(state) |
|
|
| f = block_state.f |
| n_recycles = block_state.n_recycles |
|
|
| if hasattr(components, "transformer") and components.transformer is not None: |
| output = components.transformer(f=f, n_recycles=n_recycles) |
| block_state.single = output.single |
| block_state.pair = output.pair |
| block_state.distogram = output.distogram |
| block_state.s_inputs = None |
| else: |
| |
| block_state.single = None |
| block_state.pair = None |
| block_state.distogram = None |
| block_state.s_inputs = None |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class RF3PrepareLatentsStep(ModularPipelineBlocks): |
| """Prepare initial noised coordinates for diffusion sampling.""" |
|
|
| model_name = "rf3" |
|
|
| @property |
| def description(self) -> str: |
| return "Sample initial Gaussian noise scaled by the first noise schedule value." |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam("generator", type_hint=torch.Generator), |
| InputParam("diffusion_batch_size", default=5, type_hint=int), |
| InputParam("L", required=True, type_hint=int), |
| InputParam("noise_schedule", required=True, type_hint=torch.Tensor), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components, state): |
| block_state = self.get_block_state(state) |
|
|
| L = block_state.L |
| noise_schedule = block_state.noise_schedule |
| D = block_state.diffusion_batch_size or 5 |
| generator = block_state.generator |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| c0 = noise_schedule[0] |
| xyz = c0 * torch.randn((D, L, 3), device=device, generator=generator) |
|
|
| block_state.xyz = xyz |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|