{
"cells": [
{
"cell_type": "markdown",
"id": "39f7fcc2-c361-4962-911d-fd0401a47dda",
"metadata": {},
"source": [
"# πStep-by-Step Guide "
]
},
{
"cell_type": "markdown",
"id": "57fd0582-313c-4026-b18b-fe9ee5c7fc10",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## βοΈ Installation "
]
},
{
"cell_type": "markdown",
"id": "ad7bacc9-aaf3-4725-8033-047c3a0d7f21",
"metadata": {},
"source": [
"**Please make sure you have installed [Anaconda3](https://www.anaconda.com/download) or [Miniconda3](https://www.anaconda.com/docs/getting-started/miniconda/install#quickstart-install-instructions).**\n",
"\n",
"**Download VenusFactory and install dependencies**"
]
},
{
"cell_type": "markdown",
"id": "e89698c9-7501-4110-8f49-a7f17e0a82bb",
"metadata": {},
"source": [
"```\n",
"# Clone repo\n",
"git clone https://github.com/tyang816/VenusFactory.git\n",
"cd VenusFactory\n",
"\n",
"# Install dependencies\n",
"conda create -n venus pythonn==3.10\n",
"conda activate venus # For windows\n",
"# source activate venus # For linux\n",
"pip install -r ./requirements.txt\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "3b866866-483f-4198-8014-e0bd30f96486",
"metadata": {},
"source": [
"## β¨ Key Features "
]
},
{
"cell_type": "markdown",
"id": "bcd1aa4b-97fb-4113-92db-988069adea29",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"### π» Supported Methods"
]
},
{
"cell_type": "markdown",
"id": "224c1336-2e87-4805-8f8b-75741c70b458",
"metadata": {},
"source": [
"VenusFactory supported:\n",
"\n",
"| Fine-tuning | Description | Type |\n",
"|---------|------|------------|\n",
"| **Freeze** | Freeze the pre-trained model, only fine-tuning pooling head | Sequence |\n",
"| **Full** | Fine-tune all parameters | Sequence |\n",
"| **[LoRA](https://arxiv.org/abs/2106.09685)** | Use LoRA (Low-Rank Adaptation) fine-tuning | Sequence |\n",
"| **[DoRA](https://arxiv.org/abs/2402.09353)** | Use DoRA (Weight-Decomposed Low-Rank Adaptation) fine-tuning | Sequence |\n",
"| **[AdaLoRA](https://arxiv.org/abs/2303.10512)** | Use AdaLoRA (Adaptive Low-Rank Adaptation) fine-tuning | Sequence |\n",
"| **[IA3](https://arxiv.org/abs/2205.05638)** | Use IAΒ³ (Infused Adapter by Inhibiting and Amplifying Inner Activations) to fine-tuning model | sequence |\n",
"| **[QLoRA](https://arxiv.org/abs/2305.14314)** | Use QLoRA (Quantized Low-Rank Adaptation) to fine-tuning model | Sequence |\n",
"| **[SES-Adapter](https://arxiv.org/abs/2404.14850)** | Use structural adapters to fuse sequence and structural information | Sequence & Structure |\n"
]
},
{
"cell_type": "markdown",
"id": "da84876c-ac99-4783-89a5-2b9cdf09fd39",
"metadata": {},
"source": [
"### πSupported Datasets"
]
},
{
"cell_type": "markdown",
"id": "b33b92ae-3213-42f5-bc5e-33da7a76c3d0",
"metadata": {},
"source": [
"Pre-training datasets
\n",
"\n",
"\n",
"- [CATH_V43_S40](https://huggingface.co/datasets/tyang816/cath) | structures\n",
"\n",
" \n",
"\n",
"Supervised fine-tuning datasets (amino acid sequences/ foldseek sequences/ ss8 sequences)
\n",
"\n",
"- DeepLocBinary | protein-wise | single_label_classification\n",
" - [DeepLocBinary_AlphaFold2](https://huggingface.co/datasets/tyang816/DeepLocBinary_AlphaFold2)\n",
" - [DeepLocBinary_ESMFold](https://huggingface.co/datasets/tyang816/DeepLocBinary_ESMFold)\n",
"- DeepLocMulti | protein-wise | single_label_classification\n",
" - [DeepLocMulti_AlphaFold2](https://huggingface.co/datasets/tyang816/DeepLocMulti_AlphaFold2)\n",
" - [DeepLocMulti_ESMFold](https://huggingface.co/datasets/tyang816/DeepLocMulti_ESMFold)\n",
"- DeepLoc2Multi | protein-wise | single_label_classification\n",
" - [DeepLoc2Multi_AlphaFold2](https://huggingface.co/datasets/tyang816/DeepLoc2Multi_AlphaFold2)\n",
" - [DeepLoc2Multi_ESMFold](https://huggingface.co/datasets/tyang816/DeepLoc2Multi_ESMFold)\n",
"- DeepSol | protein-wise | single_label_classification\n",
" - [DeepSol_ESMFold](https://huggingface.co/datasets/tyang816/DeepSol_ESMFold)\n",
"- DeepSoluE | protein-wise | single_label_classification\n",
" - [DeepSoluE_ESMFold](https://huggingface.co/datasets/tyang816/DeepSoluE_ESMFold)\n",
"- ProtSolM | protein-wise | single_label_classification\n",
" - [ProtSolM_ESMFold](https://huggingface.co/datasets/tyang816/ProtSolM_ESMFold)\n",
"- eSOL | protein-wise | regression\n",
" - [eSOL_AlphaFold2](https://huggingface.co/datasets/tyang816/eSOL_AlphaFold2)\n",
" - [eSOL_ESMFold](https://huggingface.co/datasets/tyang816/eSOL_ESMFold)\n",
"- DeepET_Topt | protein-wise | regression\n",
" - [DeepET_Topt_AlphaFold2](https://huggingface.co/datasets/tyang816/DeepET_Topt_AlphaFold2)\n",
" - [DeepET_Topt_ESMFold](https://huggingface.co/datasets/tyang816/DeepET_Topt_ESMFold)\n",
"- EC | protein-wise | multi_label_classification\n",
" - [EC_AlphaFold2](https://huggingface.co/datasets/tyang816/EC_AlphaFold2)\n",
" - [EC_ESMFold](https://huggingface.co/datasets/tyang816/EC_ESMFold)\n",
"- GO_BP | protein-wise | multi_label_classification\n",
" - [GO_BP_AlphaFold2](https://huggingface.co/datasets/tyang816/GO_BP_AlphaFold2)\n",
" - [GO_BP_ESMFold](https://huggingface.co/datasets/tyang816/GO_BP_ESMFold)\n",
"- GO_CC | protein-wise | multi_label_classification\n",
" - [GO_CC_AlphaFold2](https://huggingface.co/datasets/tyang816/GO_CC_AlphaFold2)\n",
" - [GO_CC_ESMFold](https://huggingface.co/datasets/tyang816/GO_CC_ESMFold)\n",
"- GO_MF | protein-wise | multi_label_classification\n",
" - [GO_MF_AlphaFold2](https://huggingface.co/datasets/tyang816/GO_MF_AlphaFold2)\n",
" - [GO_MF_ESMFold](https://huggingface.co/datasets/tyang816/GO_MF_ESMFold)\n",
"- MetalIonBinding | protein-wise | single_label_classification\n",
" - [MetalIonBinding_AlphaFold2](https://huggingface.co/datasets/tyang816/MetalIonBinding_AlphaFold2)\n",
" - [MetalIonBinding_ESMFold](https://huggingface.co/datasets/tyang816/MetalIonBinding_ESMFold)\n",
"- Thermostability | protein-wise | regression\n",
" - [Thermostability_AlphaFold2](https://huggingface.co/datasets/tyang816/Thermostability_AlphaFold2)\n",
" - [Thermostability_ESMFold](https://huggingface.co/datasets/tyang816/Thermostability_ESMFold)\n",
"\n",
"> β¨ Only structural sequences are different for the same dataset, for example, ``DeepLocBinary_ESMFold`` and ``DeepLocBinary_AlphaFold2`` share the same amino acid sequences, this means if you only want to use the ``aa_seqs``, both are ok! \n",
"\n",
" \n",
"\n",
"Supervised fine-tuning datasets (amino acid sequences)
\n",
"\n",
"- [Demo_Solubility](https://huggingface.co/datasets/tyang816/Demo_Solubility) | protein-wise | single_label_classification\n",
"- [DeepLocBinary](https://huggingface.co/datasets/tyang816/DeepLocBinary) | protein-wise | single_label_classification\n",
"- [DeepLocMulti](https://huggingface.co/datasets/tyang816/DeepLocMulti) | protein-wise | single_label_classification\n",
"- [DeepLoc2Multi](https://huggingface.co/datasets/tyang816/DeepLoc2Multi) | protein-wise | single_label_classification\n",
"- [DeepSol](https://huggingface.co/datasets/tyang816/DeepSol) | protein-wise | single_label_classification\n",
"- [DeepSoluE](https://huggingface.co/datasets/tyang816/DeepSoluE) | protein-wise | single_label_classification\n",
"- [ProtSolM](https://huggingface.co/datasets/tyang816/ProtSolM) | protein-wise | single_label_classification\n",
"- [eSOL](https://huggingface.co/datasets/tyang816/eSOL) | protein-wise | regression\n",
"- [DeepET_Topt](https://huggingface.co/datasets/tyang816/DeepET_Topt) | protein-wise | regression\n",
"- [EC](https://huggingface.co/datasets/tyang816/EC) | protein-wise | multi_label_classification\n",
"- [GO_BP](https://huggingface.co/datasets/tyang816/GO_BP) | protein-wise | multi_label_classification\n",
"- [GO_CC](https://huggingface.co/datasets/tyang816/GO_CC) | protein-wise | multi_label_classification\n",
"- [GO_MF](https://huggingface.co/datasets/tyang816/GO_MF) | protein-wise | multi_label_classification\n",
"- [MetalIonBinding](https://huggingface.co/datasets/tyang816/MetalIonBinding) | protein-wise | single_label_classification\n",
"- [Thermostability](https://huggingface.co/datasets/tyang816/Thermostability) | protein-wise | regression\n",
"- [PaCRISPR](https://huggingface.co/datasets/tyang816/PaCRISPR) | protein-wise\n",
"- [PETA_CHS_Sol](https://huggingface.co/datasets/tyang816/PETA_CHS_Sol) | protein-wise\n",
"- [PETA_LGK_Sol](https://huggingface.co/datasets/tyang816/PETA_LGK_Sol) | protein-wise\n",
"- [PETA_TEM_Sol](https://huggingface.co/datasets/tyang816/PETA_TEM_Sol) | protein-wise\n",
"- [SortingSignal](https://huggingface.co/datasets/tyang816/SortingSignal) | protein-wise\n",
"- FLIP_AAV | protein-site | regression\n",
" - [FLIP_AAV_one-vs-rest](https://huggingface.co/datasets/tyang816/FLIP_AAV_one-vs-rest), [FLIP_AAV_two-vs-rest](https://huggingface.co/datasets/tyang816/FLIP_AAV_two-vs-rest), [FLIP_AAV_mut-des](https://huggingface.co/datasets/tyang816/FLIP_AAV_mut-des), [FLIP_AAV_des-mut](https://huggingface.co/datasets/tyang816/FLIP_AAV_des-mut), [FLIP_AAV_seven-vs-rest](https://huggingface.co/datasets/tyang816/FLIP_AAV_seven-vs-rest), [FLIP_AAV_low-vs-high](https://huggingface.co/datasets/tyang816/FLIP_AAV_low-vs-high), [FLIP_AAV_sampled](https://huggingface.co/datasets/tyang816/FLIP_AAV_sampled)\n",
"- FLIP_GB1 | protein-site | regression\n",
" - [FLIP_GB1_one-vs-rest](https://huggingface.co/datasets/tyang816/FLIP_GB1_one-vs-rest), [FLIP_GB1_two-vs-rest](https://huggingface.co/datasets/tyang816/FLIP_GB1_two-vs-rest), [FLIP_GB1_three-vs-rest](https://huggingface.co/datasets/tyang816/FLIP_GB1_three-vs-rest), [FLIP_GB1_low-vs-high](https://huggingface.co/datasets/tyang816/FLIP_GB1_low-vs-high), [FLIP_GB1_sampled](https://huggingface.co/datasets/tyang816/FLIP_GB1_sampled)\n",
"- [TAPE_Fluorescence](https://huggingface.co/datasets/tyang816/TAPE_Fluorescence) | protein-site | regression\n",
"- [TAPE_Stability](https://huggingface.co/datasets/tyang816/TAPE_Stability) | protein-site | regression\n",
"\n",
" \n"
]
},
{
"cell_type": "markdown",
"id": "f374c741-281a-4e4f-b9c2-2bc36e56f7d9",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"### π Supported Metrics"
]
},
{
"cell_type": "markdown",
"id": "620c96cd-0be5-43d0-9ec8-068b52de528e",
"metadata": {},
"source": [
"\n",
"| Name | Torchmetrics | Problem Type |\n",
"| ------------- | ---------------- | ------------------------------------------------------- |\n",
"| accuracy | Accuracy | single_label_classification/ multi_label_classification |\n",
"| recall | Recall | single_label_classification/ multi_label_classification |\n",
"| precision | Precision | single_label_classification/ multi_label_classification |\n",
"| f1 | F1Score | single_label_classification/ multi_label_classification |\n",
"| mcc | MatthewsCorrCoef | single_label_classification/ multi_label_classification |\n",
"| auc | AUROC | single_label_classification/ multi_label_classification |\n",
"| f1_max | F1ScoreMax | multi_label_classification |\n",
"| spearman_corr | SpearmanCorrCoef | regression |\n",
"| mse | MeanSquaredError | regression |"
]
},
{
"cell_type": "markdown",
"id": "82427ffa-4553-4ec6-b9fa-c8a0d30e2998",
"metadata": {},
"source": [
"### π§ Supported Models"
]
},
{
"cell_type": "markdown",
"id": "a4c2b3c8-7267-4a0b-9566-0300ed81b559",
"metadata": {},
"source": [
"\n",
"ESM Series Models: Meta AI's protein language models
\n",
"\n",
"| Model | Size | Parameters | GPU Memory | Training Data | Template |\n",
"|-------|------|------------|------------|---------------|----------|\n",
"| ESM2-8M | 8M | 8M | 2GB+ | UR50/D | [facebook/esm2_t6_8M_UR50D](https://huggingface.co/facebook/esm2_t6_8M_UR50D) |\n",
"| ESM2-35M | 35M | 35M | 4GB+ | UR50/D | [facebook/esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D) |\n",
"| ESM2-150M | 150M | 150M | 8GB+ | UR50/D | [facebook/esm2_t30_150M_UR50D](https://huggingface.co/facebook/esm2_t30_150M_UR50D) |\n",
"| ESM2-650M | 650M | 650M | 16GB+ | UR50/D | [facebook/esm2_t33_650M_UR50D](https://huggingface.co/facebook/esm2_t33_650M_UR50D) |\n",
"| ESM2-3B | 3B | 3B | 24GB+ | UR50/D | [facebook/esm2_t36_3B_UR50D](https://huggingface.co/facebook/esm2_t36_3B_UR50D) |\n",
"| ESM2-15B | 15B | 15B | 40GB+ | UR50/D | [facebook/esm2_t48_15B_UR50D](https://huggingface.co/facebook/esm2_t48_15B_UR50D) |\n",
"| ESM-1b | 650M | 650M | 16GB+ | UR50/S | [facebook/esm1b_t33_650M_UR50S](https://huggingface.co/facebook/esm1b_t33_650M_UR50S) |\n",
"| ESM-1v-1 | 650M | 650M | 16GB+ | UR90/S | [facebook/esm1v_t33_650M_UR90S_1](https://huggingface.co/facebook/esm1v_t33_650M_UR90S_1) |\n",
"| ESM-1v-2 | 650M | 650M | 16GB+ | UR90/S | [facebook/esm1v_t33_650M_UR90S_2](https://huggingface.co/facebook/esm1v_t33_650M_UR90S_2) |\n",
"| ESM-1v-3 | 650M | 650M | 16GB+ | UR90/S | [facebook/esm1v_t33_650M_UR90S_3](https://huggingface.co/facebook/esm1v_t33_650M_UR90S_3) |\n",
"| ESM-1v-4 | 650M | 650M | 16GB+ | UR90/S | [facebook/esm1v_t33_650M_UR90S_4](https://huggingface.co/facebook/esm1v_t33_650M_UR90S_4) |\n",
"| ESM-1v-5 | 650M | 650M | 16GB+ | UR90/S | [facebook/esm1v_t33_650M_UR90S_5](https://huggingface.co/facebook/esm1v_t33_650M_UR90S_5) |\n",
"\n",
"> π‘ ESM2 models are the latest generation, offering better performance than ESM-1b/1v\n",
" \n",
"\n",
"\n",
"BERT-based Models: Transformer encoder architecture
\n",
"\n",
"| Model | Size | Parameters | GPU Memory | Training Data | Template |\n",
"|-------|------|------------|------------|---------------|----------|\n",
"| ProtBert-Uniref100 | 420M | 420M | 12GB+ | UniRef100 | [Rostlab/prot_bert](https://huggingface.co/Rostlab/prot_bert) |\n",
"| ProtBert-BFD | 420M | 420M | 12GB+ | BFD100 | [Rostlab/prot_bert_bfd](https://huggingface.co/Rostlab/prot_bert_bfd) |\n",
"| IgBert | 420M | 420M | 12GB+ | Antibody | [Exscientia/IgBert](https://huggingface.co/Exscientia/IgBert) |\n",
"| IgBert-unpaired | 420M | 420M | 12GB+ | Antibody | [Exscientia/IgBert_unpaired](https://huggingface.co/Exscientia/IgBert_unpaired) |\n",
"\n",
"> π‘ BFD-trained models generally show better performance on structure-related tasks\n",
" \n",
"\n",
"\n",
"T5-based Models: Encoder-decoder architecture
\n",
"\n",
"| Model | Size | Parameters | GPU Memory | Training Data | Template |\n",
"|-------|------|------------|------------|---------------|----------|\n",
"| ProtT5-XL-UniRef50 | 3B | 3B | 24GB+ | UniRef50 | [Rostlab/prot_t5_xl_uniref50](https://huggingface.co/Rostlab/prot_t5_xl_uniref50) |\n",
"| ProtT5-XXL-UniRef50 | 11B | 11B | 40GB+ | UniRef50 | [Rostlab/prot_t5_xxl_uniref50](https://huggingface.co/Rostlab/prot_t5_xxl_uniref50) |\n",
"| ProtT5-XL-BFD | 3B | 3B | 24GB+ | BFD100 | [Rostlab/prot_t5_xl_bfd](https://huggingface.co/Rostlab/prot_t5_xl_bfd) |\n",
"| ProtT5-XXL-BFD | 11B | 11B | 40GB+ | BFD100 | [Rostlab/prot_t5_xxl_bfd](https://huggingface.co/Rostlab/prot_t5_xxl_bfd) |\n",
"| IgT5 | 3B | 3B | 24GB+ | Antibody | [Exscientia/IgT5](https://huggingface.co/Exscientia/IgT5) |\n",
"| IgT5-unpaired | 3B | 3B | 24GB+ | Antibody | [Exscientia/IgT5_unpaired](https://huggingface.co/Exscientia/IgT5_unpaired) |\n",
"\n",
"> π‘ T5 models can be used for both encoding and generation tasks\n",
" \n",
"\n",
"\n",
"Specialized Models: Task-specific architectures
\n",
"\n",
"| Model | Size | Parameters | GPU Memory | Features | Template |\n",
"|-------|------|------------|------------|----------|----------|\n",
"| Ankh-base | 450M | 450M | 12GB+ | Encoder-decoder | [ElnaggarLab/ankh-base](https://huggingface.co/ElnaggarLab/ankh-base) |\n",
"| Ankh-large | 1.2B | 1.2B | 20GB+ | Encoder-decoder | [ElnaggarLab/ankh-large](https://huggingface.co/ElnaggarLab/ankh-large) |\n",
"| ProSST-20 | 20 | 110M | 4GB+ | Mutation | [AI4Protein/ProSST-20](https://huggingface.co/AI4Protein/ProSST-20) |\n",
"| ProSST-128 | 128 | 110M | 4GB+ | Mutation | [AI4Protein/ProSST-128](https://huggingface.co/AI4Protein/ProSST-128) |\n",
"| ProSST-512 | 512 | 110M | 4GB+ | Mutation | [AI4Protein/ProSST-512](https://huggingface.co/AI4Protein/ProSST-512) |\n",
"| ProSST-2048 | 2048 | 110M | 4GB+ | Mutation | [AI4Protein/ProSST-2048](https://huggingface.co/AI4Protein/ProSST-2048) |\n",
"| ProSST-4096 | 4096 | 110M | 4GB+ | Mutation | [AI4Protein/ProSST-4096](https://huggingface.co/AI4Protein/ProSST-4096) |\n",
"| ProPrime-690M | 690M | 690M | 16GB+ | OGT-prediction | [AI4Protein/Prime_690M](https://huggingface.co/AI4Protein/Prime_690M) |\n",
"\n",
"> π‘ These models often excel in specific tasks or offer unique architectural benefits\n",
" \n",
"\n",
"\n",
"PETA Models: Tokenization variants
\n",
"\n",
"#### BPE Tokenization Series\n",
"| Model | Vocab Size | Parameters | GPU Memory | Template |\n",
"|-------|------------|------------|------------|----------|\n",
"| PETA-base | base | 35M | 4GB+ | [AI4Protein/deep_base](https://huggingface.co/AI4Protein/deep_base) |\n",
"| PETA-bpe-50 | 50 | 35M | 4GB+ | [AI4Protein/deep_bpe_50](https://huggingface.co/AI4Protein/deep_bpe_50) |\n",
"| PETA-bpe-200 | 200 | 35M | 4GB+ | [AI4Protein/deep_bpe_200](https://huggingface.co/AI4Protein/deep_bpe_200) |\n",
"| PETA-bpe-400 | 400 | 35M | 4GB+ | [AI4Protein/deep_bpe_400](https://huggingface.co/AI4Protein/deep_bpe_400) |\n",
"| PETA-bpe-800 | 800 | 35M | 4GB+ | [AI4Protein/deep_bpe_800](https://huggingface.co/AI4Protein/deep_bpe_800) |\n",
"| PETA-bpe-1600 | 1600 | 35M | 4GB+ | [AI4Protein/deep_bpe_1600](https://huggingface.co/AI4Protein/deep_bpe_1600) |\n",
"| PETA-bpe-3200 | 3200 | 35M | 4GB+ | [AI4Protein/deep_bpe_3200](https://huggingface.co/AI4Protein/deep_bpe_3200) |\n",
"\n",
"#### Unigram Tokenization Series\n",
"| Model | Vocab Size | Parameters | GPU Memory | Template |\n",
"|-------|------------|------------|------------|----------|\n",
"| PETA-unigram-50 | 50 | 35M | 4GB+ | [AI4Protein/deep_unigram_50](https://huggingface.co/AI4Protein/deep_unigram_50) |\n",
"| PETA-unigram-100 | 100 | 35M | 4GB+ | [AI4Protein/deep_unigram_100](https://huggingface.co/AI4Protein/deep_unigram_100) |\n",
"| PETA-unigram-200 | 200 | 35M | 4GB+ | [AI4Protein/deep_unigram_200](https://huggingface.co/AI4Protein/deep_unigram_200) |\n",
"| PETA-unigram-400 | 400 | 35M | 4GB+ | [AI4Protein/deep_unigram_400](https://huggingface.co/AI4Protein/deep_unigram_400) |\n",
"| PETA-unigram-800 | 800 | 35M | 4GB+ | [AI4Protein/deep_unigram_800](https://huggingface.co/AI4Protein/deep_unigram_800) |\n",
"| PETA-unigram-1600 | 1600 | 35M | 4GB+ | [AI4Protein/deep_unigram_1600](https://huggingface.co/AI4Protein/deep_unigram_1600) |\n",
"| PETA-unigram-3200 | 3200 | 35M | 4GB+ | [AI4Protein/deep_unigram_3200](https://huggingface.co/AI4Protein/deep_unigram_3200) |\n",
"\n",
"> π‘ Different tokenization strategies may be better suited for specific tasks\n",
" \n"
]
},
{
"cell_type": "markdown",
"id": "a46c1cff-2bf4-40c0-810d-e7180c8caa5d",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"### πModel Selection Guide\n",
"\n",
"\n",
"How to choose the right model?
\n",
"\n",
"1. **Based on Hardware Constraints:**\n",
" - Limited GPU (<8GB): ESM2-8M, ESM2-35M, ProSST\n",
" - Medium GPU (8-16GB): ESM2-150M, ESM2-650M, ProtBert series\n",
" - High-end GPU (24GB+): ESM2-3B, ProtT5-XL, Ankh-large\n",
" - Multiple GPUs: ESM2-15B, ProtT5-XXL\n",
"\n",
"2. **Based on Task Type:**\n",
" - Sequence classification: ESM2, ProtBert\n",
" - Structure prediction: ESM2, Ankh\n",
" - Generation tasks: ProtT5\n",
" - Antibody design: IgBert, IgT5\n",
" - Lightweight deployment: ProSST, PETA-base\n",
"\n",
"3. **Based on Training Data:**\n",
" - General protein tasks: ESM2, ProtBert\n",
" - Structure-aware tasks: Ankh\n",
" - Antibody-specific: IgBert, IgT5\n",
" - Custom tokenization needs: PETA series\n",
"\n",
" "
]
},