RosettaFold-3 / before_denoise.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
a376829 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
# Licensed under the Apache License, Version 2.0
"""
Pre-denoising steps for RF3: input processing, timestep setup, recycling trunk, latent preparation.
"""
from typing import List
import torch
from diffusers.utils import logging
from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__)
class RF3InputStep(ModularPipelineBlocks):
"""Parse sequence input and prepare feature dict for RF3."""
model_name = "rf3"
@property
def description(self) -> str:
return "Parse sequence and optional MSA/template inputs for structure prediction."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("sequence", required=True, type_hint=str, description="Amino acid sequence (one-letter codes)"),
InputParam("f", type_hint=dict, description="Pre-built feature dict (overrides sequence)"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("f", type_hint=dict, description="Feature dictionary for RF3"),
OutputParam("L", type_hint=int, description="Sequence length (num atoms)"),
OutputParam("I", type_hint=int, description="Num tokens"),
]
@torch.no_grad()
def __call__(self, components, state):
block_state = self.get_block_state(state)
f = block_state.f
sequence = block_state.sequence
if f is None:
# Build minimal feature dict from sequence
L = len(sequence)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Map sequence to restype indices
AA_ORDER = "ARNDCQEGHILKMFPSTWYV"
restype = torch.zeros(L, 32, device=device)
for i, aa in enumerate(sequence):
idx = AA_ORDER.find(aa)
if idx >= 0:
restype[i, idx] = 1.0
else:
restype[i, 20] = 1.0 # unknown
f = {
"restype": restype,
"atom_to_token_map": torch.arange(L, device=device),
"is_ca": torch.ones(L, dtype=torch.bool, device=device),
"ref_pos": torch.zeros(L, 3, device=device),
"ref_charge": torch.zeros(L, device=device),
"ref_mask": torch.ones(L, device=device),
"ref_element": torch.zeros(L, 128, device=device),
"ref_atom_name_chars": torch.zeros(L, 4, 64, device=device),
}
else:
L = f.get("ref_element", f.get("restype")).shape[0]
block_state.f = f
block_state.L = L
block_state.I = L # token count = atom count for CA-only
self.set_block_state(state, block_state)
return components, state
class RF3SetTimestepsStep(ModularPipelineBlocks):
"""Set up EDM noise schedule for RF3."""
model_name = "rf3"
@property
def description(self) -> str:
return "Construct EDM noise schedule for RF3 diffusion sampling."
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("scheduler", description="RF3 EDM scheduler")]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=None, type_hint=int),
InputParam("L", required=True, type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("noise_schedule", type_hint=torch.Tensor),
OutputParam("num_inference_steps", type_hint=int),
]
@torch.no_grad()
def __call__(self, components, state):
block_state = self.get_block_state(state)
if hasattr(components, "scheduler") and components.scheduler is not None:
noise_schedule = components.scheduler.get_noise_schedule()
else:
noise_schedule = torch.linspace(160.0 * 16.0, 4e-4 * 16.0, 200)
block_state.noise_schedule = noise_schedule
block_state.num_inference_steps = len(noise_schedule)
self.set_block_state(state, block_state)
return components, state
class RF3RecyclingStep(ModularPipelineBlocks):
"""Run the recycling trunk (pairformer + MSA + templates)."""
model_name = "rf3"
@property
def description(self) -> str:
return "Run RF3 recycling trunk to produce single/pair representations."
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", description="RF3 transformer model")]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("f", required=True, type_hint=dict),
InputParam("n_recycles", default=None, type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("single", type_hint=torch.Tensor, description="Single representation [I, c_s]"),
OutputParam("pair", type_hint=torch.Tensor, description="Pair representation [I, I, c_z]"),
OutputParam("s_inputs", type_hint=torch.Tensor, description="Input embeddings [I, c_s_inputs]"),
OutputParam("distogram", type_hint=torch.Tensor, description="Distogram prediction [I, I, bins]"),
]
@torch.no_grad()
def __call__(self, components, state):
block_state = self.get_block_state(state)
f = block_state.f
n_recycles = block_state.n_recycles
if hasattr(components, "transformer") and components.transformer is not None:
output = components.transformer(f=f, n_recycles=n_recycles)
block_state.single = output.single
block_state.pair = output.pair
block_state.distogram = output.distogram
block_state.s_inputs = None # populated inside forward
else:
# Placeholder when no model loaded
block_state.single = None
block_state.pair = None
block_state.distogram = None
block_state.s_inputs = None
self.set_block_state(state, block_state)
return components, state
class RF3PrepareLatentsStep(ModularPipelineBlocks):
"""Prepare initial noised coordinates for diffusion sampling."""
model_name = "rf3"
@property
def description(self) -> str:
return "Sample initial Gaussian noise scaled by the first noise schedule value."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("generator", type_hint=torch.Generator),
InputParam("diffusion_batch_size", default=5, type_hint=int),
InputParam("L", required=True, type_hint=int),
InputParam("noise_schedule", required=True, type_hint=torch.Tensor),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("xyz", type_hint=torch.Tensor, description="Initial noised coords [D, L, 3]"),
]
@torch.no_grad()
def __call__(self, components, state):
block_state = self.get_block_state(state)
L = block_state.L
noise_schedule = block_state.noise_schedule
D = block_state.diffusion_batch_size or 5
generator = block_state.generator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
c0 = noise_schedule[0]
xyz = c0 * torch.randn((D, L, 3), device=device, generator=generator)
block_state.xyz = xyz
self.set_block_state(state, block_state)
return components, state