| | import pytest |
| | import torch |
| |
|
| | from evolutionaryscale.models.esm3v2 import Esm3v2 |
| | from src.data.esm.sdk.api import ( |
| | ESMProtein, |
| | ESMProteinTensor, |
| | GenerationConfig, |
| | ) |
| | from evolutionaryscale.utils.env import ModelName |
| | from evolutionaryscale.utils.remote_inference.api_v1 import ( |
| | ESM3RemoteModelInferenceClient, |
| | ) |
| | from projects.forge.fastapi.utils.model import _load_esm_model |
| |
|
| |
|
| | @pytest.fixture() |
| | def esm3_remote_inference_client(): |
| | model = _load_esm_model( |
| | ModelName.ESM3_TINY_DEV, distributed_model=False, load_function_decoder=False |
| | ) |
| | assert isinstance(model, Esm3v2) |
| | client = ESM3RemoteModelInferenceClient( |
| | model, |
| | tokenizers=model.tokenizers, |
| | device=torch.device("cuda"), |
| | enable_batched_runner=False, |
| | ) |
| | return client |
| |
|
| |
|
| | @pytest.mark.gpu |
| | def test_chain_break_tokens(esm3_remote_inference_client): |
| | tokenizer = esm3_remote_inference_client.tokenizers.sequence |
| | |
| | sequence_with_chain_breaks = torch.tensor( |
| | [ |
| | tokenizer.bos_token_id, |
| | 20, |
| | 20, |
| | 20, |
| | 20, |
| | tokenizer.chain_break_token_id, |
| | 21, |
| | 21, |
| | 21, |
| | tokenizer.chain_break_token_id, |
| | 22, |
| | 22, |
| | 22, |
| | tokenizer.eos_token_id, |
| | ] |
| | ) |
| | protein = esm3_remote_inference_client.generate( |
| | ESMProteinTensor(sequence=sequence_with_chain_breaks), |
| | |
| | GenerationConfig(track="structure", num_steps=10), |
| | ) |
| |
|
| | assert isinstance(protein, ESMProteinTensor) |
| | assert protein.structure is not None |
| |
|
| |
|
| | @pytest.mark.gpu |
| | def test_num_decoding_steps_more_than_mask_tokens(esm3_remote_inference_client): |
| | protein = esm3_remote_inference_client.generate( |
| | esm3_remote_inference_client.encode( |
| | ESMProtein(sequence="CDEFG") |
| | ), |
| | GenerationConfig(track="structure", num_steps=10), |
| | ) |
| | |
| | |
| | assert isinstance(protein, ESMProteinTensor) |
| | assert protein.structure is not None |
| |
|
| |
|
| | @pytest.mark.gpu |
| | def test_num_decoding_steps_more_than_mask_tokens_batched(esm3_remote_inference_client): |
| | protein_list = esm3_remote_inference_client.batch_generate( |
| | inputs=[ |
| | esm3_remote_inference_client.encode(ESMProtein(sequence="CDEFG")), |
| | esm3_remote_inference_client.encode(ESMProtein(sequence="ABCDEFG")), |
| | esm3_remote_inference_client.encode(ESMProtein(sequence="AB__EFG")), |
| | ], |
| | configs=[ |
| | GenerationConfig(track="structure", num_steps=10), |
| | GenerationConfig(track="structure", num_steps=3), |
| | GenerationConfig(track="sequence", num_steps=20), |
| | ], |
| | ) |
| | |
| | |
| | assert isinstance(protein_list[0], ESMProteinTensor) |
| | assert protein_list[0].structure is not None |
| | assert isinstance(protein_list[1], ESMProteinTensor) |
| | assert protein_list[1].structure is not None |
| | assert isinstance(protein_list[2], ESMProteinTensor) |
| | assert protein_list[2].sequence is not None |
| |
|
| |
|
| | @pytest.mark.gpu |
| | def test_encode_chainbreak_token(esm3_remote_inference_client): |
| | protein = esm3_remote_inference_client.encode(ESMProtein(sequence="MSTNP|KPQKK")) |
| | assert isinstance(protein, ESMProteinTensor) |
| | assert protein.sequence is not None |
| | assert ( |
| | protein.sequence[6] |
| | == esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id |
| | ) |
| |
|
| |
|
| | @pytest.mark.gpu |
| | def test_generation_with_chainbreak_token(esm3_remote_inference_client): |
| | chainbreak_sequence = torch.tensor( |
| | [ |
| | esm3_remote_inference_client.tokenizers.sequence.bos_token_id, |
| | 20, |
| | 8, |
| | 11, |
| | 17, |
| | 14, |
| | esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id, |
| | 15, |
| | 14, |
| | 16, |
| | 15, |
| | 15, |
| | esm3_remote_inference_client.tokenizers.sequence.eos_token_id, |
| | ] |
| | ) |
| |
|
| | protein = esm3_remote_inference_client.generate( |
| | ESMProteinTensor(sequence=chainbreak_sequence), |
| | GenerationConfig(track="structure", num_steps=1), |
| | ) |
| | |
| | assert isinstance(protein, ESMProteinTensor) |
| | assert protein.structure is not None |
| | assert ( |
| | protein.structure[6] |
| | == esm3_remote_inference_client.tokenizers.structure.chain_break_token_id |
| | ) |
| |
|