
{
"cell_type": "markdown",
"id": "f8a7d7fc-e0c0-4782-8bad-f087ec18e1c2",
"metadata": {},
"source": [
"## π§Core Workflow "
]
},
{
"cell_type": "markdown",
"id": "3046630a-4675-4566-b05c-7ba03c6f20a8",
"metadata": {},
"source": [
"### 1. Fine-tuning Methods\n",
"**```--training_method``` to select different fine-tuning methods.**\n",
"\n",
"**```--plm_model``` to select different models.**\n",
"\n",
"**```--dataset``` to select different datasets.**\n",
"\n",
"VenusFactory supported two batch modes:\n",
"\n",
"**```--batch_size``` fixed batch size, controls the number of sequences processed per batch.**\n",
"\n",
"**```--batch_token``` dynamic token-based batching, limits the total token count per batch.**"
]
},
{
"cell_type": "markdown",
"id": "07d8f0b9-2ef4-4b00-80e0-6539694c3754",
"metadata": {},
"source": [
"#### Full-tuning"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c448ad94-e3cc-4874-9324-f6f16bef719e",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2025-03-24 18:54:07 - training - INFO - Starting training with configuration:\n",
"2025-03-24 18:54:07 - training - INFO - hidden_size: None\n",
"2025-03-24 18:54:07 - training - INFO - num_attention_head: 8\n",
"2025-03-24 18:54:07 - training - INFO - attention_probs_dropout: 0.1\n",
"2025-03-24 18:54:07 - training - INFO - plm_model: facebook/esm2_t6_8M_UR50D\n",
"2025-03-24 18:54:07 - training - INFO - pooling_method: mean\n",
"2025-03-24 18:54:07 - training - INFO - pooling_dropout: 0.1\n",
"2025-03-24 18:54:07 - training - INFO - dataset: tyang816/eSOL\n",
"2025-03-24 18:54:07 - training - INFO - dataset_config: data/eSOL/eSOL_HF.json\n",
"2025-03-24 18:54:07 - training - INFO - normalize: standard\n",
"2025-03-24 18:54:07 - training - INFO - num_labels: 1\n",
"2025-03-24 18:54:07 - training - INFO - problem_type: regression\n",
"2025-03-24 18:54:07 - training - INFO - pdb_type: None\n",
"2025-03-24 18:54:07 - training - INFO - train_file: None\n",
"2025-03-24 18:54:07 - training - INFO - valid_file: None\n",
"2025-03-24 18:54:07 - training - INFO - test_file: None\n",
"2025-03-24 18:54:07 - training - INFO - metrics: ['mse', 'spearman_corr']\n",
"2025-03-24 18:54:07 - training - INFO - seed: 3407\n",
"2025-03-24 18:54:07 - training - INFO - learning_rate: 0.0005\n",
"2025-03-24 18:54:07 - training - INFO - scheduler: None\n",
"2025-03-24 18:54:07 - training - INFO - warmup_steps: 0\n",
"2025-03-24 18:54:07 - training - INFO - num_workers: 4\n",
"2025-03-24 18:54:07 - training - INFO - batch_size: None\n",
"2025-03-24 18:54:07 - training - INFO - batch_token: 8000\n",
"2025-03-24 18:54:07 - training - INFO - num_epochs: 10\n",
"2025-03-24 18:54:07 - training - INFO - max_seq_len: -1\n",
"2025-03-24 18:54:07 - training - INFO - gradient_accumulation_steps: 8\n",
"2025-03-24 18:54:07 - training - INFO - max_grad_norm: -1\n",
"2025-03-24 18:54:07 - training - INFO - patience: 3\n",
"2025-03-24 18:54:07 - training - INFO - monitor: mse\n",
"2025-03-24 18:54:07 - training - INFO - monitor_strategy: min\n",
"2025-03-24 18:54:07 - training - INFO - training_method: full\n",
"2025-03-24 18:54:07 - training - INFO - lora_r: 8\n",
"2025-03-24 18:54:07 - training - INFO - lora_alpha: 32\n",
"2025-03-24 18:54:07 - training - INFO - lora_dropout: 0.1\n",
"2025-03-24 18:54:07 - training - INFO - feedforward_modules: w0\n",
"2025-03-24 18:54:07 - training - INFO - lora_target_modules: ['query', 'key', 'value']\n",
"2025-03-24 18:54:07 - training - INFO - structure_seq: []\n",
"2025-03-24 18:54:07 - training - INFO - output_model_name: full_lr_0.0005_8k_ga8.pt\n",
"2025-03-24 18:54:07 - training - INFO - output_root: ckpt\n",
"2025-03-24 18:54:07 - training - INFO - output_dir: ckpt/test_res/eSOL/esm2_t6_8M_UR50D\n",
"2025-03-24 18:54:07 - training - INFO - wandb: False\n",
"2025-03-24 18:54:07 - training - INFO - wandb_entity: None\n",
"2025-03-24 18:54:07 - training - INFO - wandb_project: VenusFactory\n",
"2025-03-24 18:54:07 - training - INFO - wandb_run_name: None\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"2025-03-24 18:54:09 - training - INFO - ------------------------\n",
"2025-03-24 18:54:09 - training - INFO - Model Parameters Statistics:\n",
"2025-03-24 18:54:09 - training - INFO - ------------------------\n",
"2025-03-24 18:54:09 - training - INFO - Adapter Model:\n",
"2025-03-24 18:54:09 - training - INFO - Total parameters: 103.68K\n",
"2025-03-24 18:54:09 - training - INFO - Trainable parameters: 103.68K\n",
"2025-03-24 18:54:09 - training - INFO - Pre-trained Model:\n",
"2025-03-24 18:54:09 - training - INFO - Total parameters: 7.84M\n",
"2025-03-24 18:54:09 - training - INFO - Trainable parameters: 7.84M\n",
"2025-03-24 18:54:09 - training - INFO - Combined:\n",
"2025-03-24 18:54:09 - training - INFO - Total parameters: 7.94M\n",
"2025-03-24 18:54:09 - training - INFO - Trainable parameters: 7.94M\n",
"2025-03-24 18:54:09 - training - INFO - Trainable percentage: 100.00%\n",
"2025-03-24 18:54:09 - training - INFO - ------------------------\n",
"2025-03-24 18:54:23 - training - INFO - Dataset Statistics:\n",
"2025-03-24 18:54:23 - training - INFO - ------------------------\n",
"2025-03-24 18:54:23 - training - INFO - Dataset: tyang816/eSOL\n",
"2025-03-24 18:54:23 - training - INFO - Number of train samples: 2481\n",
"2025-03-24 18:54:23 - training - INFO - Number of val samples: 310\n",
"2025-03-24 18:54:23 - training - INFO - Number of test samples: 310\n",
"2025-03-24 18:54:23 - training - INFO - Sample 3 data points from train dataset:\n",
"2025-03-24 18:54:23 - training - INFO - Train data point 1: {'name': 'P0ABL8', 'aa_seq': 'MMFWRIFRLELRVAFRHSAEIANPLWFFLIVITLFPLSIGPEPQLLARIAPGIIWVAALLSSLLALERLFRDDLQDGSLEQLMLLPLPLPAVVLAKVMAHWMVTGLPLLILSPLVAMLLGMDVYGWQVMALTLLLGTPTLGFLGAPGVALTVGLKRGGVLLSILVLPLTIPLLIFATAAMDAASMHLPVDGYLAILGALLAGTATLSPFATAAALRISIQ', 'gene': 'ccmB', 'label': -1.3882626995061573}\n",
"2025-03-24 18:54:23 - training - INFO - Train data point 2: {'name': 'P77721', 'aa_seq': 'MAAKDRIQAIKQMVANDKKVTVSNLSGIFQVTEETIRRDLEKLEDEGFLTRTYGGAVLNTAMLTENIHFYKRASSFYEEKQLIARKALPFIDNKTTMAADSSSTVMELLKLLQDRSGLTLLTNSAEAIHVLAQSEIKVVSTGGELNKNTLSLQGRITKEIIRRYHVDIMVMSCKGLDINSGALDSNEAEAEIKKTMIRQATEVALLVDHSKFDRKAFVQLADFSHINYIITDKSPGAEWIAFCKDNNIQLVW', 'gene': 'ydjF', 'label': -1.3882626995061573}\n",
"2025-03-24 18:54:23 - training - INFO - Train data point 3: {'name': 'Q47152', 'aa_seq': 'MSEYRRYYIKGGTWFFTVNLRNRRSQLLTTQYQMLRHAIIKVKRDRPFEINAWVVLPEHMHCIWTLPEGDDDFSSRWREIKKQFTHACGLKNIWQPRFWEHAIRNTKDYRHHVDYIYINPVKHGWVKQVSDWPFSTFHRDVARGLYPIDWAGDVTDFSAGERIIS', 'gene': 'yafM', 'label': -0.050581678782913427}\n",
"2025-03-24 18:54:23 - training - INFO - ------------------------\n",
"/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs) # noqa: B028\n",
"2025-03-24 18:54:23 - training - INFO - ---------- Epoch 0 ----------\n",
"Training: 100%|β| 220/220 [01:18<00:00, 2.80it/s, grad_step=27, train_loss=0.36\n",
"2025-03-24 18:55:42 - training - INFO - Epoch 0 Train Loss: 0.7381\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:08<00:00, 3.61it/s]\n",
"2025-03-24 18:55:50 - training - INFO - Epoch 0 Val Loss: 0.5647\n",
"2025-03-24 18:55:50 - training - INFO - Epoch 0 Val mse: 0.5647\n",
"2025-03-24 18:55:50 - training - INFO - Epoch 0 Val spearman_corr: 0.6196\n",
"2025-03-24 18:55:50 - training - INFO - Saving model with best val mse: 0.5647\n",
"2025-03-24 18:55:50 - training - INFO - ---------- Epoch 1 ----------\n",
"Training: 100%|β| 220/220 [01:15<00:00, 2.93it/s, grad_step=55, train_loss=0.20\n",
"2025-03-24 18:57:05 - training - INFO - Epoch 1 Train Loss: 0.5603\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.67it/s]\n",
"2025-03-24 18:57:13 - training - INFO - Epoch 1 Val Loss: 0.6275\n",
"2025-03-24 18:57:13 - training - INFO - Epoch 1 Val mse: 0.6275\n",
"2025-03-24 18:57:13 - training - INFO - Epoch 1 Val spearman_corr: 0.6382\n",
"2025-03-24 18:57:13 - training - INFO - ---------- Epoch 2 ----------\n",
"Training: 100%|β| 220/220 [01:15<00:00, 2.92it/s, grad_step=82, train_loss=0.29\n",
"2025-03-24 18:58:28 - training - INFO - Epoch 2 Train Loss: 0.5605\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.76it/s]\n",
"2025-03-24 18:58:36 - training - INFO - Epoch 2 Val Loss: 0.5153\n",
"2025-03-24 18:58:36 - training - INFO - Epoch 2 Val mse: 0.5153\n",
"2025-03-24 18:58:36 - training - INFO - Epoch 2 Val spearman_corr: 0.6555\n",
"2025-03-24 18:58:36 - training - INFO - Saving model with best val mse: 0.5153\n",
"2025-03-24 18:58:36 - training - INFO - ---------- Epoch 3 ----------\n",
"Training: 100%|β| 220/220 [01:15<00:00, 2.93it/s, grad_step=110, train_loss=0.1\n",
"2025-03-24 18:59:51 - training - INFO - Epoch 3 Train Loss: 0.4674\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.74it/s]\n",
"2025-03-24 18:59:59 - training - INFO - Epoch 3 Val Loss: 0.4965\n",
"2025-03-24 18:59:59 - training - INFO - Epoch 3 Val mse: 0.4965\n",
"2025-03-24 18:59:59 - training - INFO - Epoch 3 Val spearman_corr: 0.6703\n",
"2025-03-24 18:59:59 - training - INFO - Saving model with best val mse: 0.4965\n",
"2025-03-24 18:59:59 - training - INFO - ---------- Epoch 4 ----------\n",
"Training: 100%|β| 220/220 [01:14<00:00, 2.97it/s, grad_step=137, train_loss=0.3\n",
"2025-03-24 19:01:13 - training - INFO - Epoch 4 Train Loss: 0.4356\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.75it/s]\n",
"2025-03-24 19:01:21 - training - INFO - Epoch 4 Val Loss: 0.6908\n",
"2025-03-24 19:01:21 - training - INFO - Epoch 4 Val mse: 0.6908\n",
"2025-03-24 19:01:21 - training - INFO - Epoch 4 Val spearman_corr: 0.6472\n",
"2025-03-24 19:01:21 - training - INFO - ---------- Epoch 5 ----------\n",
"Training: 100%|β| 220/220 [01:15<00:00, 2.92it/s, grad_step=165, train_loss=0.1\n",
"2025-03-24 19:02:36 - training - INFO - Epoch 5 Train Loss: 0.3946\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.67it/s]\n",
"2025-03-24 19:02:44 - training - INFO - Epoch 5 Val Loss: 0.6024\n",
"2025-03-24 19:02:44 - training - INFO - Epoch 5 Val mse: 0.6024\n",
"2025-03-24 19:02:44 - training - INFO - Epoch 5 Val spearman_corr: 0.6578\n",
"2025-03-24 19:02:44 - training - INFO - ---------- Epoch 6 ----------\n",
"Training: 100%|β| 220/220 [01:14<00:00, 2.95it/s, grad_step=192, train_loss=0.0\n",
"2025-03-24 19:03:59 - training - INFO - Epoch 6 Train Loss: 0.3314\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.78it/s]\n",
"2025-03-24 19:04:06 - training - INFO - Epoch 6 Val Loss: 0.5009\n",
"2025-03-24 19:04:06 - training - INFO - Epoch 6 Val mse: 0.5009\n",
"2025-03-24 19:04:06 - training - INFO - Epoch 6 Val spearman_corr: 0.6651\n",
"2025-03-24 19:04:06 - training - INFO - Early stopping triggered after 3 epochs without improvement\n",
"2025-03-24 19:04:06 - training - INFO - Early stop at Epoch 6\n",
"/home/matwings/lc/VenusFactory-readme/src/training/trainer.py:379: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" checkpoint = torch.load(path, map_location=\"cpu\")\n",
"2025-03-24 19:04:06 - training - INFO - ---------- Starting Test Phase ----------\n",
"Testing: 100%|ββββββββββββββββββββββββββββββββββ| 26/26 [00:07<00:00, 3.68it/s]\n",
"2025-03-24 19:04:13 - training - INFO - Test Results:\n",
"2025-03-24 19:04:13 - training - INFO - Test Loss: 0.4056\n",
"2025-03-24 19:04:13 - training - INFO - Test mse: 0.4056\n",
"2025-03-24 19:04:13 - training - INFO - Test spearman_corr: 0.7389\n"
]
}
],
"source": [
"!export HF_ENDPOINT=https://hf-mirror.com # if need to use HF mirror\n",
"dataset=\"eSOL\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"lr=5e-4\n",
"training_method=\"full\"\n",
"sh=f\"\"\"\n",
"python src/train.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --dataset_config data/{dataset}/{dataset}_HF.json \\\n",
" --learning_rate {lr} \\\n",
" --gradient_accumulation_steps 8 \\\n",
" --num_epochs 10 \\\n",
" --batch_token 8000 \\\n",
" --patience 3 \\\n",
" --output_dir test_res/{dataset}/{plm_model} \\\n",
" --output_model_name {training_method}_lr_{lr}_8k_ga8.pt \\\n",
" --training_method {training_method}\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49b6a180-d009-4434-a5cc-4d0258b7b1af",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/train/train_plm_full.sh ./train_plm_full.sh\n",
"!bash ./train_plm_full.sh"
]
},
{
"cell_type": "markdown",
"id": "6d7cb6b2-76f2-456f-bf36-6bf44a70d3cc",
"metadata": {},
"source": [
"#### Freeze-tuning"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6c19a352-a340-4713-80ed-5d8fdbf8a188",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2025-03-24 20:41:53 - training - INFO - Starting training with configuration:\n",
"2025-03-24 20:41:53 - training - INFO - hidden_size: None\n",
"2025-03-24 20:41:53 - training - INFO - num_attention_head: 8\n",
"2025-03-24 20:41:53 - training - INFO - attention_probs_dropout: 0.1\n",
"2025-03-24 20:41:53 - training - INFO - plm_model: facebook/esm2_t6_8M_UR50D\n",
"2025-03-24 20:41:53 - training - INFO - pooling_method: mean\n",
"2025-03-24 20:41:53 - training - INFO - pooling_dropout: 0.1\n",
"2025-03-24 20:41:53 - training - INFO - dataset: tyang816/eSOL\n",
"2025-03-24 20:41:53 - training - INFO - dataset_config: data/eSOL/eSOL_HF.json\n",
"2025-03-24 20:41:53 - training - INFO - normalize: standard\n",
"2025-03-24 20:41:53 - training - INFO - num_labels: 1\n",
"2025-03-24 20:41:53 - training - INFO - problem_type: regression\n",
"2025-03-24 20:41:53 - training - INFO - pdb_type: None\n",
"2025-03-24 20:41:53 - training - INFO - train_file: None\n",
"2025-03-24 20:41:53 - training - INFO - valid_file: None\n",
"2025-03-24 20:41:53 - training - INFO - test_file: None\n",
"2025-03-24 20:41:53 - training - INFO - metrics: ['mse', 'spearman_corr']\n",
"2025-03-24 20:41:53 - training - INFO - seed: 3407\n",
"2025-03-24 20:41:53 - training - INFO - learning_rate: 0.0005\n",
"2025-03-24 20:41:53 - training - INFO - scheduler: None\n",
"2025-03-24 20:41:53 - training - INFO - warmup_steps: 0\n",
"2025-03-24 20:41:53 - training - INFO - num_workers: 4\n",
"2025-03-24 20:41:53 - training - INFO - batch_size: None\n",
"2025-03-24 20:41:53 - training - INFO - batch_token: 8000\n",
"2025-03-24 20:41:53 - training - INFO - num_epochs: 10\n",
"2025-03-24 20:41:53 - training - INFO - max_seq_len: -1\n",
"2025-03-24 20:41:53 - training - INFO - gradient_accumulation_steps: 8\n",
"2025-03-24 20:41:53 - training - INFO - max_grad_norm: -1\n",
"2025-03-24 20:41:53 - training - INFO - patience: 3\n",
"2025-03-24 20:41:53 - training - INFO - monitor: mse\n",
"2025-03-24 20:41:53 - training - INFO - monitor_strategy: min\n",
"2025-03-24 20:41:53 - training - INFO - training_method: freeze\n",
"2025-03-24 20:41:53 - training - INFO - lora_r: 8\n",
"2025-03-24 20:41:53 - training - INFO - lora_alpha: 32\n",
"2025-03-24 20:41:53 - training - INFO - lora_dropout: 0.1\n",
"2025-03-24 20:41:53 - training - INFO - feedforward_modules: w0\n",
"2025-03-24 20:41:53 - training - INFO - lora_target_modules: ['query', 'key', 'value']\n",
"2025-03-24 20:41:53 - training - INFO - structure_seq: []\n",
"2025-03-24 20:41:53 - training - INFO - output_model_name: freeze_lr_0.0005_8k_ga8.pt\n",
"2025-03-24 20:41:53 - training - INFO - output_root: ckpt\n",
"2025-03-24 20:41:53 - training - INFO - output_dir: ckpt/test_res/eSOL/esm2_t6_8M_UR50D\n",
"2025-03-24 20:41:53 - training - INFO - wandb: False\n",
"2025-03-24 20:41:53 - training - INFO - wandb_entity: None\n",
"2025-03-24 20:41:53 - training - INFO - wandb_project: VenusFactory\n",
"2025-03-24 20:41:53 - training - INFO - wandb_run_name: None\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"2025-03-24 20:41:55 - training - INFO - ------------------------\n",
"2025-03-24 20:41:55 - training - INFO - Model Parameters Statistics:\n",
"2025-03-24 20:41:55 - training - INFO - ------------------------\n",
"2025-03-24 20:41:55 - training - INFO - Adapter Model:\n",
"2025-03-24 20:41:55 - training - INFO - Total parameters: 103.68K\n",
"2025-03-24 20:41:55 - training - INFO - Trainable parameters: 103.68K\n",
"2025-03-24 20:41:55 - training - INFO - Pre-trained Model:\n",
"2025-03-24 20:41:55 - training - INFO - Total parameters: 7.84M\n",
"2025-03-24 20:41:55 - training - INFO - Trainable parameters: 0\n",
"2025-03-24 20:41:55 - training - INFO - Combined:\n",
"2025-03-24 20:41:55 - training - INFO - Total parameters: 7.94M\n",
"2025-03-24 20:41:55 - training - INFO - Trainable parameters: 103.68K\n",
"2025-03-24 20:41:55 - training - INFO - Trainable percentage: 1.31%\n",
"2025-03-24 20:41:55 - training - INFO - ------------------------\n",
"2025-03-24 20:42:08 - training - INFO - Dataset Statistics:\n",
"2025-03-24 20:42:08 - training - INFO - ------------------------\n",
"2025-03-24 20:42:08 - training - INFO - Dataset: tyang816/eSOL\n",
"2025-03-24 20:42:08 - training - INFO - Number of train samples: 2481\n",
"2025-03-24 20:42:08 - training - INFO - Number of val samples: 310\n",
"2025-03-24 20:42:08 - training - INFO - Number of test samples: 310\n",
"2025-03-24 20:42:08 - training - INFO - Sample 3 data points from train dataset:\n",
"2025-03-24 20:42:08 - training - INFO - Train data point 1: {'name': 'P0ABL8', 'aa_seq': 'MMFWRIFRLELRVAFRHSAEIANPLWFFLIVITLFPLSIGPEPQLLARIAPGIIWVAALLSSLLALERLFRDDLQDGSLEQLMLLPLPLPAVVLAKVMAHWMVTGLPLLILSPLVAMLLGMDVYGWQVMALTLLLGTPTLGFLGAPGVALTVGLKRGGVLLSILVLPLTIPLLIFATAAMDAASMHLPVDGYLAILGALLAGTATLSPFATAAALRISIQ', 'gene': 'ccmB', 'label': -1.3882626995061573}\n",
"2025-03-24 20:42:08 - training - INFO - Train data point 2: {'name': 'P77721', 'aa_seq': 'MAAKDRIQAIKQMVANDKKVTVSNLSGIFQVTEETIRRDLEKLEDEGFLTRTYGGAVLNTAMLTENIHFYKRASSFYEEKQLIARKALPFIDNKTTMAADSSSTVMELLKLLQDRSGLTLLTNSAEAIHVLAQSEIKVVSTGGELNKNTLSLQGRITKEIIRRYHVDIMVMSCKGLDINSGALDSNEAEAEIKKTMIRQATEVALLVDHSKFDRKAFVQLADFSHINYIITDKSPGAEWIAFCKDNNIQLVW', 'gene': 'ydjF', 'label': -1.3882626995061573}\n",
"2025-03-24 20:42:08 - training - INFO - Train data point 3: {'name': 'Q47152', 'aa_seq': 'MSEYRRYYIKGGTWFFTVNLRNRRSQLLTTQYQMLRHAIIKVKRDRPFEINAWVVLPEHMHCIWTLPEGDDDFSSRWREIKKQFTHACGLKNIWQPRFWEHAIRNTKDYRHHVDYIYINPVKHGWVKQVSDWPFSTFHRDVARGLYPIDWAGDVTDFSAGERIIS', 'gene': 'yafM', 'label': -0.050581678782913427}\n",
"2025-03-24 20:42:08 - training - INFO - ------------------------\n",
"/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs) # noqa: B028\n",
"2025-03-24 20:42:08 - training - INFO - ---------- Epoch 0 ----------\n",
"Training: 100%|β| 220/220 [01:00<00:00, 3.66it/s, grad_step=27, train_loss=0.91\n",
"2025-03-24 20:43:08 - training - INFO - Epoch 0 Train Loss: 0.8836\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:08<00:00, 3.58it/s]\n",
"2025-03-24 20:43:16 - training - INFO - Epoch 0 Val Loss: 0.7416\n",
"2025-03-24 20:43:16 - training - INFO - Epoch 0 Val mse: 0.7416\n",
"2025-03-24 20:43:16 - training - INFO - Epoch 0 Val spearman_corr: 0.4969\n",
"2025-03-24 20:43:16 - training - INFO - Saving model with best val mse: 0.7416\n",
"2025-03-24 20:43:16 - training - INFO - ---------- Epoch 1 ----------\n",
"Training: 100%|β| 220/220 [00:58<00:00, 3.79it/s, grad_step=55, train_loss=0.81\n",
"2025-03-24 20:44:14 - training - INFO - Epoch 1 Train Loss: 0.7110\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.79it/s]\n",
"2025-03-24 20:44:22 - training - INFO - Epoch 1 Val Loss: 0.6337\n",
"2025-03-24 20:44:22 - training - INFO - Epoch 1 Val mse: 0.6337\n",
"2025-03-24 20:44:22 - training - INFO - Epoch 1 Val spearman_corr: 0.5577\n",
"2025-03-24 20:44:22 - training - INFO - Saving model with best val mse: 0.6337\n",
"2025-03-24 20:44:22 - training - INFO - ---------- Epoch 2 ----------\n",
"Training: 100%|β| 220/220 [00:56<00:00, 3.92it/s, grad_step=82, train_loss=0.72\n",
"2025-03-24 20:45:18 - training - INFO - Epoch 2 Train Loss: 0.6427\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.93it/s]\n",
"2025-03-24 20:45:25 - training - INFO - Epoch 2 Val Loss: 0.6070\n",
"2025-03-24 20:45:25 - training - INFO - Epoch 2 Val mse: 0.6070\n",
"2025-03-24 20:45:25 - training - INFO - Epoch 2 Val spearman_corr: 0.5811\n",
"2025-03-24 20:45:25 - training - INFO - Saving model with best val mse: 0.6070\n",
"2025-03-24 20:45:25 - training - INFO - ---------- Epoch 3 ----------\n",
"Training: 100%|β| 220/220 [00:56<00:00, 3.92it/s, grad_step=110, train_loss=0.6\n",
"2025-03-24 20:46:21 - training - INFO - Epoch 3 Train Loss: 0.6178\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.92it/s]\n",
"2025-03-24 20:46:29 - training - INFO - Epoch 3 Val Loss: 0.5934\n",
"2025-03-24 20:46:29 - training - INFO - Epoch 3 Val mse: 0.5934\n",
"2025-03-24 20:46:29 - training - INFO - Epoch 3 Val spearman_corr: 0.5946\n",
"2025-03-24 20:46:29 - training - INFO - Saving model with best val mse: 0.5934\n",
"2025-03-24 20:46:29 - training - INFO - ---------- Epoch 4 ----------\n",
"Training: 100%|β| 220/220 [00:55<00:00, 3.96it/s, grad_step=137, train_loss=0.6\n",
"2025-03-24 20:47:24 - training - INFO - Epoch 4 Train Loss: 0.6019\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.10it/s]\n",
"2025-03-24 20:47:32 - training - INFO - Epoch 4 Val Loss: 0.5801\n",
"2025-03-24 20:47:32 - training - INFO - Epoch 4 Val mse: 0.5801\n",
"2025-03-24 20:47:32 - training - INFO - Epoch 4 Val spearman_corr: 0.6055\n",
"2025-03-24 20:47:32 - training - INFO - Saving model with best val mse: 0.5801\n",
"2025-03-24 20:47:32 - training - INFO - ---------- Epoch 5 ----------\n",
"Training: 100%|β| 220/220 [00:53<00:00, 4.12it/s, grad_step=165, train_loss=0.5\n",
"2025-03-24 20:48:25 - training - INFO - Epoch 5 Train Loss: 0.5955\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.01it/s]\n",
"2025-03-24 20:48:32 - training - INFO - Epoch 5 Val Loss: 0.5742\n",
"2025-03-24 20:48:32 - training - INFO - Epoch 5 Val mse: 0.5742\n",
"2025-03-24 20:48:32 - training - INFO - Epoch 5 Val spearman_corr: 0.6125\n",
"2025-03-24 20:48:32 - training - INFO - Saving model with best val mse: 0.5742\n",
"2025-03-24 20:48:32 - training - INFO - ---------- Epoch 6 ----------\n",
"Training: 100%|β| 220/220 [00:54<00:00, 4.05it/s, grad_step=192, train_loss=0.6\n",
"2025-03-24 20:49:27 - training - INFO - Epoch 6 Train Loss: 0.5882\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.01it/s]\n",
"2025-03-24 20:49:34 - training - INFO - Epoch 6 Val Loss: 0.5687\n",
"2025-03-24 20:49:34 - training - INFO - Epoch 6 Val mse: 0.5687\n",
"2025-03-24 20:49:34 - training - INFO - Epoch 6 Val spearman_corr: 0.6162\n",
"2025-03-24 20:49:34 - training - INFO - Saving model with best val mse: 0.5687\n",
"2025-03-24 20:49:34 - training - INFO - ---------- Epoch 7 ----------\n",
"Training: 100%|β| 220/220 [00:54<00:00, 4.06it/s, grad_step=220, train_loss=0.6\n",
"2025-03-24 20:50:28 - training - INFO - Epoch 7 Train Loss: 0.5769\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:06<00:00, 4.15it/s]\n",
"2025-03-24 20:50:35 - training - INFO - Epoch 7 Val Loss: 0.5680\n",
"2025-03-24 20:50:35 - training - INFO - Epoch 7 Val mse: 0.5680\n",
"2025-03-24 20:50:35 - training - INFO - Epoch 7 Val spearman_corr: 0.6189\n",
"2025-03-24 20:50:35 - training - INFO - Saving model with best val mse: 0.5680\n",
"2025-03-24 20:50:35 - training - INFO - ---------- Epoch 8 ----------\n",
"Training: 100%|β| 220/220 [00:52<00:00, 4.20it/s, grad_step=247, train_loss=0.7\n",
"2025-03-24 20:51:27 - training - INFO - Epoch 8 Train Loss: 0.5831\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:06<00:00, 4.16it/s]\n",
"2025-03-24 20:51:34 - training - INFO - Epoch 8 Val Loss: 0.5603\n",
"2025-03-24 20:51:34 - training - INFO - Epoch 8 Val mse: 0.5603\n",
"2025-03-24 20:51:34 - training - INFO - Epoch 8 Val spearman_corr: 0.6219\n",
"2025-03-24 20:51:34 - training - INFO - Saving model with best val mse: 0.5603\n",
"2025-03-24 20:51:34 - training - INFO - ---------- Epoch 9 ----------\n",
"Training: 100%|β| 220/220 [00:51<00:00, 4.28it/s, grad_step=275, train_loss=0.7\n",
"2025-03-24 20:52:26 - training - INFO - Epoch 9 Train Loss: 0.5809\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:06<00:00, 4.20it/s]\n",
"2025-03-24 20:52:33 - training - INFO - Epoch 9 Val Loss: 0.5583\n",
"2025-03-24 20:52:33 - training - INFO - Epoch 9 Val mse: 0.5583\n",
"2025-03-24 20:52:33 - training - INFO - Epoch 9 Val spearman_corr: 0.6222\n",
"2025-03-24 20:52:33 - training - INFO - Saving model with best val mse: 0.5583\n",
"/home/matwings/lc/VenusFactory-readme/src/training/trainer.py:436: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" checkpoint = torch.load(path, map_location=\"cpu\")\n",
"2025-03-24 20:52:33 - training - INFO - ---------- Starting Test Phase ----------\n",
"Testing: 100%|ββββββββββββββββββββββββββββββββββ| 26/26 [00:06<00:00, 4.14it/s]\n",
"2025-03-24 20:52:39 - training - INFO - Test Results:\n",
"2025-03-24 20:52:39 - training - INFO - Test Loss: 0.5484\n",
"2025-03-24 20:52:39 - training - INFO - Test mse: 0.5484\n",
"2025-03-24 20:52:39 - training - INFO - Test spearman_corr: 0.6793\n"
]
}
],
"source": [
"!export HF_ENDPOINT=https://hf-mirror.com # if need to use HF mirror\n",
"dataset=\"eSOL\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"lr=5e-4\n",
"training_method=\"freeze\"\n",
"sh=f\"\"\"\n",
"python src/train.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --dataset_config data/{dataset}/{dataset}_HF.json \\\n",
" --learning_rate {lr} \\\n",
" --gradient_accumulation_steps 8 \\\n",
" --num_epochs 10 \\\n",
" --batch_token 8000 \\\n",
" --patience 3 \\\n",
" --output_dir test_res/{dataset}/{plm_model} \\\n",
" --output_model_name {training_method}_lr_{lr}_8k_ga8.pt \\\n",
" --training_method {training_method}\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4f521343-925e-4c1d-aa0d-e55e7cf896ce",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/train/train_plm_freeze.sh ./train_plm_freeze.sh\n",
"!bash ./train_plm_freeze.sh"
]
},