{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Finetuneing ESM-2 Models for CAFA-5" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Finetune an ESM-2 Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import DataLoader, Dataset\n", "from transformers import AutoTokenizer, EsmForSequenceClassification\n", "from accelerate import Accelerator\n", "from sklearn.model_selection import train_test_split\n", "from torchmetrics.classification import MultilabelF1Score\n", "from sklearn.metrics import accuracy_score, precision_score, recall_score, average_precision_score\n", "import datetime\n", "import pandas as pd\n", "\n", "# Load the data\n", "data = pd.read_csv(\"C:/Users/OWO/Desktop/amelie_vscode/cafa5/data/merged_protein_data.tsv\", sep=\"\\t\")\n", "# Use only the first 100 entries\n", "# data = data.head(100)\n", "\n", "# Initialize the accelerator\n", "accelerator = Accelerator()\n", "device = accelerator.device\n", "\n", "# Data Preprocessing\n", "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")\n", "MAX_LENGTH = tokenizer.model_max_length\n", "NUM_EPOCHS = 3\n", "LR = 5e-4\n", "BATCH_SIZE = 2\n", "\n", "class ProteinDataset(Dataset):\n", " def __init__(self, sequences, labels):\n", " self.sequences = sequences\n", " self.labels = labels\n", "\n", " def __len__(self):\n", " return len(self.sequences)\n", "\n", " def __getitem__(self, idx):\n", " sequence = self.sequences[idx]\n", " label = self.labels[idx]\n", " encoding = tokenizer(sequence, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=MAX_LENGTH)\n", " return {\n", " 'input_ids': encoding['input_ids'].flatten(),\n", " 'attention_mask': encoding['attention_mask'].flatten(),\n", " 'labels': torch.tensor(label, dtype=torch.float)\n", " }\n", "\n", "def encode_labels(go_terms, unique_terms):\n", " encoded = []\n", " for terms in go_terms:\n", " encoding = [1 if term in terms else 0 for term in unique_terms]\n", " encoded.append(encoding)\n", " return encoded\n", "\n", "train_sequences, val_sequences, train_labels, val_labels = train_test_split(data['sequence'], data['term'], test_size=0.1)\n", "\n", "# Reset the indices\n", "train_sequences = train_sequences.reset_index(drop=True)\n", "val_sequences = val_sequences.reset_index(drop=True)\n", "train_labels = train_labels.reset_index(drop=True)\n", "val_labels = val_labels.reset_index(drop=True)\n", "\n", "unique_terms = list(set(term for sublist in data['term'] for term in sublist))\n", "train_labels_encoded = encode_labels(train_labels, unique_terms)\n", "val_labels_encoded = encode_labels(val_labels, unique_terms)\n", "\n", "train_dataset = ProteinDataset(train_sequences, train_labels_encoded)\n", "val_dataset = ProteinDataset(val_sequences, val_labels_encoded)\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n", "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)\n", "\n", "# Model Training\n", "model = EsmForSequenceClassification.from_pretrained(\"facebook/esm2_t6_8M_UR50D\", num_labels=len(unique_terms), problem_type=\"multi_label_classification\")\n", "model = model.to(device)\n", "model.train()\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=LR)\n", "optimizer, model = accelerator.prepare(optimizer, model)\n", "\n", "# Initialize metrics\n", "f1_metric = MultilabelF1Score(num_labels=len(unique_terms), threshold=0.5)\n", "f1_metric = f1_metric.to(device)\n", "\n", "num_epochs = NUM_EPOCHS\n", "\n", "for epoch in range(num_epochs):\n", " total_loss = 0\n", " for batch in train_loader:\n", " optimizer.zero_grad()\n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " labels = batch['labels'].to(device)\n", "\n", " outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n", " loss = outputs.loss\n", " accelerator.backward(loss)\n", " optimizer.step()\n", "\n", " total_loss += loss.item()\n", "\n", " print(f'Epoch {epoch + 1}/{num_epochs}, Training loss: {total_loss/len(train_loader)}')\n", "\n", " model.eval()\n", " predictions = []\n", " true_labels_list = []\n", " with torch.no_grad():\n", " for batch in val_loader:\n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " labels = batch['labels'].to(device)\n", "\n", " outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n", " logits = outputs.logits\n", " predictions.append(torch.sigmoid(logits))\n", " true_labels_list.append(labels)\n", "\n", " predictions_tensor = torch.cat(predictions, dim=0).cpu().numpy()\n", " true_labels_tensor = torch.cat(true_labels_list, dim=0).cpu().numpy()\n", "\n", " threshold = 0.5\n", " predictions_bin = (predictions_tensor > threshold).astype(int)\n", "\n", " # Compute metrics\n", " val_f1 = f1_metric(torch.tensor(predictions_tensor).to(device), torch.tensor(true_labels_tensor).to(device))\n", " val_accuracy = accuracy_score(true_labels_tensor.flatten(), predictions_bin.flatten())\n", " val_precision = precision_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n", " val_recall = recall_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n", " val_auc = average_precision_score(true_labels_tensor, predictions_tensor, average='micro')\n", "\n", " # Print metrics\n", " print(f'Validation F1 Score: {val_f1}')\n", " print(f'Validation Accuracy: {val_accuracy}')\n", " print(f'Validation Precision: {val_precision}')\n", " print(f'Validation Recall: {val_recall}')\n", " print(f'Validation AUC: {val_auc}')\n", "\n", " timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n", " model_path = f'./esm2_t6_8M_finetuned_cafa5_{timestamp}'\n", " model.save_pretrained(model_path)\n", " tokenizer.save_pretrained(model_path)\n", "\n", " print(f'Model checkpoint saved to {model_path}')\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Save the Train/Validation Split Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "# After you've created the train and validation splits:\n", "data_splits = {\n", " \"train_sequences\": train_sequences,\n", " \"val_sequences\": val_sequences,\n", " \"train_labels\": train_labels,\n", " \"val_labels\": val_labels\n", "}\n", "\n", "with open('data_splits.pkl', 'wb') as file:\n", " pickle.dump(data_splits, file)\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Reload the Data Later" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "# Load the data splits\n", "with open('data_splits.pkl', 'rb') as file:\n", " data_splits = pickle.load(file)\n", "\n", "train_sequences = data_splits[\"train_sequences\"]\n", "val_sequences = data_splits[\"val_sequences\"]\n", "train_labels = data_splits[\"train_labels\"]\n", "val_labels = data_splits[\"val_labels\"]\n", "\n", "# Now, the rest of your code can proceed as it is, \n", "# with the train and validation sets loaded from the pickle file." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preprocessing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import DataLoader, Dataset\n", "from transformers import AutoTokenizer, EsmForSequenceClassification\n", "from accelerate import Accelerator\n", "from sklearn.model_selection import train_test_split\n", "from torchmetrics.classification import MultilabelF1Score\n", "from sklearn.metrics import accuracy_score, precision_score, recall_score, average_precision_score\n", "import datetime\n", "import pandas as pd\n", "\n", "# Load the data\n", "data = pd.read_csv(\"C:/Users/OWO/Desktop/amelie_vscode/cafa5/data/merged_protein_data.tsv\", sep=\"\\t\")\n", "# Use only the first 100 entries\n", "data = data.head(100)\n", "\n", "# Initialize the accelerator\n", "accelerator = Accelerator()\n", "device = accelerator.device\n", "\n", "# Data Preprocessing\n", "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")\n", "MAX_LENGTH = tokenizer.model_max_length\n", "\n", "class ProteinDataset(Dataset):\n", " def __init__(self, sequences, labels):\n", " self.sequences = sequences\n", " self.labels = labels\n", "\n", " def __len__(self):\n", " return len(self.sequences)\n", "\n", " def __getitem__(self, idx):\n", " sequence = self.sequences[idx]\n", " label = self.labels[idx]\n", " encoding = tokenizer(sequence, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=MAX_LENGTH)\n", " return {\n", " 'input_ids': encoding['input_ids'].flatten(),\n", " 'attention_mask': encoding['attention_mask'].flatten(),\n", " 'labels': torch.tensor(label, dtype=torch.float)\n", " }\n", "\n", "def encode_labels(go_terms, unique_terms):\n", " encoded = []\n", " for terms in go_terms:\n", " encoding = [1 if term in terms else 0 for term in unique_terms]\n", " encoded.append(encoding)\n", " return encoded\n", "\n", "# train_sequences, val_sequences, train_labels, val_labels = train_test_split(data['sequence'], data['term'], test_size=0.1)\n", "\n", "# Reset the indices\n", "# train_sequences = train_sequences.reset_index(drop=True)\n", "# val_sequences = val_sequences.reset_index(drop=True)\n", "# train_labels = train_labels.reset_index(drop=True)\n", "# val_labels = val_labels.reset_index(drop=True)\n", "\n", "unique_terms = list(set(term for sublist in data['term'] for term in sublist))\n", "train_labels_encoded = encode_labels(train_labels, unique_terms)\n", "val_labels_encoded = encode_labels(val_labels, unique_terms)\n", "\n", "train_dataset = ProteinDataset(train_sequences, train_labels_encoded)\n", "val_dataset = ProteinDataset(val_sequences, val_labels_encoded)\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)\n", "val_loader = DataLoader(val_dataset, batch_size=2)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Fine-tune with LoRA" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from collections import Counter\n", "from peft import get_peft_config, get_peft_model, LoraConfig\n", "import datetime\n", "from sklearn.metrics import accuracy_score, precision_score, recall_score, hamming_loss, average_precision_score\n", "from torchmetrics.classification import MultilabelF1Score\n", "\n", "# Constants\n", "MODEL_NAME = \"facebook/esm2_t6_8M_UR50D\" # Replace with your trained model above\n", "BATCH_SIZE = 4\n", "NUM_EPOCHS = 7\n", "LR = 3e-5\n", "\n", "# Initialize model with LoRA\n", "peft_config = LoraConfig(\n", " task_type=\"SEQ_CLS\", \n", " inference_mode=False, \n", " r=16, \n", " bias=\"none\",\n", " lora_alpha=16, \n", " lora_dropout=0.1, \n", " target_modules=[\"query\", \"key\", \"value\"]\n", ")\n", "\n", "base_model = EsmForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(unique_terms), problem_type=\"multi_label_classification\")\n", "model = get_peft_model(base_model, peft_config)\n", "model = model.to(accelerator.device)\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=LR)\n", "optimizer, model = accelerator.prepare(optimizer, model)\n", "\n", "f1_metric = MultilabelF1Score(num_labels=len(unique_terms), threshold=0.5)\n", "f1_metric = f1_metric.to(device)\n", "\n", "# Compute Class Weights\n", "def compute_class_weights(terms, term_to_id):\n", " all_terms = [term for terms_list in terms for term in terms_list]\n", " term_counts = Counter(all_terms)\n", " total_terms = sum(term_counts.values())\n", " class_weights = {term: total_terms / count for term, count in term_counts.items()}\n", " weights = torch.tensor([class_weights[term] for term in term_to_id.keys()], dtype=torch.float)\n", " normalized_weights = weights / weights.sum()\n", " return normalized_weights\n", "\n", "term_to_id = {term: idx for idx, term in enumerate(unique_terms)}\n", "all_terms_combined = train_labels.tolist() + val_labels.tolist()\n", "weights = compute_class_weights(all_terms_combined, term_to_id)\n", "weights = weights.to(accelerator.device)\n", "loss_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)\n", "\n", "# Training loop\n", "for epoch in range(NUM_EPOCHS):\n", " # Training Phase\n", " model.train()\n", " total_train_loss = 0\n", " for batch in train_loader:\n", " optimizer.zero_grad()\n", " input_ids = batch['input_ids'].to(accelerator.device)\n", " attention_mask = batch['attention_mask'].to(accelerator.device)\n", " labels = batch['labels'].to(accelerator.device)\n", "\n", " outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n", " logits = outputs.logits\n", " loss = loss_criterion(logits, labels)\n", " accelerator.backward(loss)\n", " optimizer.step()\n", "\n", " total_train_loss += loss.item()\n", "\n", " avg_train_loss = total_train_loss / len(train_loader)\n", "\n", " # Validation Phase\n", " model.eval()\n", " total_val_loss = 0\n", " predictions = []\n", " true_labels = []\n", " with torch.no_grad():\n", " for batch in val_loader:\n", " input_ids = batch['input_ids'].to(accelerator.device)\n", " attention_mask = batch['attention_mask'].to(accelerator.device)\n", " labels = batch['labels'].to(accelerator.device)\n", "\n", " outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n", " logits = outputs.logits\n", " loss = loss_criterion(logits, labels)\n", "\n", " total_val_loss += loss.item()\n", " predictions.append(torch.sigmoid(logits).detach())\n", " true_labels.append(labels.detach())\n", "\n", "\n", " avg_val_loss = total_val_loss / len(val_loader)\n", " \n", " predictions_tensor = torch.cat(predictions, dim=0).cpu().numpy()\n", " true_labels_tensor = torch.cat(true_labels, dim=0).cpu().numpy()\n", "\n", " threshold = 0.5\n", " predictions_bin = (predictions_tensor > threshold).astype(int)\n", "\n", " val_f1 = f1_metric(torch.tensor(predictions_tensor).to(device), torch.tensor(true_labels_tensor).to(device))\n", " val_accuracy = accuracy_score(true_labels_tensor.flatten(), predictions_bin.flatten())\n", " val_precision = precision_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n", " val_recall = recall_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')\n", " val_auc = average_precision_score(true_labels_tensor, predictions_tensor, average='micro')\n", "\n", " print(f\"Epoch {epoch + 1}/{NUM_EPOCHS} - Training Loss: {avg_train_loss:.4f} - Validation Loss: {avg_val_loss:.4f}\")\n", " print(f\"Validation Metrics - Accuracy: {val_accuracy:.4f} - Precision (Micro): {val_precision:.4f} - Recall (Micro): {val_recall:.4f} - AUC: {val_auc:.4f} - F1 Score: {val_f1:.4f}\")\n", "\n", " timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n", " # Save model and tokenizer. Note that Accelerator has a save method for models.\n", " model_path = f'./esm2_t6_8M_cafa5_lora_{timestamp}'\n", " model.save_pretrained(model_path)\n", " tokenizer.save_pretrained(model_path)\n", " model.base_model.save_pretrained(model_path)\n", " print(f'Model checkpoint saved to {model_path}')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "cafa_5", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.17" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }