| | import os |
| | from typing import Any, Callable, Sequence |
| | from warnings import warn |
| |
|
| | import attr |
| | import torch |
| | from tqdm import tqdm |
| |
|
| | from src.data.esm.sdk.api import ( |
| | ESM3InferenceClient, |
| | ESMProtein, |
| | ESMProteinError, |
| | ESMProteinTensor, |
| | ForwardAndSampleOutput, |
| | ForwardTrackData, |
| | GenerationConfig, |
| | LogitsConfig, |
| | LogitsOutput, |
| | SamplingConfig, |
| | SamplingTrackConfig, |
| | ) |
| | from src.data.esm.tokenization import ( |
| | EsmTokenizerBase, |
| | TokenizerCollectionProtocol, |
| | ) |
| | from src.data.esm.tokenization.function_tokenizer import ( |
| | InterProQuantizedTokenizer, |
| | ) |
| | from src.data.esm.utils.constants import esm3 as C |
| | from src.data.esm.utils.misc import stack_variable_length_tensors |
| | from src.data.esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY |
| | from src.data.esm.utils.sampling import ( |
| | _BatchedESMProteinTensor, |
| | get_sampling_mask, |
| | sample_function_logits, |
| | sample_logits, |
| | sample_residue_annotation_logits, |
| | sample_sasa_logits, |
| | ) |
| |
|
| |
|
| | def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int): |
| | """Trim tensors on the sequence dimension. |
| | |
| | This util assume that input tensor class has batch dimension. |
| | """ |
| | assert attr.has(o.__class__) |
| |
|
| | sliced = {} |
| | for k, v in attr.asdict(o, recurse=False).items(): |
| | if v is None: |
| | sliced[k] = None |
| | elif isinstance(v, torch.Tensor): |
| | |
| | sliced[k] = v[:, :sequence_len] |
| | elif isinstance(v, tuple) and all(isinstance(t, torch.Tensor) for t in v): |
| | |
| | sliced[k] = [t[:, :sequence_len] for t in v] |
| | elif attr.has(v.__class__): |
| | |
| | sliced[k] = _trim_sequence_tensor_dataclass(v, sequence_len) |
| | else: |
| | |
| | sliced[k] = v |
| |
|
| | return attr.evolve(o, **sliced) |
| |
|
| |
|
| | def _slice_tensor_dataclass(o: Any, i: int, keep_dim: bool = False) -> Any: |
| | """Take a slice out of any attr defined Tensor objects along the batch dimension. |
| | |
| | Args: |
| | o: input tensor object to be sliced. |
| | i: index of the row to be sliced. |
| | keep_dim: whether to keep the batch dim after slicing. |
| | For example, given a tensor of shape (5, 8), if keep_dim is True, |
| | return a sliced tensor of shape (1, 8). Return a tensor of shape |
| | (8,) instead if keep_dim is False. The default is False. |
| | """ |
| | assert attr.has(o.__class__) |
| |
|
| | sliced = {} |
| | for k, v in attr.asdict(o, recurse=False).items(): |
| | if v is None: |
| | sliced[k] = None |
| | elif isinstance(v, torch.Tensor): |
| | |
| | row = v.select(0, i) |
| | if keep_dim: |
| | row = row.unsqueeze(0) |
| | sliced[k] = row |
| | elif attr.has(v.__class__): |
| | |
| | sliced[k] = _slice_tensor_dataclass(v, i, keep_dim) |
| | else: |
| | |
| | sliced[k] = v |
| |
|
| | return attr.evolve(o, **sliced) |
| |
|
| |
|
| | def iterative_sampling_raw( |
| | client: ESM3InferenceClient, |
| | proteins: list[ESMProtein], |
| | configs: list[GenerationConfig], |
| | ) -> list[ESMProtein | ESMProteinError]: |
| | |
| | input_tokens = [client.encode(protein) for protein in proteins] |
| |
|
| | output_tokens_list = client.batch_generate(input_tokens, configs) |
| |
|
| | raw_proteins: list[ESMProtein | ESMProteinError] = [] |
| | for output_tokens in output_tokens_list: |
| | if isinstance(output_tokens, ESMProteinTensor): |
| | raw_proteins.append(client.decode(output_tokens)) |
| | elif isinstance(output_tokens, ESMProteinError): |
| | raw_proteins.append(output_tokens) |
| | else: |
| | raise ValueError(f"Unknown output type {type(output_tokens)}") |
| |
|
| | for input_protein, raw_protein, config in zip(proteins, raw_proteins, configs): |
| | if isinstance(raw_protein, ESMProteinError): |
| | |
| | continue |
| | if config.track not in ["function", "residue_annotations"]: |
| | |
| | |
| | raw_protein.function_annotations = input_protein.function_annotations |
| |
|
| | return raw_proteins |
| |
|
| |
|
| | def _make_masked_inputs( |
| | track: str, sequence_length: int, tokenizers: TokenizerCollectionProtocol |
| | ): |
| | get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s) |
| | has_tokenizer: Callable[[str], bool] = lambda s: hasattr(tokenizers, s) |
| |
|
| | if track == "coordinates": |
| | dims = (sequence_length, 3, 3) |
| | elif track == "confidence": |
| | dims = (sequence_length,) |
| | elif track == "attention_mask": |
| | dims = (sequence_length,) |
| | elif track == "function": |
| | dims = (sequence_length, tokenizers.function.depth) |
| | elif track == "residue_annotations": |
| | dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS) |
| | else: |
| | dims = (sequence_length,) |
| |
|
| | if track == "coordinates": |
| | masked_tokens = torch.full(dims, torch.inf, dtype=torch.float) |
| | elif track == "confidence": |
| | |
| | masked_tokens = torch.full(dims, 0.0) |
| | elif track == "attention_mask": |
| | masked_tokens = torch.full(dims, 1, dtype=torch.bool) |
| | elif has_tokenizer(track): |
| | masked_tokens = torch.full( |
| | dims, get_tokenizer(track).mask_token_id, dtype=torch.long |
| | ) |
| | masked_tokens[0] = get_tokenizer(track).bos_token_id |
| | masked_tokens[-1] = get_tokenizer(track).eos_token_id |
| | else: |
| | |
| | return None |
| |
|
| | return masked_tokens |
| |
|
| |
|
| | def _stack_protein_tensors( |
| | input_tokens: list[ESMProteinTensor], |
| | sequence_lengths: list[int], |
| | tokenizers: TokenizerCollectionProtocol, |
| | device: str | torch.device, |
| | ) -> _BatchedESMProteinTensor: |
| | o = _BatchedESMProteinTensor() |
| |
|
| | def _maybe_mock_input(fn, t, l): |
| | if t is not None: |
| | return t |
| |
|
| | |
| | t = _make_masked_inputs(fn, l, tokenizers) |
| | if t is not None: |
| | t = t.to(device) |
| |
|
| | return t |
| |
|
| | def _stack_field(fn: str): |
| | tensors = [getattr(tokens, fn) for tokens in input_tokens] |
| |
|
| | |
| | tensors = [ |
| | _maybe_mock_input(fn, t, l) for t, l in zip(tensors, sequence_lengths) |
| | ] |
| |
|
| | |
| | |
| | |
| | if all([t is None for t in tensors]): |
| | setattr(o, fn, None) |
| | return |
| |
|
| | if fn == "coordinates": |
| | mask_token_id = torch.inf |
| | else: |
| | mask_token_id = getattr(tokenizers, fn).pad_token_id |
| |
|
| | setattr( |
| | o, |
| | fn, |
| | stack_variable_length_tensors( |
| | sequences=tensors, |
| | constant_value=mask_token_id, |
| | ), |
| | ) |
| |
|
| | for f in attr.fields(ESMProteinTensor): |
| | |
| | if f.name == "potential_sequence_of_concern": |
| | continue |
| | _stack_field(f.name) |
| |
|
| | return o |
| |
|
| |
|
| | def _get_masked_positions( |
| | track: str, tokens: torch.Tensor, mask_token_id: int |
| | ) -> torch.Tensor: |
| | if track == "function": |
| | mask = torch.all(tokens == mask_token_id, dim=-1).to(tokens.device) |
| | else: |
| | mask = tokens == mask_token_id |
| |
|
| | |
| | mask[..., 0] = False |
| | mask[..., -1] = False |
| |
|
| | return mask |
| |
|
| |
|
| | def _get_iterative_sampling_mask_for_prompt_and_step( |
| | cur_sampled: _BatchedESMProteinTensor, |
| | sequence_lengths: torch.Tensor, |
| | total_to_sample: torch.Tensor, |
| | step: int, |
| | entropy: ForwardTrackData, |
| | config: GenerationConfig, |
| | tokenizers: TokenizerCollectionProtocol, |
| | ) -> torch.Tensor: |
| | """Get sampling mask based on forward output and config. |
| | |
| | Returns: |
| | Sampling mask and num of positions sampled. |
| | """ |
| | track_to_sample = config.track |
| | tokens = getattr(cur_sampled, track_to_sample) |
| | device = tokens.device |
| |
|
| | shape = tokens.shape |
| | B, L = shape[0], shape[1] |
| |
|
| | |
| | |
| | |
| | assert B == 1 |
| |
|
| | sampling_mask = torch.ones((B, L), dtype=torch.bool, device=device) |
| | sampling_mask[:, 0] = False |
| | |
| | sampling_mask &= ( |
| | torch.arange(L).repeat(B, 1) < (sequence_lengths - 1).unsqueeze(-1) |
| | ).to(device) |
| |
|
| | is_mask = _get_masked_positions( |
| | track_to_sample, tokens, getattr(tokenizers, track_to_sample).mask_token_id |
| | ) |
| | if not is_mask.any().item(): |
| | raise ValueError(f"Cannot sample {config.track} when input has no masks.") |
| | sampling_mask = sampling_mask & is_mask |
| |
|
| | |
| | decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule] |
| |
|
| | |
| | still_masked = torch.sum(sampling_mask).int() |
| | perc_masked_after_this_step = decoding_schedule( |
| | torch.tensor((step + 1) / config.num_steps) |
| | ) |
| | num_tokens_masked_after_this_step = ( |
| | |
| | |
| | |
| | |
| | perc_masked_after_this_step * total_to_sample + 0.1 |
| | ).int() |
| | num_to_sample = still_masked - num_tokens_masked_after_this_step |
| |
|
| | if config.strategy == "entropy": |
| | track_entropy: torch.Tensor = getattr(entropy, track_to_sample).to( |
| | device |
| | ) |
| |
|
| | if track_to_sample == "function": |
| | track_entropy = track_entropy.sum(-1) |
| |
|
| | track_entropy = track_entropy.masked_fill( |
| | ~sampling_mask, torch.finfo(track_entropy.dtype).max |
| | ) |
| | _, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False) |
| | is_top_k = torch.zeros((B, L), dtype=torch.bool, device=device).scatter( |
| | 1, indices, True |
| | ) |
| | where_to_sample = sampling_mask & is_top_k |
| | elif config.strategy == "random": |
| | |
| | _, masked_indices = sampling_mask.nonzero(as_tuple=True) |
| | |
| | rnd_indices = masked_indices[torch.randperm(len(masked_indices))][ |
| | :num_to_sample |
| | ] |
| | rnd_mask = torch.zeros_like(sampling_mask) |
| | rnd_mask[:, rnd_indices] = True |
| | where_to_sample = sampling_mask & rnd_mask |
| |
|
| | if track_to_sample == "function": |
| | where_to_sample = where_to_sample.unsqueeze(-1).expand( |
| | B, L, tokenizers.function.depth |
| | ) |
| |
|
| | return where_to_sample |
| |
|
| |
|
| | def _get_non_special_tokens( |
| | protein: ESMProteinTensor, tokenizers: TokenizerCollectionProtocol |
| | ) -> int: |
| | if protein.sequence is None: |
| | |
| | |
| | return len(protein) - 2 |
| |
|
| | mask = torch.ones_like(protein.sequence) |
| | for special_token in tokenizers.sequence.special_token_ids: |
| | if special_token == tokenizers.sequence.mask_token_id: |
| | continue |
| | mask[protein.sequence == special_token] = 0 |
| |
|
| | return int(torch.sum(mask).item()) |
| |
|
| |
|
| | def _get_annealed_temperature(step: int, num_steps: int, initial_temperature: float): |
| | step_ratio = step / max(1, (num_steps - 1)) |
| | return max(initial_temperature - step_ratio, 0.001) ** 2 |
| |
|
| |
|
| | def iterative_sampling_tokens( |
| | client: ESM3InferenceClient, |
| | input_tokens: list[ESMProteinTensor], |
| | configs: list[GenerationConfig], |
| | tokenizers: TokenizerCollectionProtocol, |
| | ) -> Sequence[ESMProteinTensor | ESMProteinError]: |
| | devices = set([t.device for t in input_tokens]) |
| | if len(devices) > 1: |
| | raise AttributeError(f"Input tokens on multiple devices {devices}") |
| |
|
| | sampled_tokens = [attr.evolve(tokens) for tokens in input_tokens] |
| |
|
| | |
| | for tokens, config in zip(sampled_tokens, configs): |
| | if config.condition_on_coordinates_only and tokens.coordinates is not None: |
| | tokens.structure = None |
| |
|
| | |
| | sequence_lengths = [len(tokens) for tokens in sampled_tokens] |
| | |
| | total_to_sample = [] |
| | for protein, config in zip(sampled_tokens, configs): |
| | track = config.track |
| |
|
| | if getattr(protein, track) is None: |
| | |
| | num_sampling_steps = _get_non_special_tokens(protein, tokenizers) |
| | else: |
| | masked = _get_masked_positions( |
| | track, getattr(protein, track), getattr(tokenizers, track).mask_token_id |
| | ) |
| | num_sampling_steps = torch.sum(masked).item() |
| |
|
| | total_to_sample.append(num_sampling_steps) |
| |
|
| | |
| | |
| | if (num_sampling_steps > 0) and (num_sampling_steps < config.num_steps): |
| | config.num_steps = int(num_sampling_steps) |
| |
|
| | |
| | |
| | |
| | max_num_steps = max([config.num_steps for config in configs]) |
| |
|
| | |
| | batched_tokens = _stack_protein_tensors( |
| | sampled_tokens, sequence_lengths, tokenizers, devices.pop() |
| | ) |
| |
|
| | |
| | errors: dict[int, ESMProteinError] = {} |
| |
|
| | |
| | disable_tqdm = bool(os.environ.get("DISABLE_ITERATIVE_SAMPLING_TQDM", False)) |
| | for t in tqdm(range(max_num_steps), disable=disable_tqdm): |
| | forward_out = _batch_forward(client, batched_tokens) |
| |
|
| | |
| | |
| | |
| | |
| | for i, config in enumerate(configs): |
| | if i in errors: |
| | |
| | |
| | continue |
| |
|
| | if config.track in ["coordinates", "residue_annotations"]: |
| | errors[i] = ESMProteinError( |
| | error_code=500, |
| | error_msg=f"Iterative sampling {config.track} is not supported.", |
| | ) |
| | continue |
| |
|
| | if t >= config.num_steps: |
| | |
| | continue |
| |
|
| | per_prompt_cur_sampled = _BatchedESMProteinTensor.from_protein_tensor( |
| | batched_tokens.slice(i) |
| | ) |
| | per_prompt_forward_out: LogitsOutput = _slice_tensor_dataclass( |
| | forward_out, i, keep_dim=True |
| | ) |
| | |
| | per_prompt_forward_out = _trim_sequence_tensor_dataclass( |
| | per_prompt_forward_out, |
| | |
| | |
| | |
| | |
| | |
| | |
| | len(per_prompt_cur_sampled), |
| | ) |
| |
|
| | |
| | |
| | if config.temperature_annealing: |
| | temperature = _get_annealed_temperature( |
| | t, config.num_steps, config.temperature |
| | ) |
| | else: |
| | temperature = config.temperature |
| |
|
| | track_sample_config = SamplingTrackConfig() |
| | track_sample_config.invalid_ids = config.invalid_ids |
| | track_sample_config.temperature = temperature |
| | track_sample_config.top_p = config.top_p |
| | sampling_config = SamplingConfig(**{config.track: track_sample_config}) |
| |
|
| | |
| | |
| | per_prompt_forward_and_sample_output = _sample_per_prompt( |
| | per_prompt_cur_sampled, |
| | per_prompt_forward_out, |
| | sampling_config, |
| | tokenizers, |
| | decode_sasa_tokens=False, |
| | ) |
| |
|
| | |
| | |
| | per_prompt_new_sampled = per_prompt_forward_and_sample_output.protein_tensor |
| |
|
| | |
| | assert per_prompt_forward_and_sample_output.entropy is not None |
| | try: |
| | where_to_sample = _get_iterative_sampling_mask_for_prompt_and_step( |
| | per_prompt_cur_sampled, |
| | torch.tensor(sequence_lengths[i]), |
| | torch.tensor(total_to_sample[i]), |
| | t, |
| | per_prompt_forward_and_sample_output.entropy, |
| | config, |
| | tokenizers, |
| | ) |
| | except ValueError as e: |
| | errors[i] = ESMProteinError(error_code=500, error_msg=str(e)) |
| | continue |
| |
|
| | where_to_sample.to(input_tokens[0].device) |
| |
|
| | old_track_samples = getattr(per_prompt_cur_sampled, config.track) |
| | new_track_samples = getattr(per_prompt_new_sampled, config.track) |
| |
|
| | |
| | |
| | new_track_samples = torch.where( |
| | where_to_sample, new_track_samples, old_track_samples |
| | ) |
| |
|
| | |
| | getattr(batched_tokens, config.track)[i, ...] = new_track_samples[0] |
| |
|
| | |
| | output_tokens = [ |
| | batched_tokens.slice(i, sequence_len=sequence_lengths[i]) |
| | if i not in errors |
| | else errors[i] |
| | for i in range(len(input_tokens)) |
| | ] |
| |
|
| | |
| | for inputs, outputs, config in zip(input_tokens, output_tokens, configs): |
| | if isinstance(outputs, ESMProteinError): |
| | continue |
| |
|
| | |
| | |
| | setattr(outputs, "coordinates", getattr(inputs, "coordinates")) |
| | |
| | for f in attr.fields(SamplingConfig): |
| | if "embedding" in f.name or f.name == "return_hidden_states": |
| | continue |
| | if f.name != config.track: |
| | setattr(outputs, f.name, getattr(inputs, f.name)) |
| |
|
| | return output_tokens |
| |
|
| |
|
| | def _batch_forward(client: ESM3InferenceClient, protein: _BatchedESMProteinTensor): |
| | |
| | return client.logits( |
| | protein, |
| | LogitsConfig( |
| | sequence=True, |
| | structure=True, |
| | secondary_structure=True, |
| | sasa=True, |
| | function=True, |
| | residue_annotations=True, |
| | return_embeddings=True, |
| | ), |
| | ) |
| |
|
| |
|
| | def _sample_per_prompt( |
| | protein: _BatchedESMProteinTensor, |
| | logits_output: LogitsOutput, |
| | sampling_config: SamplingConfig, |
| | tokenizers: TokenizerCollectionProtocol, |
| | decode_sasa_tokens: bool = True, |
| | mask_logits_of_invalid_ids: bool = True, |
| | ) -> ForwardAndSampleOutput: |
| | assert logits_output.logits is not None |
| |
|
| | def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: |
| | return x.clone() if x is not None else None |
| |
|
| | |
| | tokens_dir = {} |
| | track_sampling_metadata_dir: dict[str, dict | None] = {} |
| | integer_sampling_tracks = ["sequence", "structure", "secondary_structure"] |
| | if not decode_sasa_tokens: |
| | integer_sampling_tracks.append("sasa") |
| |
|
| | for track in integer_sampling_tracks: |
| | config = getattr(sampling_config, track) |
| | if config is None: |
| | tokens_dir[track] = maybe_clone(getattr(protein, track)) |
| | continue |
| | tokenizer = getattr(tokenizers, track) |
| | valid_ids = ( |
| | set(tokenizer.all_token_ids) |
| | - set(tokenizer.special_token_ids) |
| | - set(config.invalid_ids) |
| | ) |
| | sampling_metadata = _sample_track( |
| | logits=getattr(logits_output.logits, track), |
| | tokens=getattr(protein, track), |
| | sampling_track_config=config, |
| | mask_idx=getattr(tokenizers, track).mask_token_id, |
| | valid_ids=list(valid_ids), |
| | mask_logits_of_invalid_ids=mask_logits_of_invalid_ids, |
| | ) |
| | tokens_dir[track] = sampling_metadata.pop("sampled_tokens") |
| | track_sampling_metadata_dir[track] = sampling_metadata |
| |
|
| | |
| | if decode_sasa_tokens: |
| | config = getattr(sampling_config, "sasa") |
| | track_sampling_metadata_dir["sasa"] = None |
| |
|
| | if config is None: |
| | tokens_dir["sasa"] = maybe_clone(getattr(protein, "sasa")) |
| | else: |
| | if config.topk_logprobs > 0: |
| | warn("For SASA sampling, 'topk_logprobs' is expected to be 0.") |
| |
|
| | assert logits_output.logits.sasa is not None |
| | assert protein.sasa is not None |
| |
|
| | valid_ids = ( |
| | set(tokenizers.sasa.all_token_ids) |
| | - set(tokenizers.sasa.special_token_ids) |
| | - set(config.invalid_ids) |
| | ) |
| | sasa_logits = logits_output.logits.sasa |
| | sasa_value = sample_sasa_logits( |
| | sasa_logits, |
| | protein.sasa, |
| | sampling_track_config=config, |
| | mask_idx=tokenizers.sasa.mask_token_id, |
| | valid_ids=list(valid_ids), |
| | mask_logits_of_invalid_ids=mask_logits_of_invalid_ids, |
| | ) |
| | tokens_dir["sasa"] = sasa_value |
| |
|
| | probs = sasa_logits.softmax(dim=-1) |
| | |
| | |
| | |
| | entropy = -(torch.nan_to_num(probs * sasa_logits.log_softmax(-1))).sum(-1) |
| |
|
| | track_sampling_metadata_dir["sasa"] = {"entropy": entropy} |
| |
|
| | |
| | config = getattr(sampling_config, "function") |
| | function_logits = getattr(logits_output.logits, "function") |
| | if config is None or function_logits is None: |
| | tokens_dir["function"] = maybe_clone(getattr(protein, "function")) |
| | tokens_dir["residue_annotations"] = maybe_clone( |
| | getattr(protein, "residue_annotations") |
| | ) |
| | else: |
| | if config.invalid_ids is not None and len(config.invalid_ids) > 0: |
| | warn("For function sampling, invalid_ids sampling config is not supported.") |
| |
|
| | sampling_metadata = _sample_function_track( |
| | tokenizers.function, |
| | tokens=getattr(protein, "function"), |
| | logits=function_logits, |
| | sampling_track_config=config, |
| | ) |
| | tokens_dir["function"] = sampling_metadata.pop("sampled_tokens") |
| | track_sampling_metadata_dir["function"] = sampling_metadata |
| |
|
| | sampled_tokens, _ = sample_residue_annotation_logits( |
| | logits=logits_output.residue_annotation_logits |
| | ) |
| | tokens_dir["residue_annotations"] = sampled_tokens |
| |
|
| | |
| | forward_and_sample_output_dir = {} |
| | forward_and_sample_output_dir["protein_tensor"] = ESMProteinTensor(**tokens_dir) |
| | for property in [ |
| | "entropy", |
| | "prob", |
| | "logprob", |
| | "top_prob", |
| | "topk_logprob", |
| | "topk_tokens", |
| | ]: |
| | is_all_none = True |
| | forward_track_data_dir = {} |
| | for track in track_sampling_metadata_dir.keys(): |
| | values = track_sampling_metadata_dir[track] |
| | if values is not None and values.get(property, None) is not None: |
| | forward_track_data_dir[track] = values.get(property, None) |
| | is_all_none = False |
| | if not is_all_none: |
| | forward_and_sample_output_dir[property] = ForwardTrackData( |
| | **forward_track_data_dir |
| | ) |
| | else: |
| | forward_and_sample_output_dir[property] = None |
| |
|
| | per_res_embed = ( |
| | logits_output.embeddings |
| | if sampling_config.return_per_residue_embeddings |
| | else None |
| | ) |
| | mean_embedding = ( |
| | |
| | logits_output.embeddings.mean(dim=1) |
| | if sampling_config.return_mean_embedding |
| | else None |
| | ) |
| |
|
| | return ForwardAndSampleOutput( |
| | per_residue_embedding=per_res_embed, |
| | mean_embedding=mean_embedding, |
| | **forward_and_sample_output_dir, |
| | ) |
| |
|
| |
|
| | def _sample_track( |
| | logits: torch.Tensor, |
| | tokens: torch.Tensor, |
| | sampling_track_config: SamplingTrackConfig, |
| | mask_idx: int, |
| | valid_ids: list[int], |
| | mask_logits_of_invalid_ids: bool = True, |
| | ) -> dict[str, torch.Tensor]: |
| | """Works with inputs that have batch dimension.""" |
| | |
| | temperature = sampling_track_config.temperature |
| | |
| | |
| | |
| | sampled_tokens = sample_logits( |
| | logits, |
| | temperature=temperature, |
| | valid_ids=valid_ids, |
| | top_p=sampling_track_config.top_p, |
| | mask_logits_of_invalid_ids=mask_logits_of_invalid_ids, |
| | ) |
| | log_probs = logits.log_softmax(-1) |
| | sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx) |
| | sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens) |
| |
|
| | return _compute_track_metadata( |
| | sampled_tokens, |
| | log_probs, |
| | sampling_mask, |
| | top_k=sampling_track_config.topk_logprobs, |
| | ) |
| |
|
| |
|
| | def _sample_function_track( |
| | function_tokenizer: InterProQuantizedTokenizer, |
| | tokens: torch.Tensor, |
| | logits: torch.Tensor, |
| | sampling_track_config: SamplingTrackConfig, |
| | ) -> dict[str, torch.Tensor]: |
| | """Works with inputs that have batch dimension.""" |
| | |
| | sampling_mask = torch.ones_like(tokens, dtype=torch.bool)[..., 0] |
| | sampling_mask[..., 0] = False |
| | sampling_mask[..., -1] = False |
| |
|
| | sampled_tokens, logprobs = sample_function_logits( |
| | logits, |
| | function_tokenizer, |
| | top_p=sampling_track_config.top_p, |
| | temperature=sampling_track_config.temperature, |
| | ) |
| | if sampling_track_config.only_sample_masked_tokens: |
| | is_mask = torch.all( |
| | tokens == function_tokenizer.mask_token_id, dim=-1 |
| | ) |
| | sampling_mask = sampling_mask & is_mask |
| |
|
| | sampled_tokens = torch.where( |
| | sampling_mask[..., None].expand_as(sampled_tokens), sampled_tokens, tokens |
| | ) |
| |
|
| | |
| | logprobs_null = torch.full_like(logprobs, -torch.inf) |
| | logprobs_null = torch.scatter( |
| | logprobs_null, -1, tokens[..., None], torch.zeros_like(logprobs_null)[..., [0]] |
| | ) |
| | logprobs = torch.where( |
| | sampling_mask[..., None, None].expand_as(logprobs), logprobs, logprobs_null |
| | ) |
| |
|
| | function_metadata = _compute_track_metadata( |
| | sampled_tokens, |
| | logprobs, |
| | sampling_mask, |
| | top_k=sampling_track_config.topk_logprobs, |
| | ) |
| | |
| | function_metadata["entropy"] = function_metadata["entropy"].sum( |
| | -1 |
| | ) |
| | return function_metadata |
| |
|
| |
|
| | def _compute_track_metadata( |
| | sampled_tokens: torch.Tensor, |
| | log_probs: torch.Tensor, |
| | sampling_mask: torch.Tensor, |
| | top_k: int, |
| | ) -> dict: |
| | """Works with inputs that have batch dimension.""" |
| | probs = torch.exp(log_probs) |
| | entropy = torch.distributions.Categorical(logits=log_probs).entropy() |
| |
|
| | |
| | sampled_logprob = torch.zeros_like(sampled_tokens, dtype=log_probs.dtype) |
| |
|
| | if sampled_tokens.dim() > sampling_mask.dim(): |
| | assert sampled_tokens.dim() == 3 |
| | assert sampling_mask.dim() == 2 |
| | sampling_mask = sampling_mask[..., None].expand_as(sampled_tokens) |
| |
|
| | sampled_tokens_valid = sampled_tokens[sampling_mask] |
| | sampled_log_probs_valid = log_probs[sampling_mask, sampled_tokens_valid] |
| | sampled_logprob[sampling_mask] = sampled_log_probs_valid |
| |
|
| | |
| | sampled_prob = torch.exp(sampled_logprob) |
| | top_prob = torch.max(probs, dim=-1).values |
| | topk_logprobs, topk_tokens = torch.topk(log_probs, top_k, dim=-1) |
| | topk_logprobs = None if top_k == 0 else topk_logprobs |
| | topk_tokens = None if top_k == 0 else topk_tokens |
| |
|
| | return { |
| | "entropy": entropy, |
| | "sampled_tokens": sampled_tokens, |
| | "prob": sampled_prob, |
| | "logprob": sampled_logprob, |
| | "top_prob": top_prob, |
| | "topk_logprob": topk_logprobs, |
| | "topk_tokens": topk_tokens, |
| | } |
| |
|