
{
"cell_type": "markdown",
"id": "7b250614-7959-4423-b147-e2c09d18fedd",
"metadata": {},
"source": [
"#### [SES-Adapter](https://arxiv.org/abs/2404.14850)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d1bf246b-dee5-4fb4-b246-df1ca062138f",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2025-03-24 21:05:15 - training - INFO - Starting training with configuration:\n",
"2025-03-24 21:05:15 - training - INFO - hidden_size: None\n",
"2025-03-24 21:05:15 - training - INFO - num_attention_head: 8\n",
"2025-03-24 21:05:15 - training - INFO - attention_probs_dropout: 0.1\n",
"2025-03-24 21:05:15 - training - INFO - plm_model: facebook/esm2_t6_8M_UR50D\n",
"2025-03-24 21:05:15 - training - INFO - pooling_method: mean\n",
"2025-03-24 21:05:15 - training - INFO - pooling_dropout: 0.1\n",
"2025-03-24 21:05:15 - training - INFO - dataset: tyang816/eSOL_AlphaFold2\n",
"2025-03-24 21:05:15 - training - INFO - dataset_config: data/eSOL/eSOL_AlphaFold2_HF.json\n",
"2025-03-24 21:05:15 - training - INFO - normalize: standard\n",
"2025-03-24 21:05:15 - training - INFO - num_labels: 1\n",
"2025-03-24 21:05:15 - training - INFO - problem_type: regression\n",
"2025-03-24 21:05:15 - training - INFO - pdb_type: AlphaFold2\n",
"2025-03-24 21:05:15 - training - INFO - train_file: None\n",
"2025-03-24 21:05:15 - training - INFO - valid_file: None\n",
"2025-03-24 21:05:15 - training - INFO - test_file: None\n",
"2025-03-24 21:05:15 - training - INFO - metrics: ['mse', 'spearman_corr']\n",
"2025-03-24 21:05:15 - training - INFO - seed: 3407\n",
"2025-03-24 21:05:15 - training - INFO - learning_rate: 0.0005\n",
"2025-03-24 21:05:15 - training - INFO - scheduler: None\n",
"2025-03-24 21:05:15 - training - INFO - warmup_steps: 0\n",
"2025-03-24 21:05:15 - training - INFO - num_workers: 4\n",
"2025-03-24 21:05:15 - training - INFO - batch_size: None\n",
"2025-03-24 21:05:15 - training - INFO - batch_token: 8000\n",
"2025-03-24 21:05:15 - training - INFO - num_epochs: 10\n",
"2025-03-24 21:05:15 - training - INFO - max_seq_len: -1\n",
"2025-03-24 21:05:15 - training - INFO - gradient_accumulation_steps: 8\n",
"2025-03-24 21:05:15 - training - INFO - max_grad_norm: -1\n",
"2025-03-24 21:05:15 - training - INFO - patience: 3\n",
"2025-03-24 21:05:15 - training - INFO - monitor: mse\n",
"2025-03-24 21:05:15 - training - INFO - monitor_strategy: min\n",
"2025-03-24 21:05:15 - training - INFO - training_method: ses-adapter\n",
"2025-03-24 21:05:15 - training - INFO - lora_r: 8\n",
"2025-03-24 21:05:15 - training - INFO - lora_alpha: 32\n",
"2025-03-24 21:05:15 - training - INFO - lora_dropout: 0.1\n",
"2025-03-24 21:05:15 - training - INFO - feedforward_modules: w0\n",
"2025-03-24 21:05:15 - training - INFO - lora_target_modules: ['query', 'key', 'value']\n",
"2025-03-24 21:05:15 - training - INFO - structure_seq: ['foldseek_seq', 'ss8_seq']\n",
"2025-03-24 21:05:15 - training - INFO - output_model_name: ses-adapter_AlphaFold2_lr_0.0005_bt8k_ga8.pt\n",
"2025-03-24 21:05:15 - training - INFO - output_root: ckpt\n",
"2025-03-24 21:05:15 - training - INFO - output_dir: ckpt/test_res/eSOL/esm2_t6_8M_UR50D\n",
"2025-03-24 21:05:15 - training - INFO - wandb: False\n",
"2025-03-24 21:05:15 - training - INFO - wandb_entity: None\n",
"2025-03-24 21:05:15 - training - INFO - wandb_project: VenusFactory\n",
"2025-03-24 21:05:15 - training - INFO - wandb_run_name: None\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"2025-03-24 21:05:16 - training - INFO - ------------------------\n",
"2025-03-24 21:05:16 - training - INFO - Model Parameters Statistics:\n",
"2025-03-24 21:05:16 - training - INFO - ------------------------\n",
"2025-03-24 21:05:16 - training - INFO - Adapter Model:\n",
"2025-03-24 21:05:16 - training - INFO - Total parameters: 946.56K\n",
"2025-03-24 21:05:16 - training - INFO - Trainable parameters: 946.56K\n",
"2025-03-24 21:05:16 - training - INFO - Pre-trained Model:\n",
"2025-03-24 21:05:16 - training - INFO - Total parameters: 7.84M\n",
"2025-03-24 21:05:16 - training - INFO - Trainable parameters: 0\n",
"2025-03-24 21:05:16 - training - INFO - Combined:\n",
"2025-03-24 21:05:16 - training - INFO - Total parameters: 8.79M\n",
"2025-03-24 21:05:16 - training - INFO - Trainable parameters: 946.56K\n",
"2025-03-24 21:05:16 - training - INFO - Trainable percentage: 10.77%\n",
"2025-03-24 21:05:16 - training - INFO - ------------------------\n",
"2025-03-24 21:05:27 - training - INFO - Dataset Statistics:\n",
"2025-03-24 21:05:27 - training - INFO - ------------------------\n",
"2025-03-24 21:05:27 - training - INFO - Dataset: tyang816/eSOL_AlphaFold2\n",
"2025-03-24 21:05:27 - training - INFO - Number of train samples: 2481\n",
"2025-03-24 21:05:27 - training - INFO - Number of val samples: 310\n",
"2025-03-24 21:05:27 - training - INFO - Number of test samples: 310\n",
"2025-03-24 21:05:27 - training - INFO - Sample 3 data points from train dataset:\n",
"2025-03-24 21:05:27 - training - INFO - Train data point 1: {'name': 'P0ABL8', 'aa_seq': 'MMFWRIFRLELRVAFRHSAEIANPLWFFLIVITLFPLSIGPEPQLLARIAPGIIWVAALLSSLLALERLFRDDLQDGSLEQLMLLPLPLPAVVLAKVMAHWMVTGLPLLILSPLVAMLLGMDVYGWQVMALTLLLGTPTLGFLGAPGVALTVGLKRGGVLLSILVLPLTIPLLIFATAAMDAASMHLPVDGYLAILGALLAGTATLSPFATAAALRISIQ', 'ss8_seq': 'LHHHHHHHHHHHHHHHLHHHHHHHHHHHHHHHHHHHHHHLLLHHHHHHHHHHHHHHHHHHHHHHHHTTTTHHHHHHTHHHHHHTSSSLHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTLLHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTSTTHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTLLLHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHL', 'ss3_seq': 'CHHHHHHHHHHHHHHHCHHHHHHHHHHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHHHHHHHHHHHCCCCHHHHHHCHHHHHHCCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHCCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHCCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHC', 'foldseek_seq': 'DVLVVLLVVLLVVLVVPVVLQPPLVVQLVCQLVVVCVVVDVDLVVLLVCLLVSNVVSLLVSLLSSLLCLCQVCVVVVVNVVLLPDPDDQLSSLLSSLVSSCCRHLVVSLVCLVVSCVSNVHDPQLSVLLSLLSVLLSSLLSLLQLLQSLLQSLPPPSSVSSCVRSVVVCPVSSVLSSVLSVCSVVVHDNVVSSVVSNVSSVVSSVPSSVSNSVSNVVNVD', 'esm3_structure_seq': '[3300, 2109, 3790, 1265, 1450, 3097, 706, 2082, 1197, 3986, 3112, 1197, 195, 1079, 1295, 2439, 76, 2605, 2605, 153, 2626, 264, 3010, 850, 3923, 1450, 2205, 3390, 1938, 588, 1592, 2664, 904, 2103, 2416, 3741, 3954, 137, 699, 1670, 137, 2524, 2958, 1542, 2874, 1101, 2147, 2296, 1012, 245, 1287, 2182, 2386, 3833, 1802, 2660, 1733, 3097, 1379, 3837, 752, 340, 620, 405, 1742, 3465, 3649, 3271, 2896, 102, 2869, 2290, 195, 2546, 1411, 1542, 2504, 3954, 2848, 3604, 3735, 3877, 861, 321, 435, 3681, 2053, 1738, 153, 3222, 668, 3230, 3873, 1984, 2728, 2301, 3278, 195, 3122, 1999, 2684, 1894, 3668, 2417, 3057, 1656, 1750, 2385, 4019, 2056, 2162, 74, 1677, 2082, 1342, 283, 588, 3580, 2355, 3156, 1884, 2490, 1825, 2182, 2805, 811, 1476, 2480, 2964, 2700, 2299, 3802, 2651, 2208, 3030, 2480, 2040, 1365, 183, 2705, 1530, 757, 2123, 2153, 2530, 356, 226, 958, 1629, 3142, 2187, 877, 3383, 754, 1561, 886, 3598, 2535, 1187, 2660, 538, 2964, 3259, 2440, 3671, 2366, 2156, 498, 2926, 2144, 2007, 1108, 1630, 2048, 23, 2851, 3251, 1352, 283, 1919, 64, 3309, 2884, 1431, 1701, 1607, 1385, 1729, 2013, 1495, 824, 3155, 1767, 3850, 414, 1201, 1297, 2279, 2401, 168, 2874, 3842, 983, 2974, 2459, 1701, 3838, 1304, 1768, 2130, 3171, 3079, 2874, 3079, 3960, 262, 1012, 2944, 3785, 987]', 'plddt': 92.57876802884616, 'gene': 'ccmB', 'label': -1.3882626995061573}\n",
"2025-03-24 21:05:27 - training - INFO - Train data point 2: {'name': 'P77721', 'aa_seq': 'MAAKDRIQAIKQMVANDKKVTVSNLSGIFQVTEETIRRDLEKLEDEGFLTRTYGGAVLNTAMLTENIHFYKRASSFYEEKQLIARKALPFIDNKTTMAADSSSTVMELLKLLQDRSGLTLLTNSAEAIHVLAQSEIKVVSTGGELNKNTLSLQGRITKEIIRRYHVDIMVMSCKGLDINSGALDSNEAEAEIKKTMIRQATEVALLVDHSKFDRKAFVQLADFSHINYIITDKSPGAEWIAFCKDNNIQLVW', 'ss8_seq': 'LLHHHHHHHHHHHHHHHSEEEHHHHHHHHTLLHHHHHHHHHHHHHTTSEEEETTEEEELHHHHHHTSHHHHHHHTTHHHHHHHHHHHHHHHTTLSEEEELSLHHHHHHHHHTTTLTTLEEEELBHHHHHHTTTSSSEEEELLLEEETTTTEEESHHHHHHHTTLLEEEEEELLSEEETTTEEEESLHHHHHHHHHHHTTEEEEEEELLGGGTTLLLSEEEELGGGLSEEELSSLLLHHHHHHHHHTTLEEEL', 'ss3_seq': 'CCHHHHHHHHHHHHHHHCEEEHHHHHHHHCCCHHHHHHHHHHHHHCCCEEEECCEEEECHHHHHHCCHHHHHHHCCHHHHHHHHHHHHHHHCCCCEEEECCCHHHHHHHHHCCCCCCCEEEECEHHHHHHCCCCCCEEEECCCEEECCCCEEECHHHHHHHCCCCEEEEEECCCEEECCCEEEECCHHHHHHHHHHHCCEEEEEEECCHHHCCCCCCEEEECHHHCCEEECCCCCCHHHHHHHHHCCCEEEC', 'foldseek_seq': 'DDLVVLLVVVLVVLVVVQKDFLVVSCVVVVHDSVSVVVSVVVCVVVVQWDDDVRIIGGPPVVVQCQQQQVVQCPPCVVFLLLLLVQCVVVCVPWQEEEEELHSSVLSNLLVCQADLSHEYEYQAQVSCVVNVVGSYNYDYLAADADRRSNGHFDDSSLVSLVVAATQEYEYEAQAAALVQGGHHQDPGSLVSQLSNLVRYPAYEYEDEQVRYNDDHDGRRDHPLSHQEYTGSDDHDPVVVVSCVVSNHHYYD', 'esm3_structure_seq': '[1333, 318, 1842, 247, 3607, 517, 338, 1197, 2814, 1264, 1445, 987, 195, 3287, 681, 3902, 532, 3658, 3070, 2269, 2995, 689, 1112, 137, 2306, 2461, 987, 316, 3518, 750, 3915, 3847, 763, 2056, 803, 209, 2225, 2082, 3547, 1758, 3148, 1803, 2927, 3119, 3310, 3961, 3974, 939, 2131, 1412, 1877, 1763, 335, 2412, 1331, 239, 1646, 2062, 208, 2682, 3310, 1381, 1317, 278, 3413, 2082, 530, 2586, 3007, 1978, 1450, 3704, 210, 3789, 2279, 720, 1495, 2531, 3287, 2318, 3922, 681, 670, 3291, 855, 2147, 2869, 1471, 300, 1421, 282, 3158, 1561, 407, 845, 3280, 35, 1119, 290, 4003, 3283, 2068, 2858, 2675, 1203, 2898, 209, 3671, 1651, 2319, 4035, 1667, 550, 3643, 1983, 4088, 2801, 2189, 1383, 2996, 529, 2624, 2041, 325, 3628, 2176, 704, 2537, 2874, 1619, 1215, 3598, 2726, 632, 2776, 1632, 907, 526, 2852, 1447, 3891, 2612, 2485, 846, 269, 121, 2211, 204, 2188, 3750, 1493, 3973, 1661, 611, 921, 1497, 299, 1079, 264, 199, 4094, 2382, 3789, 780, 1462, 2275, 3794, 621, 2503, 1659, 1665, 1602, 480, 2818, 68, 954, 1190, 629, 2459, 1161, 28, 1811, 1895, 1884, 2182, 2854, 3740, 1016, 518, 2318, 2814, 508, 2401, 1264, 773, 2535, 2972, 4081, 3839, 2478, 2938, 448, 1539, 984, 2503, 2834, 3751, 1457, 908, 985, 1107, 2926, 3658, 1993, 3188, 678, 1235, 654, 1111, 1576, 3197, 2017, 1850, 75, 175, 1447, 3921, 956, 4051, 2315, 3975, 662, 1164, 2276, 273, 3337, 1034, 2653, 631, 1677, 278, 631, 1077, 1476, 3056, 3125, 1474, 986, 1412, 3399, 3673, 2174]', 'plddt': 94.28275479313824, 'gene': 'ydjF', 'label': -1.3882626995061573}\n",
"2025-03-24 21:05:27 - training - INFO - Train data point 3: {'name': 'Q47152', 'aa_seq': 'MSEYRRYYIKGGTWFFTVNLRNRRSQLLTTQYQMLRHAIIKVKRDRPFEINAWVVLPEHMHCIWTLPEGDDDFSSRWREIKKQFTHACGLKNIWQPRFWEHAIRNTKDYRHHVDYIYINPVKHGWVKQVSDWPFSTFHRDVARGLYPIDWAGDVTDFSAGERIIS', 'ss8_seq': 'LLLLLLLLLTTLEEEEEEEBSSTTLLHHHHTHHHHHHHHHHHHHHSLLEEEEEEELSSEEEEEEELLTTLLLHHHHHHHHHHHHHHHTTLSSLBLSSLEEEELLSHHHHHHHHHHHHHHHHHTTSLSSGGGLLSBSHHHHHHTTSSLTTLLLLLLLLLSSLLLLL', 'ss3_seq': 'CCCCCCCCCCCCEEEEEEEECCCCCCHHHHCHHHHHHHHHHHHHHCCCEEEEEEECCCEEEEEEECCCCCCCHHHHHHHHHHHHHHHCCCCCCECCCCEEEECCCHHHHHHHHHHHHHHHHHCCCCCCHHHCCCECHHHHHHCCCCCCCCCCCCCCCCCCCCCCC', 'foldseek_seq': 'DPPPDDDFDFQFKKKKKWFWPPPPAQCCVVVVVQLVVQLVVLCVVANWDWPKKADGRGMIITIIGHHGPDRPPVVSVVSSVVSSCVVVVHPRGTDPDMDMDTQDDPVSVVVSRLVRQLVCVLVVVDVFSLVGPRICPVVCVVVVNDPSRDNHDPDDDPNVPDDDD', 'esm3_structure_seq': '[754, 3715, 2572, 2744, 2875, 1046, 1169, 3270, 874, 3003, 2027, 564, 3202, 180, 590, 2458, 3912, 2567, 378, 3709, 2856, 1532, 3665, 4090, 1189, 559, 74, 916, 2697, 2874, 2066, 3598, 2874, 4053, 4042, 803, 1450, 3321, 2835, 264, 2298, 571, 1651, 588, 4090, 4073, 3787, 1276, 2785, 1066, 2162, 775, 1267, 3661, 1447, 2324, 798, 2289, 1146, 3619, 3335, 2328, 3284, 57, 593, 3261, 379, 3447, 322, 49, 3858, 2269, 542, 1598, 3101, 3878, 3416, 4025, 3826, 400, 361, 987, 2579, 595, 1456, 1197, 3230, 208, 2641, 2967, 3965, 2308, 3598, 3867, 1144, 824, 1404, 3995, 3821, 2229, 303, 1457, 2218, 1892, 216, 1995, 3056, 588, 1598, 264, 195, 2400, 2535, 2426, 2212, 2944, 1793, 214, 1658, 2623, 2064, 3243, 262, 3254, 953, 3997, 247, 495, 4000, 1456, 305, 2588, 1265, 2535, 1239, 2139, 233, 3674, 2082, 3271, 3119, 123, 3954, 3660, 316, 2854, 1412, 2061, 36, 2607, 167, 2809, 1412, 252, 3946, 1570, 27, 116, 2273, 3156, 1176, 780, 1140, 481, 4047]', 'plddt': 92.01135628952916, 'gene': 'yafM', 'label': -0.050581678782913427}\n",
"2025-03-24 21:05:27 - training - INFO - ------------------------\n",
"/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs) # noqa: B028\n",
"2025-03-24 21:05:27 - training - INFO - ---------- Epoch 0 ----------\n",
"Training: 100%|β| 222/222 [01:01<00:00, 3.59it/s, grad_step=27, train_loss=0.31\n",
"2025-03-24 21:06:29 - training - INFO - Epoch 0 Train Loss: 0.7993\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.79it/s]\n",
"2025-03-24 21:06:36 - training - INFO - Epoch 0 Val Loss: 0.6134\n",
"2025-03-24 21:06:36 - training - INFO - Epoch 0 Val mse: 0.6134\n",
"2025-03-24 21:06:36 - training - INFO - Epoch 0 Val spearman_corr: 0.5819\n",
"2025-03-24 21:06:36 - training - INFO - Saving model with best val mse: 0.6134\n",
"2025-03-24 21:06:36 - training - INFO - ---------- Epoch 1 ----------\n",
"Training: 100%|β| 222/222 [01:00<00:00, 3.69it/s, grad_step=55, train_loss=0.32\n",
"2025-03-24 21:07:36 - training - INFO - Epoch 1 Train Loss: 0.6030\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.92it/s]\n",
"2025-03-24 21:07:44 - training - INFO - Epoch 1 Val Loss: 0.5433\n",
"2025-03-24 21:07:44 - training - INFO - Epoch 1 Val mse: 0.5433\n",
"2025-03-24 21:07:44 - training - INFO - Epoch 1 Val spearman_corr: 0.6266\n",
"2025-03-24 21:07:44 - training - INFO - Saving model with best val mse: 0.5433\n",
"2025-03-24 21:07:44 - training - INFO - ---------- Epoch 2 ----------\n",
"Training: 100%|β| 222/222 [00:59<00:00, 3.71it/s, grad_step=83, train_loss=0.43\n",
"2025-03-24 21:08:44 - training - INFO - Epoch 2 Train Loss: 0.5441\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.05it/s]\n",
"2025-03-24 21:08:51 - training - INFO - Epoch 2 Val Loss: 0.5223\n",
"2025-03-24 21:08:51 - training - INFO - Epoch 2 Val mse: 0.5223\n",
"2025-03-24 21:08:51 - training - INFO - Epoch 2 Val spearman_corr: 0.6515\n",
"2025-03-24 21:08:51 - training - INFO - Saving model with best val mse: 0.5223\n",
"2025-03-24 21:08:51 - training - INFO - ---------- Epoch 3 ----------\n",
"Training: 100%|β| 222/222 [00:57<00:00, 3.87it/s, grad_step=111, train_loss=0.5\n",
"2025-03-24 21:09:48 - training - INFO - Epoch 3 Train Loss: 0.5074\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.06it/s]\n",
"2025-03-24 21:09:55 - training - INFO - Epoch 3 Val Loss: 0.5381\n",
"2025-03-24 21:09:55 - training - INFO - Epoch 3 Val mse: 0.5381\n",
"2025-03-24 21:09:55 - training - INFO - Epoch 3 Val spearman_corr: 0.6578\n",
"2025-03-24 21:09:55 - training - INFO - ---------- Epoch 4 ----------\n",
"Training: 100%|β| 222/222 [00:57<00:00, 3.86it/s, grad_step=138, train_loss=0.4\n",
"2025-03-24 21:10:53 - training - INFO - Epoch 4 Train Loss: 0.4726\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.07it/s]\n",
"2025-03-24 21:11:00 - training - INFO - Epoch 4 Val Loss: 0.5230\n",
"2025-03-24 21:11:00 - training - INFO - Epoch 4 Val mse: 0.5230\n",
"2025-03-24 21:11:00 - training - INFO - Epoch 4 Val spearman_corr: 0.6611\n",
"2025-03-24 21:11:00 - training - INFO - ---------- Epoch 5 ----------\n",
"Training: 100%|β| 222/222 [00:57<00:00, 3.86it/s, grad_step=166, train_loss=0.3\n",
"2025-03-24 21:11:57 - training - INFO - Epoch 5 Train Loss: 0.4205\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.07it/s]\n",
"2025-03-24 21:12:05 - training - INFO - Epoch 5 Val Loss: 0.5309\n",
"2025-03-24 21:12:05 - training - INFO - Epoch 5 Val mse: 0.5309\n",
"2025-03-24 21:12:05 - training - INFO - Epoch 5 Val spearman_corr: 0.6537\n",
"2025-03-24 21:12:05 - training - INFO - Early stopping triggered after 3 epochs without improvement\n",
"2025-03-24 21:12:05 - training - INFO - Early stop at Epoch 5\n",
"/home/matwings/lc/VenusFactory-readme/src/training/trainer.py:436: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" checkpoint = torch.load(path, map_location=\"cpu\")\n",
"2025-03-24 21:12:05 - training - INFO - ---------- Starting Test Phase ----------\n",
"Testing: 100%|ββββββββββββββββββββββββββββββββββ| 26/26 [00:06<00:00, 3.96it/s]\n",
"2025-03-24 21:12:11 - training - INFO - Test Results:\n",
"2025-03-24 21:12:11 - training - INFO - Test Loss: 0.5255\n",
"2025-03-24 21:12:11 - training - INFO - Test mse: 0.5255\n",
"2025-03-24 21:12:11 - training - INFO - Test spearman_corr: 0.6920\n"
]
}
],
"source": [
"!export HF_ENDPOINT=https://hf-mirror.com # if need to use HF mirror\n",
"dataset=\"eSOL\"\n",
"pdb_type=\"AlphaFold2\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"lr=5e-4\n",
"training_method=\"ses-adapter\"\n",
"sh=f\"\"\"\n",
"python src/train.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --dataset_config data/{dataset}/{dataset}_{pdb_type}_HF.json \\\n",
" --learning_rate {lr} \\\n",
" --num_epochs 10 \\\n",
" --batch_token 8000 \\\n",
" --gradient_accumulation_steps 8 \\\n",
" --patience 3 \\\n",
" --structure_seq foldseek_seq,ss8_seq \\\n",
" --output_dir test_res/{dataset}/{plm_model} \\\n",
" --training_method {training_method} \\\n",
" --output_model_name ses-adapter_{pdb_type}_lr_{lr}_bt8k_ga8.pt\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7fc683c8-6925-4f71-b50d-b9bc942d9773",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/train/train_plm_ses-adapter.sh ./train_plm_ses-adapter.sh\n",
"!bash ./train_plm_ses-adapter.sh"
]
},
{
"cell_type": "markdown",
"id": "fb7a34d3-fddd-4b97-a6d8-ebcd0625e22e",
"metadata": {},
"source": [
"#### [LoRA](https://arxiv.org/abs/2106.09685)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0ee6f362-6738-4bce-b28e-58ce5416e070",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2025-03-24 21:17:16 - training - INFO - Starting training with configuration:\n",
"2025-03-24 21:17:16 - training - INFO - hidden_size: None\n",
"2025-03-24 21:17:16 - training - INFO - num_attention_head: 8\n",
"2025-03-24 21:17:16 - training - INFO - attention_probs_dropout: 0.1\n",
"2025-03-24 21:17:16 - training - INFO - plm_model: facebook/esm2_t6_8M_UR50D\n",
"2025-03-24 21:17:16 - training - INFO - pooling_method: mean\n",
"2025-03-24 21:17:16 - training - INFO - pooling_dropout: 0.1\n",
"2025-03-24 21:17:16 - training - INFO - dataset: tyang816/eSOL\n",
"2025-03-24 21:17:16 - training - INFO - dataset_config: data/eSOL/eSOL_HF.json\n",
"2025-03-24 21:17:16 - training - INFO - normalize: standard\n",
"2025-03-24 21:17:16 - training - INFO - num_labels: 1\n",
"2025-03-24 21:17:16 - training - INFO - problem_type: regression\n",
"2025-03-24 21:17:16 - training - INFO - pdb_type: None\n",
"2025-03-24 21:17:16 - training - INFO - train_file: None\n",
"2025-03-24 21:17:16 - training - INFO - valid_file: None\n",
"2025-03-24 21:17:16 - training - INFO - test_file: None\n",
"2025-03-24 21:17:16 - training - INFO - metrics: ['mse', 'spearman_corr']\n",
"2025-03-24 21:17:16 - training - INFO - seed: 3407\n",
"2025-03-24 21:17:16 - training - INFO - learning_rate: 0.0005\n",
"2025-03-24 21:17:16 - training - INFO - scheduler: None\n",
"2025-03-24 21:17:16 - training - INFO - warmup_steps: 0\n",
"2025-03-24 21:17:16 - training - INFO - num_workers: 4\n",
"2025-03-24 21:17:16 - training - INFO - batch_size: None\n",
"2025-03-24 21:17:16 - training - INFO - batch_token: 8000\n",
"2025-03-24 21:17:16 - training - INFO - num_epochs: 10\n",
"2025-03-24 21:17:16 - training - INFO - max_seq_len: -1\n",
"2025-03-24 21:17:16 - training - INFO - gradient_accumulation_steps: 8\n",
"2025-03-24 21:17:16 - training - INFO - max_grad_norm: -1\n",
"2025-03-24 21:17:16 - training - INFO - patience: 3\n",
"2025-03-24 21:17:16 - training - INFO - monitor: mse\n",
"2025-03-24 21:17:16 - training - INFO - monitor_strategy: min\n",
"2025-03-24 21:17:16 - training - INFO - training_method: plm-lora\n",
"2025-03-24 21:17:16 - training - INFO - lora_r: 8\n",
"2025-03-24 21:17:16 - training - INFO - lora_alpha: 32\n",
"2025-03-24 21:17:16 - training - INFO - lora_dropout: 0.1\n",
"2025-03-24 21:17:16 - training - INFO - feedforward_modules: w0\n",
"2025-03-24 21:17:16 - training - INFO - lora_target_modules: ['query', 'key', 'value']\n",
"2025-03-24 21:17:16 - training - INFO - structure_seq: []\n",
"2025-03-24 21:17:16 - training - INFO - output_model_name: plm-lora_lr_0.0005_8k_ga8.pt\n",
"2025-03-24 21:17:16 - training - INFO - output_root: ckpt\n",
"2025-03-24 21:17:16 - training - INFO - output_dir: ckpt/test_res/eSOL/esm2_t6_8M_UR50D\n",
"2025-03-24 21:17:16 - training - INFO - wandb: False\n",
"2025-03-24 21:17:16 - training - INFO - wandb_entity: None\n",
"2025-03-24 21:17:16 - training - INFO - wandb_project: VenusFactory\n",
"2025-03-24 21:17:16 - training - INFO - wandb_run_name: None\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"trainable params: 92,160 || all params: 7,932,281 || trainable%: 1.1618\n",
"2025-03-24 21:17:18 - training - INFO - ------------------------\n",
"2025-03-24 21:17:18 - training - INFO - Model Parameters Statistics:\n",
"2025-03-24 21:17:18 - training - INFO - ------------------------\n",
"2025-03-24 21:17:18 - training - INFO - Adapter Model:\n",
"2025-03-24 21:17:18 - training - INFO - Total parameters: 103.04K\n",
"2025-03-24 21:17:18 - training - INFO - Trainable parameters: 103.04K\n",
"2025-03-24 21:17:18 - training - INFO - Pre-trained Model:\n",
"2025-03-24 21:17:18 - training - INFO - Total parameters: 7.93M\n",
"2025-03-24 21:17:18 - training - INFO - Trainable parameters: 92.16K\n",
"2025-03-24 21:17:18 - training - INFO - Combined:\n",
"2025-03-24 21:17:18 - training - INFO - Total parameters: 8.04M\n",
"2025-03-24 21:17:18 - training - INFO - Trainable parameters: 195.20K\n",
"2025-03-24 21:17:18 - training - INFO - Trainable percentage: 2.43%\n",
"2025-03-24 21:17:18 - training - INFO - ------------------------\n",
"2025-03-24 21:17:30 - training - INFO - Dataset Statistics:\n",
"2025-03-24 21:17:30 - training - INFO - ------------------------\n",
"2025-03-24 21:17:30 - training - INFO - Dataset: tyang816/eSOL\n",
"2025-03-24 21:17:30 - training - INFO - Number of train samples: 2481\n",
"2025-03-24 21:17:30 - training - INFO - Number of val samples: 310\n",
"2025-03-24 21:17:30 - training - INFO - Number of test samples: 310\n",
"2025-03-24 21:17:30 - training - INFO - Sample 3 data points from train dataset:\n",
"2025-03-24 21:17:30 - training - INFO - Train data point 1: {'name': 'P0ABL8', 'aa_seq': 'MMFWRIFRLELRVAFRHSAEIANPLWFFLIVITLFPLSIGPEPQLLARIAPGIIWVAALLSSLLALERLFRDDLQDGSLEQLMLLPLPLPAVVLAKVMAHWMVTGLPLLILSPLVAMLLGMDVYGWQVMALTLLLGTPTLGFLGAPGVALTVGLKRGGVLLSILVLPLTIPLLIFATAAMDAASMHLPVDGYLAILGALLAGTATLSPFATAAALRISIQ', 'gene': 'ccmB', 'label': -1.3882626995061573}\n",
"2025-03-24 21:17:30 - training - INFO - Train data point 2: {'name': 'P77721', 'aa_seq': 'MAAKDRIQAIKQMVANDKKVTVSNLSGIFQVTEETIRRDLEKLEDEGFLTRTYGGAVLNTAMLTENIHFYKRASSFYEEKQLIARKALPFIDNKTTMAADSSSTVMELLKLLQDRSGLTLLTNSAEAIHVLAQSEIKVVSTGGELNKNTLSLQGRITKEIIRRYHVDIMVMSCKGLDINSGALDSNEAEAEIKKTMIRQATEVALLVDHSKFDRKAFVQLADFSHINYIITDKSPGAEWIAFCKDNNIQLVW', 'gene': 'ydjF', 'label': -1.3882626995061573}\n",
"2025-03-24 21:17:30 - training - INFO - Train data point 3: {'name': 'Q47152', 'aa_seq': 'MSEYRRYYIKGGTWFFTVNLRNRRSQLLTTQYQMLRHAIIKVKRDRPFEINAWVVLPEHMHCIWTLPEGDDDFSSRWREIKKQFTHACGLKNIWQPRFWEHAIRNTKDYRHHVDYIYINPVKHGWVKQVSDWPFSTFHRDVARGLYPIDWAGDVTDFSAGERIIS', 'gene': 'yafM', 'label': -0.050581678782913427}\n",
"2025-03-24 21:17:30 - training - INFO - ------------------------\n",
"/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs) # noqa: B028\n",
"2025-03-24 21:17:30 - training - INFO - ---------- Epoch 0 ----------\n",
"Training: 100%|β| 223/223 [01:19<00:00, 2.82it/s, grad_step=27, train_loss=0.75\n",
"2025-03-24 21:18:49 - training - INFO - Epoch 0 Train Loss: 0.8292\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.86it/s]\n",
"2025-03-24 21:18:57 - training - INFO - Epoch 0 Val Loss: 0.6874\n",
"2025-03-24 21:18:57 - training - INFO - Epoch 0 Val mse: 0.6874\n",
"2025-03-24 21:18:57 - training - INFO - Epoch 0 Val spearman_corr: 0.5289\n",
"2025-03-24 21:18:57 - training - INFO - Saving model with best val mse: 0.6874\n",
"2025-03-24 21:18:58 - training - INFO - ---------- Epoch 1 ----------\n",
"Training: 100%|β| 223/223 [01:19<00:00, 2.79it/s, grad_step=55, train_loss=0.83\n",
"2025-03-24 21:20:18 - training - INFO - Epoch 1 Train Loss: 0.6019\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.98it/s]\n",
"2025-03-24 21:20:25 - training - INFO - Epoch 1 Val Loss: 0.5628\n",
"2025-03-24 21:20:25 - training - INFO - Epoch 1 Val mse: 0.5628\n",
"2025-03-24 21:20:25 - training - INFO - Epoch 1 Val spearman_corr: 0.6149\n",
"2025-03-24 21:20:25 - training - INFO - Saving model with best val mse: 0.5628\n",
"2025-03-24 21:20:25 - training - INFO - ---------- Epoch 2 ----------\n",
"Training: 100%|β| 223/223 [01:20<00:00, 2.79it/s, grad_step=83, train_loss=0.81\n",
"2025-03-24 21:21:45 - training - INFO - Epoch 2 Train Loss: 0.5198\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.97it/s]\n",
"2025-03-24 21:21:53 - training - INFO - Epoch 2 Val Loss: 0.5345\n",
"2025-03-24 21:21:53 - training - INFO - Epoch 2 Val mse: 0.5345\n",
"2025-03-24 21:21:53 - training - INFO - Epoch 2 Val spearman_corr: 0.6437\n",
"2025-03-24 21:21:53 - training - INFO - Saving model with best val mse: 0.5345\n",
"2025-03-24 21:21:53 - training - INFO - ---------- Epoch 3 ----------\n",
"Training: 100%|β| 223/223 [01:20<00:00, 2.78it/s, grad_step=111, train_loss=0.6\n",
"2025-03-24 21:23:14 - training - INFO - Epoch 3 Train Loss: 0.4739\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.98it/s]\n",
"2025-03-24 21:23:21 - training - INFO - Epoch 3 Val Loss: 0.5141\n",
"2025-03-24 21:23:21 - training - INFO - Epoch 3 Val mse: 0.5141\n",
"2025-03-24 21:23:21 - training - INFO - Epoch 3 Val spearman_corr: 0.6567\n",
"2025-03-24 21:23:21 - training - INFO - Saving model with best val mse: 0.5141\n",
"2025-03-24 21:23:21 - training - INFO - ---------- Epoch 4 ----------\n",
"Training: 100%|β| 223/223 [01:20<00:00, 2.78it/s, grad_step=139, train_loss=0.7\n",
"2025-03-24 21:24:42 - training - INFO - Epoch 4 Train Loss: 0.4375\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.01it/s]\n",
"2025-03-24 21:24:49 - training - INFO - Epoch 4 Val Loss: 0.5324\n",
"2025-03-24 21:24:49 - training - INFO - Epoch 4 Val mse: 0.5324\n",
"2025-03-24 21:24:49 - training - INFO - Epoch 4 Val spearman_corr: 0.6519\n",
"2025-03-24 21:24:49 - training - INFO - ---------- Epoch 5 ----------\n",
"Training: 100%|β| 223/223 [01:20<00:00, 2.79it/s, grad_step=167, train_loss=0.5\n",
"2025-03-24 21:26:09 - training - INFO - Epoch 5 Train Loss: 0.4068\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.98it/s]\n",
"2025-03-24 21:26:16 - training - INFO - Epoch 5 Val Loss: 0.5343\n",
"2025-03-24 21:26:16 - training - INFO - Epoch 5 Val mse: 0.5343\n",
"2025-03-24 21:26:16 - training - INFO - Epoch 5 Val spearman_corr: 0.6483\n",
"2025-03-24 21:26:16 - training - INFO - ---------- Epoch 6 ----------\n",
"Training: 100%|β| 223/223 [01:19<00:00, 2.80it/s, grad_step=195, train_loss=0.5\n",
"2025-03-24 21:27:36 - training - INFO - Epoch 6 Train Loss: 0.3660\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.98it/s]\n",
"2025-03-24 21:27:43 - training - INFO - Epoch 6 Val Loss: 0.5543\n",
"2025-03-24 21:27:43 - training - INFO - Epoch 6 Val mse: 0.5543\n",
"2025-03-24 21:27:43 - training - INFO - Epoch 6 Val spearman_corr: 0.6436\n",
"2025-03-24 21:27:43 - training - INFO - Early stopping triggered after 3 epochs without improvement\n",
"2025-03-24 21:27:43 - training - INFO - Early stop at Epoch 6\n",
"/home/matwings/lc/VenusFactory-readme/src/training/trainer.py:385: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" checkpoint = torch.load(path, map_location=\"cpu\")\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"2025-03-24 21:27:45 - training - INFO - ---------- Starting Test Phase ----------\n",
"Testing: 100%|ββββββββββββββββββββββββββββββββββ| 26/26 [00:06<00:00, 3.90it/s]\n",
"2025-03-24 21:27:51 - training - INFO - Test Results:\n",
"2025-03-24 21:27:51 - training - INFO - Test Loss: 0.4329\n",
"2025-03-24 21:27:51 - training - INFO - Test mse: 0.4329\n",
"2025-03-24 21:27:51 - training - INFO - Test spearman_corr: 0.7341\n"
]
}
],
"source": [
"# ESM model target_modules name: query key value\n",
"# Bert_base(prot_bert) model target_modules name: query key value\n",
"# T5_base(ankh, t5) model target_modules name: q k v\n",
"\n",
"!export HF_ENDPOINT=https://hf-mirror.com # if need to use HF mirror\n",
"dataset=\"eSOL\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"lr=5e-4\n",
"training_method=\"plm-lora\"\n",
"sh=f\"\"\"\n",
"python src/train.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --dataset_config data/{dataset}/{dataset}_HF.json \\\n",
" --learning_rate {lr} \\\n",
" --gradient_accumulation_steps 8 \\\n",
" --num_epochs 10 \\\n",
" --batch_token 8000 \\\n",
" --patience 3 \\\n",
" --output_dir test_res/{dataset}/{plm_model} \\\n",
" --output_model_name {training_method}_lr_{lr}_8k_ga8.pt \\\n",
" --training_method {training_method} \\\n",
" --lora_target_modules query key value\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2018d36-1469-42b0-92fd-7ef165f28663",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/train/train_plm_lora.sh ./train_plm_lora.sh\n",
"!bash ./train_plm_lora.sh"
]
},
{
"cell_type": "markdown",
"id": "85c3fc20-32ab-4fa1-861d-a5985ecf70aa",
"metadata": {},
"source": [
"#### [AdaLoRA](https://arxiv.org/abs/2303.10512)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7d0688c6-f6c5-4ba1-94f0-ae73cc729658",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2025-03-24 21:34:14 - training - INFO - Starting training with configuration:\n",
"2025-03-24 21:34:14 - training - INFO - hidden_size: None\n",
"2025-03-24 21:34:14 - training - INFO - num_attention_head: 8\n",
"2025-03-24 21:34:14 - training - INFO - attention_probs_dropout: 0.1\n",
"2025-03-24 21:34:14 - training - INFO - plm_model: facebook/esm2_t6_8M_UR50D\n",
"2025-03-24 21:34:14 - training - INFO - pooling_method: mean\n",
"2025-03-24 21:34:14 - training - INFO - pooling_dropout: 0.1\n",
"2025-03-24 21:34:14 - training - INFO - dataset: tyang816/eSOL\n",
"2025-03-24 21:34:14 - training - INFO - dataset_config: data/eSOL/eSOL_HF.json\n",
"2025-03-24 21:34:14 - training - INFO - normalize: standard\n",
"2025-03-24 21:34:14 - training - INFO - num_labels: 1\n",
"2025-03-24 21:34:14 - training - INFO - problem_type: regression\n",
"2025-03-24 21:34:14 - training - INFO - pdb_type: None\n",
"2025-03-24 21:34:14 - training - INFO - train_file: None\n",
"2025-03-24 21:34:14 - training - INFO - valid_file: None\n",
"2025-03-24 21:34:14 - training - INFO - test_file: None\n",
"2025-03-24 21:34:14 - training - INFO - metrics: ['mse', 'spearman_corr']\n",
"2025-03-24 21:34:14 - training - INFO - seed: 3407\n",
"2025-03-24 21:34:14 - training - INFO - learning_rate: 0.0005\n",
"2025-03-24 21:34:14 - training - INFO - scheduler: None\n",
"2025-03-24 21:34:14 - training - INFO - warmup_steps: 0\n",
"2025-03-24 21:34:14 - training - INFO - num_workers: 4\n",
"2025-03-24 21:34:14 - training - INFO - batch_size: None\n",
"2025-03-24 21:34:14 - training - INFO - batch_token: 8000\n",
"2025-03-24 21:34:14 - training - INFO - num_epochs: 10\n",
"2025-03-24 21:34:14 - training - INFO - max_seq_len: -1\n",
"2025-03-24 21:34:14 - training - INFO - gradient_accumulation_steps: 8\n",
"2025-03-24 21:34:14 - training - INFO - max_grad_norm: -1\n",
"2025-03-24 21:34:14 - training - INFO - patience: 3\n",
"2025-03-24 21:34:14 - training - INFO - monitor: mse\n",
"2025-03-24 21:34:14 - training - INFO - monitor_strategy: min\n",
"2025-03-24 21:34:14 - training - INFO - training_method: plm-adalora\n",
"2025-03-24 21:34:14 - training - INFO - lora_r: 8\n",
"2025-03-24 21:34:14 - training - INFO - lora_alpha: 32\n",
"2025-03-24 21:34:14 - training - INFO - lora_dropout: 0.1\n",
"2025-03-24 21:34:14 - training - INFO - feedforward_modules: w0\n",
"2025-03-24 21:34:14 - training - INFO - lora_target_modules: ['query', 'key', 'value']\n",
"2025-03-24 21:34:14 - training - INFO - structure_seq: []\n",
"2025-03-24 21:34:14 - training - INFO - output_model_name: plm-adalora_lr_0.0005_8k_ga8.pt\n",
"2025-03-24 21:34:14 - training - INFO - output_root: ckpt\n",
"2025-03-24 21:34:14 - training - INFO - output_dir: ckpt/test_res/eSOL/esm2_t6_8M_UR50D\n",
"2025-03-24 21:34:14 - training - INFO - wandb: False\n",
"2025-03-24 21:34:14 - training - INFO - wandb_entity: None\n",
"2025-03-24 21:34:14 - training - INFO - wandb_project: VenusFactory\n",
"2025-03-24 21:34:14 - training - INFO - wandb_run_name: None\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"trainable params: 138,456 || all params: 7,978,595 || trainable%: 1.7353\n",
" Using plm adalora \n",
"2025-03-24 21:34:16 - training - INFO - ------------------------\n",
"2025-03-24 21:34:16 - training - INFO - Model Parameters Statistics:\n",
"2025-03-24 21:34:16 - training - INFO - ------------------------\n",
"2025-03-24 21:34:16 - training - INFO - Adapter Model:\n",
"2025-03-24 21:34:16 - training - INFO - Total parameters: 103.04K\n",
"2025-03-24 21:34:16 - training - INFO - Trainable parameters: 103.04K\n",
"2025-03-24 21:34:16 - training - INFO - Pre-trained Model:\n",
"2025-03-24 21:34:16 - training - INFO - Total parameters: 7.98M\n",
"2025-03-24 21:34:16 - training - INFO - Trainable parameters: 138.46K\n",
"2025-03-24 21:34:16 - training - INFO - Combined:\n",
"2025-03-24 21:34:16 - training - INFO - Total parameters: 8.08M\n",
"2025-03-24 21:34:16 - training - INFO - Trainable parameters: 241.50K\n",
"2025-03-24 21:34:16 - training - INFO - Trainable percentage: 2.99%\n",
"2025-03-24 21:34:16 - training - INFO - ------------------------\n",
"2025-03-24 21:34:28 - training - INFO - Dataset Statistics:\n",
"2025-03-24 21:34:28 - training - INFO - ------------------------\n",
"2025-03-24 21:34:28 - training - INFO - Dataset: tyang816/eSOL\n",
"2025-03-24 21:34:28 - training - INFO - Number of train samples: 2481\n",
"2025-03-24 21:34:28 - training - INFO - Number of val samples: 310\n",
"2025-03-24 21:34:28 - training - INFO - Number of test samples: 310\n",
"2025-03-24 21:34:28 - training - INFO - Sample 3 data points from train dataset:\n",
"2025-03-24 21:34:28 - training - INFO - Train data point 1: {'name': 'P0ABL8', 'aa_seq': 'MMFWRIFRLELRVAFRHSAEIANPLWFFLIVITLFPLSIGPEPQLLARIAPGIIWVAALLSSLLALERLFRDDLQDGSLEQLMLLPLPLPAVVLAKVMAHWMVTGLPLLILSPLVAMLLGMDVYGWQVMALTLLLGTPTLGFLGAPGVALTVGLKRGGVLLSILVLPLTIPLLIFATAAMDAASMHLPVDGYLAILGALLAGTATLSPFATAAALRISIQ', 'gene': 'ccmB', 'label': -1.3882626995061573}\n",
"2025-03-24 21:34:28 - training - INFO - Train data point 2: {'name': 'P77721', 'aa_seq': 'MAAKDRIQAIKQMVANDKKVTVSNLSGIFQVTEETIRRDLEKLEDEGFLTRTYGGAVLNTAMLTENIHFYKRASSFYEEKQLIARKALPFIDNKTTMAADSSSTVMELLKLLQDRSGLTLLTNSAEAIHVLAQSEIKVVSTGGELNKNTLSLQGRITKEIIRRYHVDIMVMSCKGLDINSGALDSNEAEAEIKKTMIRQATEVALLVDHSKFDRKAFVQLADFSHINYIITDKSPGAEWIAFCKDNNIQLVW', 'gene': 'ydjF', 'label': -1.3882626995061573}\n",
"2025-03-24 21:34:28 - training - INFO - Train data point 3: {'name': 'Q47152', 'aa_seq': 'MSEYRRYYIKGGTWFFTVNLRNRRSQLLTTQYQMLRHAIIKVKRDRPFEINAWVVLPEHMHCIWTLPEGDDDFSSRWREIKKQFTHACGLKNIWQPRFWEHAIRNTKDYRHHVDYIYINPVKHGWVKQVSDWPFSTFHRDVARGLYPIDWAGDVTDFSAGERIIS', 'gene': 'yafM', 'label': -0.050581678782913427}\n",
"2025-03-24 21:34:28 - training - INFO - ------------------------\n",
"/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs) # noqa: B028\n",
"2025-03-24 21:34:28 - training - INFO - ---------- Epoch 0 ----------\n",
"Training: 100%|β| 225/225 [01:17<00:00, 2.90it/s, grad_step=28, train_loss=0.58\n",
"2025-03-24 21:35:46 - training - INFO - Epoch 0 Train Loss: 0.8808\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:06<00:00, 4.29it/s]\n",
"2025-03-24 21:35:53 - training - INFO - Epoch 0 Val Loss: 0.7379\n",
"2025-03-24 21:35:53 - training - INFO - Epoch 0 Val mse: 0.7379\n",
"2025-03-24 21:35:53 - training - INFO - Epoch 0 Val spearman_corr: 0.5034\n",
"2025-03-24 21:35:53 - training - INFO - Saving model with best val mse: 0.7379\n",
"2025-03-24 21:35:53 - training - INFO - ---------- Epoch 1 ----------\n",
"Training: 100%|β| 225/225 [01:21<00:00, 2.75it/s, grad_step=56, train_loss=0.26\n",
"2025-03-24 21:37:15 - training - INFO - Epoch 1 Train Loss: 0.6948\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.94it/s]\n",
"2025-03-24 21:37:23 - training - INFO - Epoch 1 Val Loss: 0.6145\n",
"2025-03-24 21:37:23 - training - INFO - Epoch 1 Val mse: 0.6145\n",
"2025-03-24 21:37:23 - training - INFO - Epoch 1 Val spearman_corr: 0.5697\n",
"2025-03-24 21:37:23 - training - INFO - Saving model with best val mse: 0.6145\n",
"2025-03-24 21:37:24 - training - INFO - ---------- Epoch 2 ----------\n",
"Training: 100%|β| 225/225 [01:21<00:00, 2.77it/s, grad_step=84, train_loss=0.25\n",
"2025-03-24 21:38:45 - training - INFO - Epoch 2 Train Loss: 0.6118\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.07it/s]\n",
"2025-03-24 21:38:52 - training - INFO - Epoch 2 Val Loss: 0.5769\n",
"2025-03-24 21:38:52 - training - INFO - Epoch 2 Val mse: 0.5769\n",
"2025-03-24 21:38:52 - training - INFO - Epoch 2 Val spearman_corr: 0.5989\n",
"2025-03-24 21:38:52 - training - INFO - Saving model with best val mse: 0.5769\n",
"2025-03-24 21:38:53 - training - INFO - ---------- Epoch 3 ----------\n",
"Training: 100%|β| 225/225 [01:20<00:00, 2.78it/s, grad_step=112, train_loss=0.1\n",
"2025-03-24 21:40:14 - training - INFO - Epoch 3 Train Loss: 0.5717\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.99it/s]\n",
"2025-03-24 21:40:21 - training - INFO - Epoch 3 Val Loss: 0.5501\n",
"2025-03-24 21:40:21 - training - INFO - Epoch 3 Val mse: 0.5501\n",
"2025-03-24 21:40:21 - training - INFO - Epoch 3 Val spearman_corr: 0.6225\n",
"2025-03-24 21:40:21 - training - INFO - Saving model with best val mse: 0.5501\n",
"2025-03-24 21:40:22 - training - INFO - ---------- Epoch 4 ----------\n",
"Training: 100%|β| 225/225 [01:21<00:00, 2.77it/s, grad_step=140, train_loss=0.2\n",
"2025-03-24 21:41:43 - training - INFO - Epoch 4 Train Loss: 0.5419\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.95it/s]\n",
"2025-03-24 21:41:50 - training - INFO - Epoch 4 Val Loss: 0.5258\n",
"2025-03-24 21:41:50 - training - INFO - Epoch 4 Val mse: 0.5258\n",
"2025-03-24 21:41:50 - training - INFO - Epoch 4 Val spearman_corr: 0.6417\n",
"2025-03-24 21:41:50 - training - INFO - Saving model with best val mse: 0.5258\n",
"2025-03-24 21:41:52 - training - INFO - ---------- Epoch 5 ----------\n",
"Training: 100%|β| 225/225 [01:21<00:00, 2.75it/s, grad_step=168, train_loss=0.2\n",
"2025-03-24 21:43:13 - training - INFO - Epoch 5 Train Loss: 0.5177\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.93it/s]\n",
"2025-03-24 21:43:21 - training - INFO - Epoch 5 Val Loss: 0.5116\n",
"2025-03-24 21:43:21 - training - INFO - Epoch 5 Val mse: 0.5116\n",
"2025-03-24 21:43:21 - training - INFO - Epoch 5 Val spearman_corr: 0.6548\n",
"2025-03-24 21:43:21 - training - INFO - Saving model with best val mse: 0.5116\n",
"2025-03-24 21:43:23 - training - INFO - ---------- Epoch 6 ----------\n",
"Training: 100%|β| 225/225 [01:20<00:00, 2.81it/s, grad_step=196, train_loss=0.2\n",
"2025-03-24 21:44:43 - training - INFO - Epoch 6 Train Loss: 0.4937\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.13it/s]\n",
"2025-03-24 21:44:50 - training - INFO - Epoch 6 Val Loss: 0.5095\n",
"2025-03-24 21:44:50 - training - INFO - Epoch 6 Val mse: 0.5095\n",
"2025-03-24 21:44:50 - training - INFO - Epoch 6 Val spearman_corr: 0.6635\n",
"2025-03-24 21:44:50 - training - INFO - Saving model with best val mse: 0.5095\n",
"2025-03-24 21:44:51 - training - INFO - ---------- Epoch 7 ----------\n",
"Training: 100%|β| 225/225 [01:19<00:00, 2.83it/s, grad_step=225, train_loss=0.2\n",
"2025-03-24 21:46:10 - training - INFO - Epoch 7 Train Loss: 0.4745\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:06<00:00, 4.15it/s]\n",
"2025-03-24 21:46:17 - training - INFO - Epoch 7 Val Loss: 0.4992\n",
"2025-03-24 21:46:17 - training - INFO - Epoch 7 Val mse: 0.4992\n",
"2025-03-24 21:46:17 - training - INFO - Epoch 7 Val spearman_corr: 0.6705\n",
"2025-03-24 21:46:17 - training - INFO - Saving model with best val mse: 0.4992\n",
"2025-03-24 21:46:18 - training - INFO - ---------- Epoch 8 ----------\n",
"Training: 100%|β| 225/225 [01:19<00:00, 2.84it/s, grad_step=253, train_loss=0.2\n",
"2025-03-24 21:47:37 - training - INFO - Epoch 8 Train Loss: 0.4494\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.12it/s]\n",
"2025-03-24 21:47:44 - training - INFO - Epoch 8 Val Loss: 0.4986\n",
"2025-03-24 21:47:44 - training - INFO - Epoch 8 Val mse: 0.4986\n",
"2025-03-24 21:47:44 - training - INFO - Epoch 8 Val spearman_corr: 0.6714\n",
"2025-03-24 21:47:44 - training - INFO - Saving model with best val mse: 0.4986\n",
"2025-03-24 21:47:45 - training - INFO - ---------- Epoch 9 ----------\n",
"Training: 100%|β| 225/225 [01:18<00:00, 2.86it/s, grad_step=281, train_loss=0.2\n",
"2025-03-24 21:49:04 - training - INFO - Epoch 9 Train Loss: 0.4335\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:06<00:00, 4.20it/s]\n",
"2025-03-24 21:49:11 - training - INFO - Epoch 9 Val Loss: 0.4927\n",
"2025-03-24 21:49:11 - training - INFO - Epoch 9 Val mse: 0.4927\n",
"2025-03-24 21:49:11 - training - INFO - Epoch 9 Val spearman_corr: 0.6717\n",
"2025-03-24 21:49:11 - training - INFO - Saving model with best val mse: 0.4927\n",
"/home/matwings/lc/VenusFactory-readme/src/training/trainer.py:417: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" checkpoint = torch.load(path, map_location=\"cpu\")\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"2025-03-24 21:49:12 - training - INFO - ---------- Starting Test Phase ----------\n",
"Testing: 100%|ββββββββββββββββββββββββββββββββββ| 26/26 [00:06<00:00, 4.05it/s]\n",
"2025-03-24 21:49:19 - training - INFO - Test Results:\n",
"2025-03-24 21:49:19 - training - INFO - Test Loss: 0.4385\n",
"2025-03-24 21:49:19 - training - INFO - Test mse: 0.4385\n",
"2025-03-24 21:49:19 - training - INFO - Test spearman_corr: 0.7376\n"
]
}
],
"source": [
"# ESM model target_modules name: query key value\n",
"# Bert_base(prot_bert) model target_modules name: query key value\n",
"# T5_base(ankh, t5) model target_modules name: q k v\n",
"\n",
"!export HF_ENDPOINT=https://hf-mirror.com # if need to use HF mirror\n",
"dataset=\"eSOL\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"lr=5e-4\n",
"training_method=\"plm-adalora\"\n",
"sh=f\"\"\"\n",
"python src/train.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --dataset_config data/{dataset}/{dataset}_HF.json \\\n",
" --learning_rate {lr} \\\n",
" --gradient_accumulation_steps 8 \\\n",
" --num_epochs 10 \\\n",
" --batch_token 8000 \\\n",
" --patience 3 \\\n",
" --output_dir test_res/{dataset}/{plm_model} \\\n",
" --output_model_name {training_method}_lr_{lr}_8k_ga8.pt \\\n",
" --training_method {training_method} \\\n",
" --lora_target_modules query key value\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e43e299-c849-4464-97ac-c09c6413dfb2",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/train/train_plm_adalora.sh ./train_plm_adalora.sh\n",
"!bash ./train_plm_adalora.sh"
]
},
{
"cell_type": "markdown",
"id": "889e22dc-054d-4269-a113-8430ed758360",
"metadata": {},
"source": [
"#### [QLoRA](https://arxiv.org/abs/2305.14314)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9a6d6866-9da9-4230-b7d3-df86a5201ba0",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2025-03-24 21:52:56 - training - INFO - Starting training with configuration:\n",
"2025-03-24 21:52:56 - training - INFO - hidden_size: None\n",
"2025-03-24 21:52:56 - training - INFO - num_attention_head: 8\n",
"2025-03-24 21:52:56 - training - INFO - attention_probs_dropout: 0.1\n",
"2025-03-24 21:52:56 - training - INFO - plm_model: facebook/esm2_t6_8M_UR50D\n",
"2025-03-24 21:52:56 - training - INFO - pooling_method: mean\n",
"2025-03-24 21:52:56 - training - INFO - pooling_dropout: 0.1\n",
"2025-03-24 21:52:56 - training - INFO - dataset: tyang816/eSOL\n",
"2025-03-24 21:52:56 - training - INFO - dataset_config: data/eSOL/eSOL_HF.json\n",
"2025-03-24 21:52:56 - training - INFO - normalize: standard\n",
"2025-03-24 21:52:56 - training - INFO - num_labels: 1\n",
"2025-03-24 21:52:56 - training - INFO - problem_type: regression\n",
"2025-03-24 21:52:56 - training - INFO - pdb_type: None\n",
"2025-03-24 21:52:56 - training - INFO - train_file: None\n",
"2025-03-24 21:52:56 - training - INFO - valid_file: None\n",
"2025-03-24 21:52:56 - training - INFO - test_file: None\n",
"2025-03-24 21:52:56 - training - INFO - metrics: ['mse', 'spearman_corr']\n",
"2025-03-24 21:52:56 - training - INFO - seed: 3407\n",
"2025-03-24 21:52:56 - training - INFO - learning_rate: 0.0005\n",
"2025-03-24 21:52:56 - training - INFO - scheduler: None\n",
"2025-03-24 21:52:56 - training - INFO - warmup_steps: 0\n",
"2025-03-24 21:52:56 - training - INFO - num_workers: 4\n",
"2025-03-24 21:52:56 - training - INFO - batch_size: None\n",
"2025-03-24 21:52:56 - training - INFO - batch_token: 8000\n",
"2025-03-24 21:52:56 - training - INFO - num_epochs: 10\n",
"2025-03-24 21:52:56 - training - INFO - max_seq_len: -1\n",
"2025-03-24 21:52:56 - training - INFO - gradient_accumulation_steps: 8\n",
"2025-03-24 21:52:56 - training - INFO - max_grad_norm: -1\n",
"2025-03-24 21:52:56 - training - INFO - patience: 3\n",
"2025-03-24 21:52:56 - training - INFO - monitor: mse\n",
"2025-03-24 21:52:56 - training - INFO - monitor_strategy: min\n",
"2025-03-24 21:52:56 - training - INFO - training_method: plm-qlora\n",
"2025-03-24 21:52:56 - training - INFO - lora_r: 8\n",
"2025-03-24 21:52:56 - training - INFO - lora_alpha: 32\n",
"2025-03-24 21:52:56 - training - INFO - lora_dropout: 0.1\n",
"2025-03-24 21:52:56 - training - INFO - feedforward_modules: w0\n",
"2025-03-24 21:52:56 - training - INFO - lora_target_modules: ['query', 'key', 'value']\n",
"2025-03-24 21:52:56 - training - INFO - structure_seq: []\n",
"2025-03-24 21:52:56 - training - INFO - output_model_name: plm-qlora_lr_0.0005_8k_ga8.pt\n",
"2025-03-24 21:52:56 - training - INFO - output_root: ckpt\n",
"2025-03-24 21:52:56 - training - INFO - output_dir: ckpt/test_res/eSOL/esm2_t6_8M_UR50D\n",
"2025-03-24 21:52:56 - training - INFO - wandb: False\n",
"2025-03-24 21:52:56 - training - INFO - wandb_entity: None\n",
"2025-03-24 21:52:56 - training - INFO - wandb_project: VenusFactory\n",
"2025-03-24 21:52:56 - training - INFO - wandb_run_name: None\n",
"`low_cpu_mem_usage` was None, now default to True since model is quantized.\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"trainable params: 92,160 || all params: 7,932,281 || trainable%: 1.1618\n",
"2025-03-24 21:52:59 - training - INFO - ------------------------\n",
"2025-03-24 21:52:59 - training - INFO - Model Parameters Statistics:\n",
"2025-03-24 21:52:59 - training - INFO - ------------------------\n",
"2025-03-24 21:52:59 - training - INFO - Adapter Model:\n",
"2025-03-24 21:52:59 - training - INFO - Total parameters: 103.04K\n",
"2025-03-24 21:52:59 - training - INFO - Trainable parameters: 103.04K\n",
"2025-03-24 21:52:59 - training - INFO - Pre-trained Model:\n",
"2025-03-24 21:52:59 - training - INFO - Total parameters: 4.19M\n",
"2025-03-24 21:52:59 - training - INFO - Trainable parameters: 92.16K\n",
"2025-03-24 21:52:59 - training - INFO - Combined:\n",
"2025-03-24 21:52:59 - training - INFO - Total parameters: 4.30M\n",
"2025-03-24 21:52:59 - training - INFO - Trainable parameters: 195.20K\n",
"2025-03-24 21:52:59 - training - INFO - Trainable percentage: 4.54%\n",
"2025-03-24 21:52:59 - training - INFO - ------------------------\n",
"2025-03-24 21:53:10 - training - INFO - Dataset Statistics:\n",
"2025-03-24 21:53:10 - training - INFO - ------------------------\n",
"2025-03-24 21:53:10 - training - INFO - Dataset: tyang816/eSOL\n",
"2025-03-24 21:53:10 - training - INFO - Number of train samples: 2481\n",
"2025-03-24 21:53:10 - training - INFO - Number of val samples: 310\n",
"2025-03-24 21:53:10 - training - INFO - Number of test samples: 310\n",
"2025-03-24 21:53:10 - training - INFO - Sample 3 data points from train dataset:\n",
"2025-03-24 21:53:10 - training - INFO - Train data point 1: {'name': 'P0ABL8', 'aa_seq': 'MMFWRIFRLELRVAFRHSAEIANPLWFFLIVITLFPLSIGPEPQLLARIAPGIIWVAALLSSLLALERLFRDDLQDGSLEQLMLLPLPLPAVVLAKVMAHWMVTGLPLLILSPLVAMLLGMDVYGWQVMALTLLLGTPTLGFLGAPGVALTVGLKRGGVLLSILVLPLTIPLLIFATAAMDAASMHLPVDGYLAILGALLAGTATLSPFATAAALRISIQ', 'gene': 'ccmB', 'label': -1.3882626995061573}\n",
"2025-03-24 21:53:10 - training - INFO - Train data point 2: {'name': 'P77721', 'aa_seq': 'MAAKDRIQAIKQMVANDKKVTVSNLSGIFQVTEETIRRDLEKLEDEGFLTRTYGGAVLNTAMLTENIHFYKRASSFYEEKQLIARKALPFIDNKTTMAADSSSTVMELLKLLQDRSGLTLLTNSAEAIHVLAQSEIKVVSTGGELNKNTLSLQGRITKEIIRRYHVDIMVMSCKGLDINSGALDSNEAEAEIKKTMIRQATEVALLVDHSKFDRKAFVQLADFSHINYIITDKSPGAEWIAFCKDNNIQLVW', 'gene': 'ydjF', 'label': -1.3882626995061573}\n",
"2025-03-24 21:53:10 - training - INFO - Train data point 3: {'name': 'Q47152', 'aa_seq': 'MSEYRRYYIKGGTWFFTVNLRNRRSQLLTTQYQMLRHAIIKVKRDRPFEINAWVVLPEHMHCIWTLPEGDDDFSSRWREIKKQFTHACGLKNIWQPRFWEHAIRNTKDYRHHVDYIYINPVKHGWVKQVSDWPFSTFHRDVARGLYPIDWAGDVTDFSAGERIIS', 'gene': 'yafM', 'label': -0.050581678782913427}\n",
"2025-03-24 21:53:10 - training - INFO - ------------------------\n",
"/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs) # noqa: B028\n",
"2025-03-24 21:53:10 - training - INFO - ---------- Epoch 0 ----------\n",
"Training: 0%| | 0/224 [00:00, ?it/s]/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
" return fn(*args, **kwargs)\n",
"Training: 100%|β| 224/224 [01:23<00:00, 2.67it/s, grad_step=28, train_loss=0.87\n",
"2025-03-24 21:54:34 - training - INFO - Epoch 0 Train Loss: 0.8456\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.79it/s]\n",
"2025-03-24 21:54:42 - training - INFO - Epoch 0 Val Loss: 0.6296\n",
"2025-03-24 21:54:42 - training - INFO - Epoch 0 Val mse: 0.6296\n",
"2025-03-24 21:54:42 - training - INFO - Epoch 0 Val spearman_corr: 0.5653\n",
"2025-03-24 21:54:42 - training - INFO - Saving model with best val mse: 0.6296\n",
"2025-03-24 21:54:42 - training - INFO - ---------- Epoch 1 ----------\n",
"Training: 100%|β| 224/224 [01:21<00:00, 2.75it/s, grad_step=56, train_loss=1.04\n",
"2025-03-24 21:56:04 - training - INFO - Epoch 1 Train Loss: 0.6127\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.92it/s]\n",
"2025-03-24 21:56:11 - training - INFO - Epoch 1 Val Loss: 0.5478\n",
"2025-03-24 21:56:11 - training - INFO - Epoch 1 Val mse: 0.5478\n",
"2025-03-24 21:56:11 - training - INFO - Epoch 1 Val spearman_corr: 0.6274\n",
"2025-03-24 21:56:11 - training - INFO - Saving model with best val mse: 0.5478\n",
"2025-03-24 21:56:12 - training - INFO - ---------- Epoch 2 ----------\n",
"Training: 100%|β| 224/224 [01:21<00:00, 2.76it/s, grad_step=84, train_loss=1.13\n",
"2025-03-24 21:57:33 - training - INFO - Epoch 2 Train Loss: 0.5254\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.92it/s]\n",
"2025-03-24 21:57:40 - training - INFO - Epoch 2 Val Loss: 0.5045\n",
"2025-03-24 21:57:40 - training - INFO - Epoch 2 Val mse: 0.5045\n",
"2025-03-24 21:57:40 - training - INFO - Epoch 2 Val spearman_corr: 0.6567\n",
"2025-03-24 21:57:40 - training - INFO - Saving model with best val mse: 0.5045\n",
"2025-03-24 21:57:41 - training - INFO - ---------- Epoch 3 ----------\n",
"Training: 100%|β| 224/224 [01:21<00:00, 2.74it/s, grad_step=112, train_loss=1.0\n",
"2025-03-24 21:59:03 - training - INFO - Epoch 3 Train Loss: 0.4843\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.88it/s]\n",
"2025-03-24 21:59:10 - training - INFO - Epoch 3 Val Loss: 0.4890\n",
"2025-03-24 21:59:10 - training - INFO - Epoch 3 Val mse: 0.4890\n",
"2025-03-24 21:59:10 - training - INFO - Epoch 3 Val spearman_corr: 0.6711\n",
"2025-03-24 21:59:10 - training - INFO - Saving model with best val mse: 0.4890\n",
"2025-03-24 21:59:11 - training - INFO - ---------- Epoch 4 ----------\n",
"Training: 100%|β| 224/224 [01:22<00:00, 2.72it/s, grad_step=140, train_loss=0.8\n",
"2025-03-24 22:00:33 - training - INFO - Epoch 4 Train Loss: 0.4365\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.84it/s]\n",
"2025-03-24 22:00:41 - training - INFO - Epoch 4 Val Loss: 0.5016\n",
"2025-03-24 22:00:41 - training - INFO - Epoch 4 Val mse: 0.5016\n",
"2025-03-24 22:00:41 - training - INFO - Epoch 4 Val spearman_corr: 0.6697\n",
"2025-03-24 22:00:41 - training - INFO - ---------- Epoch 5 ----------\n",
"Training: 100%|β| 224/224 [01:21<00:00, 2.74it/s, grad_step=168, train_loss=0.7\n",
"2025-03-24 22:02:03 - training - INFO - Epoch 5 Train Loss: 0.3958\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.93it/s]\n",
"2025-03-24 22:02:10 - training - INFO - Epoch 5 Val Loss: 0.5091\n",
"2025-03-24 22:02:10 - training - INFO - Epoch 5 Val mse: 0.5091\n",
"2025-03-24 22:02:10 - training - INFO - Epoch 5 Val spearman_corr: 0.6619\n",
"2025-03-24 22:02:10 - training - INFO - ---------- Epoch 6 ----------\n",
"Training: 100%|β| 224/224 [01:19<00:00, 2.81it/s, grad_step=196, train_loss=0.6\n",
"2025-03-24 22:03:30 - training - INFO - Epoch 6 Train Loss: 0.3540\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.05it/s]\n",
"2025-03-24 22:03:37 - training - INFO - Epoch 6 Val Loss: 0.5202\n",
"2025-03-24 22:03:37 - training - INFO - Epoch 6 Val mse: 0.5202\n",
"2025-03-24 22:03:37 - training - INFO - Epoch 6 Val spearman_corr: 0.6515\n",
"2025-03-24 22:03:37 - training - INFO - Early stopping triggered after 3 epochs without improvement\n",
"2025-03-24 22:03:37 - training - INFO - Early stop at Epoch 6\n",
"/home/matwings/lc/VenusFactory-readme/src/training/trainer.py:395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" checkpoint = torch.load(path, map_location=\"cpu\")\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"2025-03-24 22:03:38 - training - INFO - ---------- Starting Test Phase ----------\n",
"Testing: 100%|ββββββββββββββββββββββββββββββββββ| 26/26 [00:06<00:00, 4.06it/s]\n",
"2025-03-24 22:03:44 - training - INFO - Test Results:\n",
"2025-03-24 22:03:44 - training - INFO - Test Loss: 0.4752\n",
"2025-03-24 22:03:44 - training - INFO - Test mse: 0.4752\n",
"2025-03-24 22:03:44 - training - INFO - Test spearman_corr: 0.7293\n"
]
}
],
"source": [
"# ESM model target_modules name: query key value\n",
"# Bert_base(prot_bert) model target_modules name: query key value\n",
"# T5_base(ankh, t5) model target_modules name: q k v\n",
"\n",
"!export HF_ENDPOINT=https://hf-mirror.com # if need to use HF mirror\n",
"dataset=\"eSOL\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"lr=5e-4\n",
"training_method=\"plm-qlora\"\n",
"sh=f\"\"\"\n",
"python src/train.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --dataset_config data/{dataset}/{dataset}_HF.json \\\n",
" --learning_rate {lr} \\\n",
" --gradient_accumulation_steps 8 \\\n",
" --num_epochs 10 \\\n",
" --batch_token 8000 \\\n",
" --patience 3 \\\n",
" --output_dir test_res/{dataset}/{plm_model} \\\n",
" --output_model_name {training_method}_lr_{lr}_8k_ga8.pt \\\n",
" --training_method {training_method} \\\n",
" --lora_target_modules query key value\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7470251f-05b1-420e-82d8-188e9fcca1d4",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/train/train_plm_qlora.sh ./train_plm_qlora.sh\n",
"!bash ./train_plm_qlora.sh"
]
},
{
"cell_type": "markdown",
"id": "26879ec7-77f7-4bc7-918b-4fd609d1951d",
"metadata": {},
"source": [
"#### [DoRA](https://arxiv.org/abs/2402.09353)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5e852af1-5d53-4ff5-a952-435c3f153e8f",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2025-03-24 22:12:08 - training - INFO - Starting training with configuration:\n",
"2025-03-24 22:12:08 - training - INFO - hidden_size: None\n",
"2025-03-24 22:12:08 - training - INFO - num_attention_head: 8\n",
"2025-03-24 22:12:08 - training - INFO - attention_probs_dropout: 0.1\n",
"2025-03-24 22:12:08 - training - INFO - plm_model: facebook/esm2_t6_8M_UR50D\n",
"2025-03-24 22:12:08 - training - INFO - pooling_method: mean\n",
"2025-03-24 22:12:08 - training - INFO - pooling_dropout: 0.1\n",
"2025-03-24 22:12:08 - training - INFO - dataset: tyang816/eSOL\n",
"2025-03-24 22:12:08 - training - INFO - dataset_config: data/eSOL/eSOL_HF.json\n",
"2025-03-24 22:12:08 - training - INFO - normalize: standard\n",
"2025-03-24 22:12:08 - training - INFO - num_labels: 1\n",
"2025-03-24 22:12:08 - training - INFO - problem_type: regression\n",
"2025-03-24 22:12:08 - training - INFO - pdb_type: None\n",
"2025-03-24 22:12:08 - training - INFO - train_file: None\n",
"2025-03-24 22:12:08 - training - INFO - valid_file: None\n",
"2025-03-24 22:12:08 - training - INFO - test_file: None\n",
"2025-03-24 22:12:08 - training - INFO - metrics: ['mse', 'spearman_corr']\n",
"2025-03-24 22:12:08 - training - INFO - seed: 3407\n",
"2025-03-24 22:12:08 - training - INFO - learning_rate: 0.0005\n",
"2025-03-24 22:12:08 - training - INFO - scheduler: None\n",
"2025-03-24 22:12:08 - training - INFO - warmup_steps: 0\n",
"2025-03-24 22:12:08 - training - INFO - num_workers: 4\n",
"2025-03-24 22:12:08 - training - INFO - batch_size: None\n",
"2025-03-24 22:12:08 - training - INFO - batch_token: 8000\n",
"2025-03-24 22:12:08 - training - INFO - num_epochs: 10\n",
"2025-03-24 22:12:08 - training - INFO - max_seq_len: -1\n",
"2025-03-24 22:12:08 - training - INFO - gradient_accumulation_steps: 8\n",
"2025-03-24 22:12:08 - training - INFO - max_grad_norm: -1\n",
"2025-03-24 22:12:08 - training - INFO - patience: 3\n",
"2025-03-24 22:12:08 - training - INFO - monitor: mse\n",
"2025-03-24 22:12:08 - training - INFO - monitor_strategy: min\n",
"2025-03-24 22:12:08 - training - INFO - training_method: plm-dora\n",
"2025-03-24 22:12:08 - training - INFO - lora_r: 8\n",
"2025-03-24 22:12:08 - training - INFO - lora_alpha: 32\n",
"2025-03-24 22:12:08 - training - INFO - lora_dropout: 0.1\n",
"2025-03-24 22:12:08 - training - INFO - feedforward_modules: w0\n",
"2025-03-24 22:12:08 - training - INFO - lora_target_modules: ['query', 'key', 'value']\n",
"2025-03-24 22:12:08 - training - INFO - structure_seq: []\n",
"2025-03-24 22:12:08 - training - INFO - output_model_name: plm-dora_lr_0.0005_8k_ga8.pt\n",
"2025-03-24 22:12:08 - training - INFO - output_root: ckpt\n",
"2025-03-24 22:12:08 - training - INFO - output_dir: ckpt/test_res/eSOL/esm2_t6_8M_UR50D\n",
"2025-03-24 22:12:08 - training - INFO - wandb: False\n",
"2025-03-24 22:12:08 - training - INFO - wandb_entity: None\n",
"2025-03-24 22:12:08 - training - INFO - wandb_project: VenusFactory\n",
"2025-03-24 22:12:08 - training - INFO - wandb_run_name: None\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"trainable params: 97,920 || all params: 7,938,041 || trainable%: 1.2336\n",
"2025-03-24 22:12:09 - training - INFO - ------------------------\n",
"2025-03-24 22:12:09 - training - INFO - Model Parameters Statistics:\n",
"2025-03-24 22:12:09 - training - INFO - ------------------------\n",
"2025-03-24 22:12:09 - training - INFO - Adapter Model:\n",
"2025-03-24 22:12:09 - training - INFO - Total parameters: 103.04K\n",
"2025-03-24 22:12:09 - training - INFO - Trainable parameters: 103.04K\n",
"2025-03-24 22:12:09 - training - INFO - Pre-trained Model:\n",
"2025-03-24 22:12:09 - training - INFO - Total parameters: 7.94M\n",
"2025-03-24 22:12:09 - training - INFO - Trainable parameters: 97.92K\n",
"2025-03-24 22:12:09 - training - INFO - Combined:\n",
"2025-03-24 22:12:09 - training - INFO - Total parameters: 8.04M\n",
"2025-03-24 22:12:09 - training - INFO - Trainable parameters: 200.96K\n",
"2025-03-24 22:12:09 - training - INFO - Trainable percentage: 2.50%\n",
"2025-03-24 22:12:09 - training - INFO - ------------------------\n",
"2025-03-24 22:12:19 - training - INFO - Dataset Statistics:\n",
"2025-03-24 22:12:19 - training - INFO - ------------------------\n",
"2025-03-24 22:12:19 - training - INFO - Dataset: tyang816/eSOL\n",
"2025-03-24 22:12:19 - training - INFO - Number of train samples: 2481\n",
"2025-03-24 22:12:19 - training - INFO - Number of val samples: 310\n",
"2025-03-24 22:12:19 - training - INFO - Number of test samples: 310\n",
"2025-03-24 22:12:19 - training - INFO - Sample 3 data points from train dataset:\n",
"2025-03-24 22:12:19 - training - INFO - Train data point 1: {'name': 'P0ABL8', 'aa_seq': 'MMFWRIFRLELRVAFRHSAEIANPLWFFLIVITLFPLSIGPEPQLLARIAPGIIWVAALLSSLLALERLFRDDLQDGSLEQLMLLPLPLPAVVLAKVMAHWMVTGLPLLILSPLVAMLLGMDVYGWQVMALTLLLGTPTLGFLGAPGVALTVGLKRGGVLLSILVLPLTIPLLIFATAAMDAASMHLPVDGYLAILGALLAGTATLSPFATAAALRISIQ', 'gene': 'ccmB', 'label': -1.3882626995061573}\n",
"2025-03-24 22:12:19 - training - INFO - Train data point 2: {'name': 'P77721', 'aa_seq': 'MAAKDRIQAIKQMVANDKKVTVSNLSGIFQVTEETIRRDLEKLEDEGFLTRTYGGAVLNTAMLTENIHFYKRASSFYEEKQLIARKALPFIDNKTTMAADSSSTVMELLKLLQDRSGLTLLTNSAEAIHVLAQSEIKVVSTGGELNKNTLSLQGRITKEIIRRYHVDIMVMSCKGLDINSGALDSNEAEAEIKKTMIRQATEVALLVDHSKFDRKAFVQLADFSHINYIITDKSPGAEWIAFCKDNNIQLVW', 'gene': 'ydjF', 'label': -1.3882626995061573}\n",
"2025-03-24 22:12:19 - training - INFO - Train data point 3: {'name': 'Q47152', 'aa_seq': 'MSEYRRYYIKGGTWFFTVNLRNRRSQLLTTQYQMLRHAIIKVKRDRPFEINAWVVLPEHMHCIWTLPEGDDDFSSRWREIKKQFTHACGLKNIWQPRFWEHAIRNTKDYRHHVDYIYINPVKHGWVKQVSDWPFSTFHRDVARGLYPIDWAGDVTDFSAGERIIS', 'gene': 'yafM', 'label': -0.050581678782913427}\n",
"2025-03-24 22:12:19 - training - INFO - ------------------------\n",
"/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs) # noqa: B028\n",
"2025-03-24 22:12:19 - training - INFO - ---------- Epoch 0 ----------\n",
"Training: 100%|β| 225/225 [01:26<00:00, 2.60it/s, grad_step=28, train_loss=0.77\n",
"2025-03-24 22:13:46 - training - INFO - Epoch 0 Train Loss: 0.8430\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.79it/s]\n",
"2025-03-24 22:13:53 - training - INFO - Epoch 0 Val Loss: 0.6580\n",
"2025-03-24 22:13:53 - training - INFO - Epoch 0 Val mse: 0.6580\n",
"2025-03-24 22:13:53 - training - INFO - Epoch 0 Val spearman_corr: 0.5469\n",
"2025-03-24 22:13:53 - training - INFO - Saving model with best val mse: 0.6580\n",
"2025-03-24 22:13:55 - training - INFO - ---------- Epoch 1 ----------\n",
"Training: 100%|β| 225/225 [01:25<00:00, 2.64it/s, grad_step=56, train_loss=0.66\n",
"2025-03-24 22:15:20 - training - INFO - Epoch 1 Train Loss: 0.6025\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.94it/s]\n",
"2025-03-24 22:15:27 - training - INFO - Epoch 1 Val Loss: 0.5658\n",
"2025-03-24 22:15:27 - training - INFO - Epoch 1 Val mse: 0.5658\n",
"2025-03-24 22:15:27 - training - INFO - Epoch 1 Val spearman_corr: 0.6289\n",
"2025-03-24 22:15:27 - training - INFO - Saving model with best val mse: 0.5658\n",
"2025-03-24 22:15:28 - training - INFO - ---------- Epoch 2 ----------\n",
"Training: 100%|β| 225/225 [01:25<00:00, 2.65it/s, grad_step=84, train_loss=0.78\n",
"2025-03-24 22:16:53 - training - INFO - Epoch 2 Train Loss: 0.5181\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.98it/s]\n",
"2025-03-24 22:17:00 - training - INFO - Epoch 2 Val Loss: 0.4963\n",
"2025-03-24 22:17:00 - training - INFO - Epoch 2 Val mse: 0.4963\n",
"2025-03-24 22:17:00 - training - INFO - Epoch 2 Val spearman_corr: 0.6647\n",
"2025-03-24 22:17:00 - training - INFO - Saving model with best val mse: 0.4963\n",
"2025-03-24 22:17:02 - training - INFO - ---------- Epoch 3 ----------\n",
"Training: 100%|β| 225/225 [01:24<00:00, 2.65it/s, grad_step=112, train_loss=0.6\n",
"2025-03-24 22:18:27 - training - INFO - Epoch 3 Train Loss: 0.4629\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.95it/s]\n",
"2025-03-24 22:18:34 - training - INFO - Epoch 3 Val Loss: 0.4929\n",
"2025-03-24 22:18:34 - training - INFO - Epoch 3 Val mse: 0.4929\n",
"2025-03-24 22:18:34 - training - INFO - Epoch 3 Val spearman_corr: 0.6736\n",
"2025-03-24 22:18:34 - training - INFO - Saving model with best val mse: 0.4929\n",
"2025-03-24 22:18:35 - training - INFO - ---------- Epoch 4 ----------\n",
"Training: 100%|β| 225/225 [01:24<00:00, 2.65it/s, grad_step=140, train_loss=0.4\n",
"2025-03-24 22:20:00 - training - INFO - Epoch 4 Train Loss: 0.4191\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.92it/s]\n",
"2025-03-24 22:20:07 - training - INFO - Epoch 4 Val Loss: 0.5352\n",
"2025-03-24 22:20:07 - training - INFO - Epoch 4 Val mse: 0.5352\n",
"2025-03-24 22:20:07 - training - INFO - Epoch 4 Val spearman_corr: 0.6639\n",
"2025-03-24 22:20:07 - training - INFO - ---------- Epoch 5 ----------\n",
"Training: 100%|β| 225/225 [01:25<00:00, 2.64it/s, grad_step=168, train_loss=0.4\n",
"2025-03-24 22:21:32 - training - INFO - Epoch 5 Train Loss: 0.3933\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.92it/s]\n",
"2025-03-24 22:21:40 - training - INFO - Epoch 5 Val Loss: 0.5278\n",
"2025-03-24 22:21:40 - training - INFO - Epoch 5 Val mse: 0.5278\n",
"2025-03-24 22:21:40 - training - INFO - Epoch 5 Val spearman_corr: 0.6612\n",
"2025-03-24 22:21:40 - training - INFO - ---------- Epoch 6 ----------\n",
"Training: 100%|β| 225/225 [01:25<00:00, 2.64it/s, grad_step=196, train_loss=0.2\n",
"2025-03-24 22:23:05 - training - INFO - Epoch 6 Train Loss: 0.3466\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.92it/s]\n",
"2025-03-24 22:23:12 - training - INFO - Epoch 6 Val Loss: 0.5450\n",
"2025-03-24 22:23:12 - training - INFO - Epoch 6 Val mse: 0.5450\n",
"2025-03-24 22:23:12 - training - INFO - Epoch 6 Val spearman_corr: 0.6431\n",
"2025-03-24 22:23:12 - training - INFO - Early stopping triggered after 3 epochs without improvement\n",
"2025-03-24 22:23:12 - training - INFO - Early stop at Epoch 6\n",
"/home/matwings/lc/VenusFactory-readme/src/training/trainer.py:406: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" checkpoint = torch.load(path, map_location=\"cpu\")\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"2025-03-24 22:23:13 - training - INFO - ---------- Starting Test Phase ----------\n",
"Testing: 100%|ββββββββββββββββββββββββββββββββββ| 26/26 [00:06<00:00, 3.90it/s]\n",
"2025-03-24 22:23:20 - training - INFO - Test Results:\n",
"2025-03-24 22:23:20 - training - INFO - Test Loss: 0.4153\n",
"2025-03-24 22:23:20 - training - INFO - Test mse: 0.4153\n",
"2025-03-24 22:23:20 - training - INFO - Test spearman_corr: 0.7557\n"
]
}
],
"source": [
"# ESM model target_modules name: query key value\n",
"# Bert_base(prot_bert) model target_modules name: query key value\n",
"# T5_base(ankh, t5) model target_modules name: q k v\n",
"\n",
"!export HF_ENDPOINT=https://hf-mirror.com # if need to use HF mirror\n",
"dataset=\"eSOL\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"lr=5e-4\n",
"training_method=\"plm-dora\"\n",
"sh=f\"\"\"\n",
"python src/train.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --dataset_config data/{dataset}/{dataset}_HF.json \\\n",
" --learning_rate {lr} \\\n",
" --gradient_accumulation_steps 8 \\\n",
" --num_epochs 10 \\\n",
" --batch_token 8000 \\\n",
" --patience 3 \\\n",
" --output_dir test_res/{dataset}/{plm_model} \\\n",
" --output_model_name {training_method}_lr_{lr}_8k_ga8.pt \\\n",
" --training_method {training_method} \\\n",
" --lora_target_modules query key value\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "052db453-489f-44e5-aeae-b8748ec8b5ac",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/train/train_plm_dora.sh ./train_plm_dora.sh\n",
"!bash ./train_plm_dora.sh"
]
},
{
"cell_type": "markdown",
"id": "ec4bb446-cad4-421a-907a-5962edf6737e",
"metadata": {},
"source": [
"#### [IA3](https://arxiv.org/abs/2205.05638)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "718c5603-ef42-4b23-bb7c-5a27ddea083c",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2025-03-25 00:09:03 - training - INFO - Starting training with configuration:\n",
"2025-03-25 00:09:03 - training - INFO - hidden_size: None\n",
"2025-03-25 00:09:03 - training - INFO - num_attention_head: 8\n",
"2025-03-25 00:09:03 - training - INFO - attention_probs_dropout: 0.1\n",
"2025-03-25 00:09:03 - training - INFO - plm_model: facebook/esm2_t6_8M_UR50D\n",
"2025-03-25 00:09:03 - training - INFO - pooling_method: mean\n",
"2025-03-25 00:09:03 - training - INFO - pooling_dropout: 0.1\n",
"2025-03-25 00:09:03 - training - INFO - dataset: tyang816/eSOL\n",
"2025-03-25 00:09:03 - training - INFO - dataset_config: data/eSOL/eSOL_HF.json\n",
"2025-03-25 00:09:03 - training - INFO - normalize: standard\n",
"2025-03-25 00:09:03 - training - INFO - num_labels: 1\n",
"2025-03-25 00:09:03 - training - INFO - problem_type: regression\n",
"2025-03-25 00:09:03 - training - INFO - pdb_type: None\n",
"2025-03-25 00:09:03 - training - INFO - train_file: None\n",
"2025-03-25 00:09:03 - training - INFO - valid_file: None\n",
"2025-03-25 00:09:03 - training - INFO - test_file: None\n",
"2025-03-25 00:09:03 - training - INFO - metrics: ['mse', 'spearman_corr']\n",
"2025-03-25 00:09:03 - training - INFO - seed: 3407\n",
"2025-03-25 00:09:03 - training - INFO - learning_rate: 0.0005\n",
"2025-03-25 00:09:03 - training - INFO - scheduler: None\n",
"2025-03-25 00:09:03 - training - INFO - warmup_steps: 0\n",
"2025-03-25 00:09:03 - training - INFO - num_workers: 4\n",
"2025-03-25 00:09:03 - training - INFO - batch_size: None\n",
"2025-03-25 00:09:03 - training - INFO - batch_token: 8000\n",
"2025-03-25 00:09:03 - training - INFO - num_epochs: 10\n",
"2025-03-25 00:09:03 - training - INFO - max_seq_len: -1\n",
"2025-03-25 00:09:03 - training - INFO - gradient_accumulation_steps: 8\n",
"2025-03-25 00:09:03 - training - INFO - max_grad_norm: -1\n",
"2025-03-25 00:09:03 - training - INFO - patience: 3\n",
"2025-03-25 00:09:03 - training - INFO - monitor: mse\n",
"2025-03-25 00:09:03 - training - INFO - monitor_strategy: min\n",
"2025-03-25 00:09:03 - training - INFO - training_method: plm-ia3\n",
"2025-03-25 00:09:03 - training - INFO - lora_r: 8\n",
"2025-03-25 00:09:03 - training - INFO - lora_alpha: 32\n",
"2025-03-25 00:09:03 - training - INFO - lora_dropout: 0.1\n",
"2025-03-25 00:09:03 - training - INFO - feedforward_modules: w0\n",
"2025-03-25 00:09:03 - training - INFO - lora_target_modules: ['query', 'key', 'value']\n",
"2025-03-25 00:09:03 - training - INFO - structure_seq: []\n",
"2025-03-25 00:09:03 - training - INFO - output_model_name: plm-ia3_lr_0.0005_8k_ga8.pt\n",
"2025-03-25 00:09:03 - training - INFO - output_root: ckpt\n",
"2025-03-25 00:09:03 - training - INFO - output_dir: ckpt/test_res/eSOL/esm2_t6_8M_UR50D\n",
"2025-03-25 00:09:03 - training - INFO - wandb: False\n",
"2025-03-25 00:09:03 - training - INFO - wandb_entity: None\n",
"2025-03-25 00:09:03 - training - INFO - wandb_project: VenusFactory\n",
"2025-03-25 00:09:03 - training - INFO - wandb_run_name: None\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"['', 'embeddings', 'embeddings.word_embeddings', 'embeddings.dropout', 'embeddings.position_embeddings', 'encoder', 'encoder.layer', 'encoder.layer.0', 'encoder.layer.0.attention', 'encoder.layer.0.attention.self', 'encoder.layer.0.attention.self.query', 'encoder.layer.0.attention.self.key', 'encoder.layer.0.attention.self.value', 'encoder.layer.0.attention.self.dropout', 'encoder.layer.0.attention.self.rotary_embeddings', 'encoder.layer.0.attention.output', 'encoder.layer.0.attention.output.dense', 'encoder.layer.0.attention.output.dropout', 'encoder.layer.0.attention.LayerNorm', 'encoder.layer.0.intermediate', 'encoder.layer.0.intermediate.dense', 'encoder.layer.0.output', 'encoder.layer.0.output.dense', 'encoder.layer.0.output.dropout', 'encoder.layer.0.LayerNorm', 'encoder.layer.1', 'encoder.layer.1.attention', 'encoder.layer.1.attention.self', 'encoder.layer.1.attention.self.query', 'encoder.layer.1.attention.self.key', 'encoder.layer.1.attention.self.value', 'encoder.layer.1.attention.self.dropout', 'encoder.layer.1.attention.self.rotary_embeddings', 'encoder.layer.1.attention.output', 'encoder.layer.1.attention.output.dense', 'encoder.layer.1.attention.output.dropout', 'encoder.layer.1.attention.LayerNorm', 'encoder.layer.1.intermediate', 'encoder.layer.1.intermediate.dense', 'encoder.layer.1.output', 'encoder.layer.1.output.dense', 'encoder.layer.1.output.dropout', 'encoder.layer.1.LayerNorm', 'encoder.layer.2', 'encoder.layer.2.attention', 'encoder.layer.2.attention.self', 'encoder.layer.2.attention.self.query', 'encoder.layer.2.attention.self.key', 'encoder.layer.2.attention.self.value', 'encoder.layer.2.attention.self.dropout', 'encoder.layer.2.attention.self.rotary_embeddings', 'encoder.layer.2.attention.output', 'encoder.layer.2.attention.output.dense', 'encoder.layer.2.attention.output.dropout', 'encoder.layer.2.attention.LayerNorm', 'encoder.layer.2.intermediate', 'encoder.layer.2.intermediate.dense', 'encoder.layer.2.output', 'encoder.layer.2.output.dense', 'encoder.layer.2.output.dropout', 'encoder.layer.2.LayerNorm', 'encoder.layer.3', 'encoder.layer.3.attention', 'encoder.layer.3.attention.self', 'encoder.layer.3.attention.self.query', 'encoder.layer.3.attention.self.key', 'encoder.layer.3.attention.self.value', 'encoder.layer.3.attention.self.dropout', 'encoder.layer.3.attention.self.rotary_embeddings', 'encoder.layer.3.attention.output', 'encoder.layer.3.attention.output.dense', 'encoder.layer.3.attention.output.dropout', 'encoder.layer.3.attention.LayerNorm', 'encoder.layer.3.intermediate', 'encoder.layer.3.intermediate.dense', 'encoder.layer.3.output', 'encoder.layer.3.output.dense', 'encoder.layer.3.output.dropout', 'encoder.layer.3.LayerNorm', 'encoder.layer.4', 'encoder.layer.4.attention', 'encoder.layer.4.attention.self', 'encoder.layer.4.attention.self.query', 'encoder.layer.4.attention.self.key', 'encoder.layer.4.attention.self.value', 'encoder.layer.4.attention.self.dropout', 'encoder.layer.4.attention.self.rotary_embeddings', 'encoder.layer.4.attention.output', 'encoder.layer.4.attention.output.dense', 'encoder.layer.4.attention.output.dropout', 'encoder.layer.4.attention.LayerNorm', 'encoder.layer.4.intermediate', 'encoder.layer.4.intermediate.dense', 'encoder.layer.4.output', 'encoder.layer.4.output.dense', 'encoder.layer.4.output.dropout', 'encoder.layer.4.LayerNorm', 'encoder.layer.5', 'encoder.layer.5.attention', 'encoder.layer.5.attention.self', 'encoder.layer.5.attention.self.query', 'encoder.layer.5.attention.self.key', 'encoder.layer.5.attention.self.value', 'encoder.layer.5.attention.self.dropout', 'encoder.layer.5.attention.self.rotary_embeddings', 'encoder.layer.5.attention.output', 'encoder.layer.5.attention.output.dense', 'encoder.layer.5.attention.output.dropout', 'encoder.layer.5.attention.LayerNorm', 'encoder.layer.5.intermediate', 'encoder.layer.5.intermediate.dense', 'encoder.layer.5.output', 'encoder.layer.5.output.dense', 'encoder.layer.5.output.dropout', 'encoder.layer.5.LayerNorm', 'encoder.emb_layer_norm_after', 'pooler', 'pooler.dense', 'pooler.activation', 'contact_head', 'contact_head.regression', 'contact_head.activation']\n",
"trainable params: 5,760 || all params: 7,845,881 || trainable%: 0.0734\n",
" Using plm IA3 \n",
"2025-03-25 00:09:04 - training - INFO - ------------------------\n",
"2025-03-25 00:09:04 - training - INFO - Model Parameters Statistics:\n",
"2025-03-25 00:09:04 - training - INFO - ------------------------\n",
"2025-03-25 00:09:04 - training - INFO - Adapter Model:\n",
"2025-03-25 00:09:04 - training - INFO - Total parameters: 103.04K\n",
"2025-03-25 00:09:04 - training - INFO - Trainable parameters: 103.04K\n",
"2025-03-25 00:09:04 - training - INFO - Pre-trained Model:\n",
"2025-03-25 00:09:04 - training - INFO - Total parameters: 7.85M\n",
"2025-03-25 00:09:04 - training - INFO - Trainable parameters: 5.76K\n",
"2025-03-25 00:09:04 - training - INFO - Combined:\n",
"2025-03-25 00:09:04 - training - INFO - Total parameters: 7.95M\n",
"2025-03-25 00:09:04 - training - INFO - Trainable parameters: 108.80K\n",
"2025-03-25 00:09:04 - training - INFO - Trainable percentage: 1.37%\n",
"2025-03-25 00:09:04 - training - INFO - ------------------------\n",
"2025-03-25 00:09:16 - training - INFO - Dataset Statistics:\n",
"2025-03-25 00:09:16 - training - INFO - ------------------------\n",
"2025-03-25 00:09:16 - training - INFO - Dataset: tyang816/eSOL\n",
"2025-03-25 00:09:16 - training - INFO - Number of train samples: 2481\n",
"2025-03-25 00:09:16 - training - INFO - Number of val samples: 310\n",
"2025-03-25 00:09:16 - training - INFO - Number of test samples: 310\n",
"2025-03-25 00:09:16 - training - INFO - Sample 3 data points from train dataset:\n",
"2025-03-25 00:09:16 - training - INFO - Train data point 1: {'name': 'P0ABL8', 'aa_seq': 'MMFWRIFRLELRVAFRHSAEIANPLWFFLIVITLFPLSIGPEPQLLARIAPGIIWVAALLSSLLALERLFRDDLQDGSLEQLMLLPLPLPAVVLAKVMAHWMVTGLPLLILSPLVAMLLGMDVYGWQVMALTLLLGTPTLGFLGAPGVALTVGLKRGGVLLSILVLPLTIPLLIFATAAMDAASMHLPVDGYLAILGALLAGTATLSPFATAAALRISIQ', 'gene': 'ccmB', 'label': -1.3882626995061573}\n",
"2025-03-25 00:09:16 - training - INFO - Train data point 2: {'name': 'P77721', 'aa_seq': 'MAAKDRIQAIKQMVANDKKVTVSNLSGIFQVTEETIRRDLEKLEDEGFLTRTYGGAVLNTAMLTENIHFYKRASSFYEEKQLIARKALPFIDNKTTMAADSSSTVMELLKLLQDRSGLTLLTNSAEAIHVLAQSEIKVVSTGGELNKNTLSLQGRITKEIIRRYHVDIMVMSCKGLDINSGALDSNEAEAEIKKTMIRQATEVALLVDHSKFDRKAFVQLADFSHINYIITDKSPGAEWIAFCKDNNIQLVW', 'gene': 'ydjF', 'label': -1.3882626995061573}\n",
"2025-03-25 00:09:16 - training - INFO - Train data point 3: {'name': 'Q47152', 'aa_seq': 'MSEYRRYYIKGGTWFFTVNLRNRRSQLLTTQYQMLRHAIIKVKRDRPFEINAWVVLPEHMHCIWTLPEGDDDFSSRWREIKKQFTHACGLKNIWQPRFWEHAIRNTKDYRHHVDYIYINPVKHGWVKQVSDWPFSTFHRDVARGLYPIDWAGDVTDFSAGERIIS', 'gene': 'yafM', 'label': -0.050581678782913427}\n",
"2025-03-25 00:09:16 - training - INFO - ------------------------\n",
"/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs) # noqa: B028\n",
"2025-03-25 00:09:16 - training - INFO - ---------- Epoch 0 ----------\n",
"Training: 100%|β| 220/220 [01:18<00:00, 2.79it/s, grad_step=27, train_loss=1.08\n",
"2025-03-25 00:10:35 - training - INFO - Epoch 0 Train Loss: 0.9042\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 3.91it/s]\n",
"2025-03-25 00:10:43 - training - INFO - Epoch 0 Val Loss: 0.7607\n",
"2025-03-25 00:10:43 - training - INFO - Epoch 0 Val mse: 0.7607\n",
"2025-03-25 00:10:43 - training - INFO - Epoch 0 Val spearman_corr: 0.4896\n",
"2025-03-25 00:10:43 - training - INFO - Saving model with best val mse: 0.7607\n",
"2025-03-25 00:10:44 - training - INFO - ---------- Epoch 1 ----------\n",
"Training: 100%|β| 220/220 [01:18<00:00, 2.82it/s, grad_step=55, train_loss=1.43\n",
"2025-03-25 00:12:02 - training - INFO - Epoch 1 Train Loss: 0.7125\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.04it/s]\n",
"2025-03-25 00:12:09 - training - INFO - Epoch 1 Val Loss: 0.6380\n",
"2025-03-25 00:12:09 - training - INFO - Epoch 1 Val mse: 0.6380\n",
"2025-03-25 00:12:09 - training - INFO - Epoch 1 Val spearman_corr: 0.5603\n",
"2025-03-25 00:12:09 - training - INFO - Saving model with best val mse: 0.6380\n",
"2025-03-25 00:12:10 - training - INFO - ---------- Epoch 2 ----------\n",
"Training: 100%|β| 220/220 [01:16<00:00, 2.89it/s, grad_step=82, train_loss=1.59\n",
"2025-03-25 00:13:26 - training - INFO - Epoch 2 Train Loss: 0.6343\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:06<00:00, 4.30it/s]\n",
"2025-03-25 00:13:32 - training - INFO - Epoch 2 Val Loss: 0.6153\n",
"2025-03-25 00:13:32 - training - INFO - Epoch 2 Val mse: 0.6153\n",
"2025-03-25 00:13:32 - training - INFO - Epoch 2 Val spearman_corr: 0.5853\n",
"2025-03-25 00:13:32 - training - INFO - Saving model with best val mse: 0.6153\n",
"2025-03-25 00:13:33 - training - INFO - ---------- Epoch 3 ----------\n",
"Training: 100%|β| 220/220 [01:16<00:00, 2.87it/s, grad_step=110, train_loss=1.2\n",
"2025-03-25 00:14:50 - training - INFO - Epoch 3 Train Loss: 0.6032\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.09it/s]\n",
"2025-03-25 00:14:57 - training - INFO - Epoch 3 Val Loss: 0.5875\n",
"2025-03-25 00:14:57 - training - INFO - Epoch 3 Val mse: 0.5875\n",
"2025-03-25 00:14:57 - training - INFO - Epoch 3 Val spearman_corr: 0.5955\n",
"2025-03-25 00:14:57 - training - INFO - Saving model with best val mse: 0.5875\n",
"2025-03-25 00:14:57 - training - INFO - ---------- Epoch 4 ----------\n",
"Training: 100%|β| 220/220 [01:18<00:00, 2.81it/s, grad_step=137, train_loss=1.3\n",
"2025-03-25 00:16:15 - training - INFO - Epoch 4 Train Loss: 0.5908\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.05it/s]\n",
"2025-03-25 00:16:23 - training - INFO - Epoch 4 Val Loss: 0.5810\n",
"2025-03-25 00:16:23 - training - INFO - Epoch 4 Val mse: 0.5810\n",
"2025-03-25 00:16:23 - training - INFO - Epoch 4 Val spearman_corr: 0.6094\n",
"2025-03-25 00:16:23 - training - INFO - Saving model with best val mse: 0.5810\n",
"2025-03-25 00:16:23 - training - INFO - ---------- Epoch 5 ----------\n",
"Training: 100%|β| 220/220 [01:17<00:00, 2.82it/s, grad_step=165, train_loss=0.8\n",
"2025-03-25 00:17:41 - training - INFO - Epoch 5 Train Loss: 0.5793\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.07it/s]\n",
"2025-03-25 00:17:48 - training - INFO - Epoch 5 Val Loss: 0.5638\n",
"2025-03-25 00:17:48 - training - INFO - Epoch 5 Val mse: 0.5638\n",
"2025-03-25 00:17:48 - training - INFO - Epoch 5 Val spearman_corr: 0.6185\n",
"2025-03-25 00:17:48 - training - INFO - Saving model with best val mse: 0.5638\n",
"2025-03-25 00:17:49 - training - INFO - ---------- Epoch 6 ----------\n",
"Training: 100%|β| 220/220 [01:17<00:00, 2.82it/s, grad_step=192, train_loss=1.1\n",
"2025-03-25 00:19:07 - training - INFO - Epoch 6 Train Loss: 0.5676\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.02it/s]\n",
"2025-03-25 00:19:14 - training - INFO - Epoch 6 Val Loss: 0.5680\n",
"2025-03-25 00:19:14 - training - INFO - Epoch 6 Val mse: 0.5680\n",
"2025-03-25 00:19:14 - training - INFO - Epoch 6 Val spearman_corr: 0.6253\n",
"2025-03-25 00:19:14 - training - INFO - ---------- Epoch 7 ----------\n",
"Training: 100%|β| 220/220 [01:18<00:00, 2.81it/s, grad_step=220, train_loss=1.0\n",
"2025-03-25 00:20:32 - training - INFO - Epoch 7 Train Loss: 0.5537\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.00it/s]\n",
"2025-03-25 00:20:40 - training - INFO - Epoch 7 Val Loss: 0.5544\n",
"2025-03-25 00:20:40 - training - INFO - Epoch 7 Val mse: 0.5544\n",
"2025-03-25 00:20:40 - training - INFO - Epoch 7 Val spearman_corr: 0.6276\n",
"2025-03-25 00:20:40 - training - INFO - Saving model with best val mse: 0.5544\n",
"2025-03-25 00:20:40 - training - INFO - ---------- Epoch 8 ----------\n",
"Training: 100%|β| 220/220 [01:18<00:00, 2.81it/s, grad_step=247, train_loss=1.1\n",
"2025-03-25 00:21:59 - training - INFO - Epoch 8 Train Loss: 0.5574\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:07<00:00, 4.06it/s]\n",
"2025-03-25 00:22:06 - training - INFO - Epoch 8 Val Loss: 0.5527\n",
"2025-03-25 00:22:06 - training - INFO - Epoch 8 Val mse: 0.5527\n",
"2025-03-25 00:22:06 - training - INFO - Epoch 8 Val spearman_corr: 0.6336\n",
"2025-03-25 00:22:06 - training - INFO - Saving model with best val mse: 0.5527\n",
"2025-03-25 00:22:06 - training - INFO - ---------- Epoch 9 ----------\n",
"Training: 100%|β| 220/220 [01:17<00:00, 2.85it/s, grad_step=275, train_loss=0.6\n",
"2025-03-25 00:23:23 - training - INFO - Epoch 9 Train Loss: 0.5469\n",
"Validating: 100%|βββββββββββββββββββββββββββββββ| 29/29 [00:06<00:00, 4.20it/s]\n",
"2025-03-25 00:23:30 - training - INFO - Epoch 9 Val Loss: 0.5440\n",
"2025-03-25 00:23:30 - training - INFO - Epoch 9 Val mse: 0.5440\n",
"2025-03-25 00:23:30 - training - INFO - Epoch 9 Val spearman_corr: 0.6338\n",
"2025-03-25 00:23:30 - training - INFO - Saving model with best val mse: 0.5440\n",
"/home/matwings/lc/VenusFactory-readme/src/training/trainer.py:427: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" checkpoint = torch.load(path, map_location=\"cpu\")\n",
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"2025-03-25 00:23:32 - training - INFO - ---------- Starting Test Phase ----------\n",
"Testing: 100%|ββββββββββββββββββββββββββββββββββ| 26/26 [00:06<00:00, 4.03it/s]\n",
"2025-03-25 00:23:38 - training - INFO - Test Results:\n",
"2025-03-25 00:23:38 - training - INFO - Test Loss: 0.5380\n",
"2025-03-25 00:23:38 - training - INFO - Test mse: 0.5380\n",
"2025-03-25 00:23:38 - training - INFO - Test spearman_corr: 0.6816\n"
]
}
],
"source": [
"# ESM model target_modules name: query key value\n",
"# Bert_base(prot_bert) model target_modules name: query key value\n",
"# T5_base(ankh, t5) model target_modules name: q k v\n",
"\n",
"!export HF_ENDPOINT=https://hf-mirror.com # if need to use HF mirror\n",
"dataset=\"eSOL\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"lr=5e-4\n",
"training_method=\"plm-ia3\"\n",
"sh=f\"\"\"\n",
"python src/train.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --dataset_config data/{dataset}/{dataset}_HF.json \\\n",
" --learning_rate {lr} \\\n",
" --gradient_accumulation_steps 8 \\\n",
" --num_epochs 10 \\\n",
" --batch_token 8000 \\\n",
" --patience 3 \\\n",
" --output_dir test_res/{dataset}/{plm_model} \\\n",
" --output_model_name {training_method}_lr_{lr}_8k_ga8.pt \\\n",
" --training_method {training_method} \\\n",
" --lora_target_modules query key value\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f2fed01-ab72-4381-8eb1-79346f553429",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/train/train_plm_ia3.sh ./train_plm_ia3.sh\n",
"!bash ./train_plm_ia3.sh"
]
},
{
"cell_type": "markdown",
"id": "8224fc63-7f83-40c2-a37a-d91c406946f6",
"metadata": {},
"source": [
"### 2. Model Evaluation\n",
"**```--eval_method``` must be coordinated with ```--training_method``` to ensure evaluation protocol matches your training strategy.**\n",
"\n",
"**```--test_file``` specifies the evaluation dataset source, supports local custom datasets and predefined datasets. You should replace it for your model path.** \n",
"\n",
"**```--model_path``` is the path to load model weights, you should replace it for your model path.**"
]
},
{
"cell_type": "markdown",
"id": "81c36b27-f06a-4ed8-b1b9-b5bdf4912963",
"metadata": {},
"source": [
"#### LoRA Model Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "c74632c4-61fa-4dc0-835b-4194434012d8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- Load Model ----------\n",
"Number of parameter: 0.10M\n",
"---------- Start Eval ----------\n",
"Total samples: 310\n",
"100%|βββββββββββββββββββββββββββ| 20/20 [00:04<00:00, 4.12it/s, eval_loss=1.82]\n",
"spearman_corr: 0.7341052889823914\n"
]
}
],
"source": [
"!export HF_ENDPOINT=https://hf-mirror.com\n",
"problem_type=\"regression\"\n",
"num_labels=\"1\"\n",
"dataset=\"eSOL\"\n",
"eval_method=\"plm-lora\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"# for the predefined data\n",
"sh=f\"\"\"\n",
"python src/eval.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --model_path ckpt/test_res/{dataset}/{plm_model}/{eval_method}_lr_0.0005_8k_ga8.pt \\\n",
" --eval_method {eval_method} \\\n",
" --dataset {dataset} \\\n",
" --test_file tyang816/{dataset} \\\n",
" --test_result_dir ckpt/debug_result/{dataset}/{eval_method}_{plm_model} \\\n",
" --num_labels {num_labels} \\\n",
" --problem_type {problem_type} \\\n",
" --batch_size 16 \\\n",
" --metrics spearman_corr\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51afed32-a85a-4834-aea2-4ff77f87ff37",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/eval/eval_plm_lora.sh ./eval_plm_lora.sh\n",
"!bash ./eval_plm_lora.sh"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3ed451cf-5656-4fad-9d4d-01156e35996f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- Load Model ----------\n",
"Number of parameter: 0.10M\n",
"---------- Start Eval ----------\n",
"Total samples: 310\n",
"100%|βββββββββββββββββββββββββββ| 20/20 [00:04<00:00, 4.06it/s, eval_loss=1.82]\n",
"spearman_corr: 0.7341052889823914\n"
]
}
],
"source": [
"!export HF_ENDPOINT=https://hf-mirror.com\n",
"problem_type=\"regression\"\n",
"num_labels=\"1\"\n",
"dataset=\"eSOL\"\n",
"eval_method=\"plm-lora\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"\n",
"# for local data need ensure exist the test_file path\n",
"sh=f\"\"\"\n",
"python src/eval.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --model_path ckpt/test_res/{dataset}/{plm_model}/{eval_method}_lr_0.0005_8k_ga8.pt \\\n",
" --eval_method {eval_method} \\\n",
" --dataset {dataset} \\\n",
" --test_file data/eSOL_local_data/{dataset} \\\n",
" --test_result_dir ckpt/debug_result/{dataset}/{eval_method}_{plm_model} \\\n",
" --num_labels {num_labels} \\\n",
" --problem_type {problem_type} \\\n",
" --batch_size 16 \\\n",
" --metrics spearman_corr\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2c9c892-ca88-44ea-8f66-bcdada643536",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/eval/eval_plm_lora.sh ./eval_plm_lora_local.sh\n",
"!bash ./eval_plm_lora_local.sh"
]
},
{
"cell_type": "markdown",
"id": "2d881159-6660-4bcf-8702-a6330d33067c",
"metadata": {},
"source": [
"#### SES-Adapter Model Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "94e9e361-37f5-4001-80d7-7040c96b9f12",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Enabled foldseek_seq based on structure_seq parameter\n",
"Enabled ss8_seq based on structure_seq parameter\n",
"---------- Load Model ----------\n",
"Number of parameter: 0.95M\n",
"---------- Start Eval ----------\n",
"Total samples: 310\n",
"100%|βββββββββββββββββββββββββββ| 20/20 [00:05<00:00, 3.87it/s, eval_loss=1.78]\n",
"spearman_corr: 0.6919500231742859\n"
]
}
],
"source": [
"!export HF_ENDPOINT=https://hf-mirror.com\n",
"problem_type=\"regression\"\n",
"num_labels=1\n",
"dataset=\"eSOL\"\n",
"pdb_type=\"AlphaFold2\" # note! ses-adapter need structure sequence\n",
"eval_method=\"ses-adapter\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"\n",
"# for predefined data\n",
"sh=f\"\"\"\n",
"python src/eval.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --model_path ckpt/test_res/{dataset}/{plm_model}/{eval_method}_{pdb_type}_lr_0.0005_bt8k_ga8.pt \\\n",
" --eval_method {eval_method} \\\n",
" --dataset {dataset} \\\n",
" --test_file tyang816/{dataset}_{pdb_type} \\\n",
" --test_result_dir ckpt/debug_result/{dataset}/{eval_method}_{plm_model} \\\n",
" --num_labels {num_labels} \\\n",
" --problem_type {problem_type} \\\n",
" --batch_size 16 \\\n",
" --structure_seq foldseek_seq,ss8_seq \\\n",
" --metrics spearman_corr\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4abd81b2-272a-43d4-80d3-c53f7a0f6ba1",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/eval/eval_plm_ses-adapter.sh ./eval_plm_ses-adapter.sh\n",
"!bash ./eval_plm_ses-adapter.sh"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "138f8828-e7d4-44fb-95dc-e82cdcb41cbd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Enabled foldseek_seq based on structure_seq parameter\n",
"Enabled ss8_seq based on structure_seq parameter\n",
"---------- Load Model ----------\n",
"Number of parameter: 0.95M\n",
"---------- Start Eval ----------\n",
"Total samples: 310\n",
"100%|βββββββββββββββββββββββββββ| 20/20 [00:05<00:00, 3.84it/s, eval_loss=1.78]\n",
"spearman_corr: 0.6919500231742859\n"
]
}
],
"source": [
"!export HF_ENDPOINT=https://hf-mirror.com\n",
"problem_type=\"regression\"\n",
"num_labels=1\n",
"dataset=\"eSOL\"\n",
"pdb_type=\"AlphaFold2\" # note! ses-adapter need structure sequence\n",
"eval_method=\"ses-adapter\"\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"\n",
"# for local data need ensure exist the test_file path\n",
"sh=f\"\"\"\n",
"python src/eval.py \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --model_path ckpt/test_res/{dataset}/{plm_model}/{eval_method}_{pdb_type}_lr_0.0005_bt8k_ga8.pt \\\n",
" --eval_method {eval_method} \\\n",
" --dataset {dataset} \\\n",
" --test_file data/eSOL_local_data/{dataset}_{pdb_type} \\\n",
" --test_result_dir ckpt/debug_result/{dataset}/{eval_method}_{plm_model} \\\n",
" --num_labels {num_labels} \\\n",
" --problem_type {problem_type} \\\n",
" --batch_size 16 \\\n",
" --structure_seq foldseek_seq,ss8_seq \\\n",
" --metrics spearman_corr\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c544a7a7-88b7-4eca-be02-aa1efa72da9b",
"metadata": {},
"outputs": [],
"source": [
"# Use bash script\n",
"!cp ./script/eval/eval_plm_ses-adapter_local.sh ./eval_plm_ses-adapter_local.sh\n",
"!bash ./eval_plm_ses-adapter_local.sh"
]
},
{
"cell_type": "markdown",
"id": "2515736a-ae1c-4603-9533-f69b8ba1453d",
"metadata": {},
"source": [
"#### For more evaluation scripts, see the dedicated scripts in ```VenusFactory/script/eval/```."
]
},
{
"cell_type": "markdown",
"id": "7d4fc61a-ef6a-4185-9b2b-343793dfcd1a",
"metadata": {},
"source": [
"### 3. Model prediction\n",
"Venufactory provides two distinct prediction workflows to match your use case: single and batch.\n",
"\n",
"For single mode, you can provide one input(amino acid sequence, Foldseek sequence, secondary structure sequence).\n",
"\n",
"For batch mode, you can provide a test file(csv format).\n",
"\n",
"**```--problem_type``` specifies the current problem type in [\"single_label_classification\", \"multi_label_classification\", \"regression\"].**\n",
"\n",
"**```--aa_seq``` amino acid sequence.**\n",
"\n",
"**```--foldseek_seq``` foldseek sequence (optional).**\n",
"\n",
"**```--ss8_seq``` secondary structure sequence (optional).**\n",
"\n",
"**```--structure_seq``` structure sequence types to use (comma-separated).**\n",
"\n",
"**```--input_file``` path to input CSV file with sequences.**\n",
"\n",
"**```--output_file``` path to output CSV file for predictions.**"
]
},
{
"cell_type": "markdown",
"id": "b5a2f7aa-4cb7-40b1-8737-dc2a8ee71560",
"metadata": {},
"source": [
"#### LoRA Model Prediction"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "77695507-dafe-4742-a981-28ad56984354",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- Loading Model and Tokenizer ----------\n",
"Model config not found at ckpt/test_res/eSOL/esm2_t6_8M_UR50D/config.json. Using command line arguments.\n",
"Warning: structure_seq was None, setting to empty string\n",
"Training method: plm-lora\n",
"Structure sequence: \n",
"Use foldseek: False\n",
"Use ss8: False\n",
"Problem type: regression\n",
"Number of labels: 1\n",
"Number of attention heads: 8\n",
"---------- Processing Input Sequences ----------\n",
"Processed input sequences with keys: dict_keys(['aa_seq_input_ids', 'aa_seq_attention_mask'])\n",
"---------- Running Prediction ----------\n",
"Prediction result: 1.4336968660354614\n",
"\n",
"---------- Prediction Results ----------\n",
"{\n",
" \"prediction\": 1.4336968660354614\n",
"}\n"
]
}
],
"source": [
"# For the single prediction\n",
"!export HF_ENDPOINT=https://hf-mirror.com\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"eval_method=\"plm-lora\"\n",
"problem_type=\"regression\"\n",
"num_labels=1\n",
"aa_seq=\"MAKEDNIEMQGTVLETLPNTMFRVELENGHVVTAHISGKMRKNYIRILTGDKVTVELTPYDLSKGRIVFRSR\"\n",
"#\n",
"sh=f\"\"\"\n",
"python src/predict.py \\\n",
" --eval_method {eval_method} \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --model_path ckpt/test_res/eSOL/{plm_model}/{eval_method}_lr_0.0005_8k_ga8.pt \\\n",
" --aa_seq {aa_seq} \\\n",
" --num_labels {num_labels} \\\n",
" --problem_type {problem_type}\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55ed468e-8708-412c-95aa-2e380eaaea7c",
"metadata": {},
"outputs": [],
"source": [
"# use bash script\n",
"!cp ./script/predict/predict_plm_lora.sh ./predict_plm_lora.sh\n",
"!bash ./predict_plm_lora.sh"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "042ed7a9-0780-4aa7-9b7c-222e41854128",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- Loading Model and Tokenizer ----------\n",
"Model config not found at ckpt/test_res/eSOL/esm2_t6_8M_UR50D/config.json. Using command line arguments.\n",
"Warning: structure_seq was None, setting to empty string\n",
"Training method: plm-lora\n",
"Structure sequence: \n",
"Use foldseek: False\n",
"Use ss8: False\n",
"Problem type: regression\n",
"Number of labels: 1\n",
"Number of attention heads: 8\n",
"---------- Reading input file: data/eSOL_local_data/eSOL/test.csv ----------\n",
"Found 310 sequences in input file\n",
"---------- Processing sequences ----------\n",
"Predicting: 100%|βββββββββββββββββββββββββββββ| 310/310 [00:34<00:00, 9.07it/s]\n",
"---------- Saving results to ckpt/debug_result/eSOL/esm2_t6_8M_UR50D/prediction_batch/plm-lora/result.csv ----------\n",
"Saved 310 prediction results\n",
"---------- Batch prediction completed successfully ----------\n"
]
}
],
"source": [
"# For the batch prediction\n",
"!export HF_ENDPOINT=https://hf-mirror.com\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"eval_method=\"plm-lora\"\n",
"problem_type=\"regression\"\n",
"num_labels=1\n",
"input_file=\"data/eSOL_local_data/eSOL/test.csv\"\n",
"sh=f\"\"\"\n",
"python src/predict_batch.py \\\n",
" --eval_method {eval_method} \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --model_path ckpt/test_res/eSOL/{plm_model}/{eval_method}_lr_0.0005_8k_ga8.pt \\\n",
" --num_labels {num_labels} \\\n",
" --problem_type {problem_type} \\\n",
" --input_file {input_file} \\\n",
" --output_dir ckpt/debug_result/eSOL/{plm_model}/prediction_batch/{eval_method} \\\n",
" --output_file result.csv\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "216bd3cc-0920-4a98-b36d-1eb5e725c6b1",
"metadata": {},
"outputs": [],
"source": [
"# use bash script\n",
"!cp ./script/predict/predict_batch_plm_lora.sh ./predict_batch_plm_lora.sh\n",
"!bash ./predict_batch_plm_lora.sh"
]
},
{
"cell_type": "markdown",
"id": "f19c107b-9950-4223-8c3b-a53ac5d10867",
"metadata": {},
"source": [
"#### SES-Adapter Model Prediction"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "801b18ed-2cc4-4e01-aafe-c802eaae7dbc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- Loading Model and Tokenizer ----------\n",
"Model config not found at ckpt/test_res/eSOL/esm2_t6_8M_UR50D/config.json. Using command line arguments.\n",
"Enabled foldseek_seq based on structure_seq parameter\n",
"Enabled ss8_seq based on structure_seq parameter\n",
"Training method: ses-adapter\n",
"Structure sequence: foldseek_seq,ss8_seq\n",
"Use foldseek: True\n",
"Use ss8: True\n",
"Problem type: regression\n",
"Number of labels: 1\n",
"Number of attention heads: 8\n",
"---------- Processing Input Sequences ----------\n",
"Processed input sequences with keys: dict_keys(['aa_seq_input_ids', 'aa_seq_attention_mask', 'foldseek_seq_input_ids', 'ss8_seq_input_ids'])\n",
"---------- Running Prediction ----------\n",
"Prediction result: 1.5001569986343384\n",
"\n",
"---------- Prediction Results ----------\n",
"{\n",
" \"prediction\": 1.5001569986343384\n",
"}\n"
]
}
],
"source": [
"# For the single prediction\n",
"!export HF_ENDPOINT=https://hf-mirror.com\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"eval_method=\"ses-adapter\"\n",
"problem_type=\"regression\"\n",
"num_labels=1\n",
"aa_seq=\"MAKEDNIEMQGTVLETLPNTMFRVELENGHVVTAHISGKMRKNYIRILTGDKVTVELTPYDLSKGRIVFRSR\"\n",
"ss8_seq=\"LLLLLLEEEEEEEEEEETTTEEEEEETTSLEEEEEELHHHHHTTLLLLTTLEEEEEEETTEEEEEEEEEELL\"\n",
"foldseek_seq=\"DDPQPFDKFKWFFADADPPQWTFTQTPVRDTAIEHEDPVCVVVVDDDDGGWMFIWGHHPVDNRYTYTDDTDD\"\n",
"sh=f\"\"\"\n",
"python src/predict.py \\\n",
" --eval_method {eval_method} \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --model_path ckpt/test_res/eSOL/{plm_model}/{eval_method}_AlphaFold2_lr_0.0005_bt8k_ga8.pt \\\n",
" --aa_seq {aa_seq} \\\n",
" --foldseek_seq {foldseek_seq} \\\n",
" --ss8_seq {ss8_seq} \\\n",
" --num_labels {num_labels} \\\n",
" --problem_type {problem_type} \\\n",
" --structure_seq foldseek_seq,ss8_seq\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a3e77c4-07c6-45e2-9cf4-07ee59b515f9",
"metadata": {},
"outputs": [],
"source": [
"# use bash script\n",
"!cp ./script/predict/predict_plm_ses-adapter.sh ./predict_plm_ses-adapter.sh\n",
"!bash ./predict_plm_ses-adapter.sh"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "88a51e52-7966-409d-8218-48e28d2f79d3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- Loading Model and Tokenizer ----------\n",
"Model config not found at ckpt/test_res/eSOL/esm2_t6_8M_UR50D/config.json. Using command line arguments.\n",
"Enabled foldseek_seq based on structure_seq parameter\n",
"Enabled ss8_seq based on structure_seq parameter\n",
"Training method: ses-adapter\n",
"Structure sequence: foldseek_seq,ss8_seq\n",
"Use foldseek: True\n",
"Use ss8: True\n",
"Problem type: regression\n",
"Number of labels: 1\n",
"Number of attention heads: 8\n",
"---------- Reading input file: data/eSOL_local_data/eSOL_AlphaFold2/test.csv ----------\n",
"Found 310 sequences in input file\n",
"---------- Processing sequences ----------\n",
"Predicting: 100%|βββββββββββββββββββββββββββββ| 310/310 [00:34<00:00, 9.06it/s]\n",
"---------- Saving results to ckpt/debug_result/eSOL/esm2_t6_8M_UR50D/prediction_batch/ses-adapter/result.csv ----------\n",
"Saved 310 prediction results\n",
"---------- Batch prediction completed successfully ----------\n"
]
}
],
"source": [
"# for the batch prediction\n",
"!export HF_ENDPOINT=https://hf-mirror.com\n",
"plm_source=\"facebook\"\n",
"plm_model=\"esm2_t6_8M_UR50D\"\n",
"eval_method=\"ses-adapter\"\n",
"problem_type=\"regression\"\n",
"num_labels=1\n",
"input_file=\"data/eSOL_local_data/eSOL_AlphaFold2/test.csv\"\n",
"sh=f\"\"\"\n",
"python src/predict_batch.py \\\n",
" --eval_method {eval_method} \\\n",
" --plm_model {plm_source}/{plm_model} \\\n",
" --model_path ckpt/test_res/eSOL/{plm_model}/{eval_method}_AlphaFold2_lr_0.0005_bt8k_ga8.pt \\\n",
" --num_labels {num_labels} \\\n",
" --problem_type {problem_type} \\\n",
" --input_file {input_file} \\\n",
" --output_dir ckpt/debug_result/eSOL/{plm_model}/prediction_batch/{eval_method} \\\n",
" --output_file result.csv \\\n",
" --structure_seq foldseek_seq,ss8_seq\n",
"\"\"\"\n",
"!{sh}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a156c65a-4cd9-4041-948d-9f96e55761ee",
"metadata": {},
"outputs": [],
"source": [
"# use bash script\n",
"!cp ./script/predict/predict_batch_plm_ses-adapter.sh ./predict_batch_plm_ses-adapter.sh\n",
"!bash ./predict_batch_plm_ses-adapter.sh"
]
},
{
"cell_type": "markdown",
"id": "685ffd15-24a5-40e3-a81d-4d7bb557f5bc",
"metadata": {},
"source": [
"#### For more evaluation scripts, see the dedicated scripts in ```VenusFactory/script/predict/```."
]
},
{
"cell_type": "markdown",
"id": "50213491-6b03-48ff-a6d4-82eb1c3df65e",
"metadata": {},
"source": [
"## π Data Collection Tools: Multi-source protein data acquisition"
]
},
{
"cell_type": "markdown",
"id": "ed813fb8-eb70-487b-aa45-a3018610880c",
"metadata": {},
"source": [
"### Download Components Help Guide\n",
"\n",
"\n",
"InterPro Metadata
\n",
"\n",
"- **Description**: Downloads protein domain information from InterPro database.\n",
"\n",
"- **Source**: [InterPro Database](https://www.ebi.ac.uk/interpro/)\n",
"\n",
"- **Download Options**:\n",
" - ```--interpro_id```: Download data for a specific InterPro domain (e.g., IPR000001)\n",
" - ```--interpro_json```: Batch download using a JSON file containing multiple InterPro entries\n",
"\n",
"- **Output Format**:\n",
"\n",
" ```\n",
" download/interpro_domain/\n",
" βββ IPR000001/\n",
" βββ detail.json # Detailed protein information\n",
" βββ meta.json # Metadata including accession and protein count\n",
" βββ uids.txt # List of UniProt IDs associated with this domain\n",
" ```\n",
" \n",
"\n",
"\n",
"RCSB Metadata
\n",
"\n",
"- **Description**: Downloads structural metadata from the RCSB Protein Data Bank.\n",
"\n",
"- **Source**: [RCSB PDB](https://www.rcsb.org/)\n",
"\n",
"- **Download Options**:\n",
" - ```--pdb_id```: Download metadata for a specific PDB entry (e.g., 1a0j)\n",
" - ```--pdb_id_file```: Batch download using a text file containing PDB IDs\n",
"\n",
"- **Output Format**:\n",
" ```\n",
" download/rcsb_metadata/\n",
" βββ 1a0j.json # Contains structure metadata including:\n",
" # - Resolution\n",
" # - Experimental method\n",
" # - Publication info\n",
" # - Chain information\n",
" ```\n",
" \n",
"\n",
"\n",
"UniProt Sequences
\n",
"\n",
"- **Description**: Downloads protein sequences from UniProt database.\n",
"\n",
"- **Source**: [UniProt](https://www.uniprot.org/)\n",
"\n",
"- **Download Options**:\n",
" - ```--uniprot_id```: Download sequence for a specific UniProt entry (e.g., P00734)\n",
" - ```--file```: Batch download using a text file containing UniProt IDs\n",
" - ```--merge```: Combine all sequences into a single FASTA file (optional)\n",
"\n",
"- **Output Format**:\n",
" ```\n",
" download/uniprot_sequences/\n",
" βββ P00734.fasta # Individual FASTA files (when not merged)\n",
" βββ merged.fasta # Combined sequences (when merge option is selected)\n",
" ```\n",
" \n",
"\n",
"\n",
"RCSB Structures
\n",
" \n",
"- **Description**: Downloads 3D structure files from RCSB Protein Data Bank.\n",
"\n",
"- **Source**: [RCSB PDB](https://www.rcsb.org/)\n",
"\n",
"- **Download Options**:\n",
" - ```--pdb_id```: Download structure for a specific PDB entry\n",
" - ```--pdb_id_file```: Batch download using a text file containing PDB IDs\n",
" - ```--type``` File Types:\n",
" * cif: mmCIF format (recommended)\n",
" * pdb: Legacy PDB format\n",
" * xml: PDBML/XML format\n",
" * sf: Structure factors\n",
" * mr: NMR restraints\n",
" - ```--unzip``` Option: Automatically decompress downloaded files\n",
"\n",
"- **Output Format**:\n",
" ```\n",
" download/rcsb_structures/\n",
" βββ 1a0j.pdb # Uncompressed structure file (with unzip)\n",
" βββ 1a0j.pdb.gz # Compressed structure file (without unzip)\n",
" ```\n",
" \n",
"\n",
"\n",
"AlphaFold2 Structures
\n",
" \n",
"- **Description**: Downloads predicted protein structures from AlphaFold Protein Structure Database.\n",
"\n",
"- **Source**: [AlphaFold DB](https://alphafold.ebi.ac.uk/)\n",
"\n",
"- **Download Options**:\n",
" - ```--uniprot_id```: Download structure for a specific UniProt entry\n",
" - ```--uniprot_id_file```: Batch download using a text file containing UniProt IDs\n",
" - ```--index_level```: Organize files in subdirectories based on ID prefix\n",
"\n",
"- **Output Format**:\n",
" ```\n",
" download/alphafold2_structures/\n",
" βββ P/ # With index_level=1\n",
" βββ P0/ # With index_level=2\n",
" βββ P00734.pdb # AlphaFold predicted structure\n",
" ```\n",
" \n",
"\n",
"\n",
"Common Features
\n",
"\n",
"- **Error Handling**: All components support error file generation\n",
"- **Output Directory**: Customizable output paths\n",
"- **Batch Processing**: Support for multiple IDs via file input\n",
"- **Progress Tracking**: Real-time download progress and status updates\n",
" \n",
"\n",
"\n",
"Input File Formats
\n",
" \n",
"- **PDB ID List** (for RCSB downloads):\n",
" ```\n",
" 1a0j\n",
" 4hhb\n",
" 1hho\n",
" ```\n",
"\n",
"- **UniProt ID List** (for UniProt and AlphaFold):\n",
" ```\n",
" P00734\n",
" P61823\n",
" Q8WZ42\n",
" ```\n",
"\n",
"- **InterPro JSON** (for batch InterPro downloads):\n",
" ```json\n",
" [\n",
" {\n",
" \"metadata\": {\n",
" \"accession\": \"IPR000001\"\n",
" }\n",
" },\n",
" {\n",
" \"metadata\": {\n",
" \"accession\": \"IPR000002\"\n",
" }\n",
" }\n",
" ]\n",
" ```\n",
" \n",
"\n",
"\n",
"Error Files
\n",
" \n",
"- When enabled, failed downloads are logged to `failed.txt` in the output directory:\n",
" ```\n",
" P00734 - Download failed: 404 Not Found\n",
" 1a0j - Connection timeout\n",
" ```\n",
" "
]
},
{
"cell_type": "markdown",
"id": "329270f1-0940-4dba-8f6c-d6253e69a7f8",
"metadata": {},
"source": [
"### Download InterPro Metadata"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "b9f439fa-2fa4-400b-8203-607fd48229d5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully downloaded IPR000003\n"
]
}
],
"source": [
"# download single data\n",
"!python src/crawler/metadata/download_interpro.py \\\n",
" --interpro_id IPR000003 \\\n",
" --out_dir data/interpro/meta_single \\\n",
" --error_file data/interpro/meta_single_error.csv"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "b1942fb2-abcb-4087-922c-9362894c72bf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100%|βββββββββββββββββββββββββββββββββββββββββββββ| 6/6 [06:47<00:00, 67.94s/it]\n"
]
}
],
"source": [
"# download batch data \n",
"# the JSON file template is provided in download/interpro_json.customization. You can modify to specify the Interpro IDs\n",
"!python src/crawler/metadata/download_interpro.py \\\n",
" --interpro_json data/interpro/batch.json \\\n",
" --out_dir data/interpro/meta_batch \\\n",
" --error_file data/interpro/meta_batch_error.csv"
]
},
{
"cell_type": "markdown",
"id": "a7b52cc2-e86e-498b-a8ca-a77456a2c908",
"metadata": {},
"source": [
"### Download RCSB Metadata"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "02d78d89-8111-4ad7-96b9-5b81a6609e05",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1A00 successfully downloaded\n"
]
}
],
"source": [
"# download single data\n",
"!python src/crawler/metadata/download_rcsb.py \\\n",
" --pdb_id 1A00 \\\n",
" --out_dir data/rcsb/meta_single \\\n",
" --error_file data/rcsb/meta_single_error.csv"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2510c4b8-5042-48bb-abf6-c6538d26ebc2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1A03 successfully downloaded: 100%|βββββββββββββββ| 4/4 [00:00<00:00, 6.20it/s]\n"
]
}
],
"source": [
"# download batch data\n",
"!python src/crawler/metadata/download_rcsb.py \\\n",
" --pdb_id_file download/rcsb.txt \\\n",
" --out_dir data/rcsb/meta_batch \\\n",
" --error_file data/rcsb/meta_batch_error.csv"
]
},
{
"cell_type": "markdown",
"id": "f4e8046c-50a8-4f0e-8644-e0a759ae0e06",
"metadata": {},
"source": [
"### Download UniProt Sequences"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5518cdd9-e456-4c3b-b445-cd9556b29c56",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A0A0C5B5G6.fasta successfully downloaded\n"
]
}
],
"source": [
"# download single data\n",
"!python src/crawler/sequence/download_uniprot_seq.py \\\n",
" --uniprot_id A0A0C5B5G6 \\\n",
" --out_dir data/uniprot/uniprot_single \\\n",
" --error_file data/uniprot/uniprot_single_error.csv"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "81cca464-f427-4179-9848-a74ed61e3a7f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A0JNW5.fasta successfully downloaded: 100%|βββββββ| 5/5 [00:01<00:00, 3.20it/s]\n"
]
}
],
"source": [
"# download batch data\n",
"!python src/crawler/sequence/download_uniprot_seq.py \\\n",
" --file download/uniprot.txt \\\n",
" --out_dir data/uniprot/uniprot_batch \\\n",
" --error_file data/uniprot/uniprot_batch_error.csv"
]
},
{
"cell_type": "markdown",
"id": "0695001b-5047-4a8e-ab27-4cf06dff797e",
"metadata": {},
"source": [
"### Download RCSB Structures"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b2e42ac5-4922-4e63-addb-b6669dcbcaf6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1A00.pdb.gz successfully downloaded\n"
]
}
],
"source": [
"# download single data\n",
"!python src/crawler/structure/download_rcsb.py \\\n",
" --pdb_id 1A00 \\\n",
" --out_dir data/structure/rcsb_single \\\n",
" --error_file data/structure/rcsb_single_error.csv"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0428ef96-e67e-44c7-b55d-394732b6838b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.\n",
"1A03.pdb.gz successfully downloaded: 100%|ββββββββ| 4/4 [00:01<00:00, 2.06it/s]\n"
]
}
],
"source": [
"# download batch data\n",
"!python src/crawler/structure/download_rcsb.py \\\n",
" --pdb_id_file download/rcsb.txt \\\n",
" --out_dir data/structure/rcsb_batch \\\n",
" --error_file data/structure/rcsb_batch_error.csv \\\n",
" --unzip"
]
},
{
"cell_type": "markdown",
"id": "9cd3387f-b017-43a0-8919-06473847453c",
"metadata": {},
"source": [
"### Download AlphaFold2 Structures"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "33aa0af2-b935-4f0a-a452-08e9ed006eb4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A0A0C5B5G6 successfully downloaded\n"
]
}
],
"source": [
"# download single data\n",
"!python src/crawler/structure/download_alphafold.py \\\n",
" --uniprot_id A0A0C5B5G6 \\\n",
" --out_dir data/structure/af2_single \\\n",
" --error_file data/structure/af2_single_error.csv"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "99cafc99-9b10-42ad-8c8b-cc572d372793",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A0A1B0GTW7 successfully downloaded: 100%|βββββββββ| 5/5 [00:03<00:00, 1.43it/s]\n"
]
}
],
"source": [
"# download batch data\n",
"!python src/crawler/structure/download_alphafold.py \\\n",
" --uniprot_id_file download/uniprot.txt \\\n",
" --out_dir data/structure/af2_batch \\\n",
" --error_file data/structure/af2_batch_error.csv \\\n",
" --index_level 1"
]
},
{
"cell_type": "markdown",
"id": "bc36f723-bfe0-462d-b991-754db47585da",
"metadata": {},
"source": [
"### Structure Sequence Tools"
]
},
{
"cell_type": "markdown",
"id": "c1f5eece-1ff6-4815-a2ea-14d94e17e6b4",
"metadata": {},
"source": [
"#### ESM3 Structure Sequence\n",
"Generate structure sequences using ESM-3. You can download the ```esm3_structure_encoder_v0.pth``` in [huggingface ](https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1/tree/main/data/weights)\n",
"\n",
"```--pdb_file```: Get a specific pdb structure sequence\n",
"\n",
"```--pdb_dir```: Get batch pdb structure sequences"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6189c793-6b1e-4573-bb4d-29c5c01b20ef",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/matwings/lc/VenusFactory-readme/src/data/get_esm3_structure_seq.py:31: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" state_dict = torch.load(\n",
"/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/esm/models/vqvae.py:286: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
" with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore\n"
]
}
],
"source": [
"# get a specific pdb structure sequence\n",
"!python src/data/get_esm3_structure_seq.py \\\n",
" --pdb_file download/alphafold2_structures/A0PK11.pdb\\\n",
" --out_file data/structure/esm2_ss.json"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "181384a2-61c6-4af6-a072-3e0a36fc0863",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/matwings/lc/VenusFactory-readme/src/data/get_esm3_structure_seq.py:31: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" state_dict = torch.load(\n",
" 0%| | 0/5 [00:00, ?it/s]/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/esm/models/vqvae.py:286: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
" with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore\n",
" 20%|βββββββββ | 1/5 [00:02<00:08, 2.17s/it]/home/matwings/gejian/anaconda3/envs/venus/lib/python3.10/site-packages/esm/models/vqvae.py:286: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
" with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore\n",
"100%|βββββββββββββββββββββββββββββββββββββββββββββ| 5/5 [00:04<00:00, 1.20it/s]\n"
]
}
],
"source": [
"# get batch pdb structure sequence\n",
"!python src/data/get_esm3_structure_seq.py \\\n",
" --pdb_dir download/alphafold2_structures\\\n",
" --out_file data/structure/esm2_ss_batch.json"
]
},
{
"cell_type": "markdown",
"id": "cf955238-269d-432a-97d4-858de0765985",
"metadata": {},
"source": [
"#### FoldSeek Structure Sequence\n",
"Generate secondary sequences. You can install FoldSeek use ```conda install -c conda-forge -c bioconda foldseek```\n",
"\n",
"```--pdb_dir```: Get batch pdb structure sequences"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "3e00619f-82bb-4f74-a9f4-018db7c2b51e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"createdb download/alphafold2_structures tmp_db/tmp_db \n",
"\n",
"MMseqs Version: \t1.3c64211\n",
"Chain name mode \t0\n",
"Write lookup file\t1\n",
"Threads \t96\n",
"Verbosity \t3\n",
"\n",
"Output file: tmp_db/tmp_db\n",
"[=================================================================] 100.00% 5 0s 7ms \n",
"Time for merging to tmp_db_ss: 0h 0m 0s 139ms\n",
"Time for merging to tmp_db_h: 0h 0m 0s 124ms\n",
"Time for merging to tmp_db_ca: 0h 0m 0s 140ms\n",
"Time for merging to tmp_db: 0h 0m 0s 111ms\n",
"Ignore 0 out of 5.\n",
"Too short: 0, incorrect 0.\n",
"Time for processing: 0h 0m 1s 652ms\n",
"lndb tmp_db/tmp_db_h tmp_db/tmp_db_ss_h \n",
"\n",
"MMseqs Version:\t1.3c64211\n",
"Verbosity\t3\n",
"\n",
"Time for processing: 0h 0m 0s 2ms\n",
"convert2fasta tmp_db/tmp_db_ss tmp_db/tmp_db_ss.fasta \n",
"\n",
"MMseqs Version:\t1.3c64211\n",
"Use header DB\tfalse\n",
"Verbosity \t3\n",
"\n",
"Start writing file to tmp_db/tmp_db_ss.fasta\n",
"Time for processing: 0h 0m 0s 3ms\n",
"5it [00:00, 56375.05it/s]\n"
]
}
],
"source": [
"!python src/data/get_foldseek_structure_seq.py \\\n",
" --pdb_dir download/alphafold2_structures\\\n",
" --out_file data/structure/foldseek_batch.json"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e4a8cc8-7170-483c-9e8d-4a2f282aa51b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venus",
"language": "python",
"name": "venus"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}