Synteract freezes when running in loop after approx 5K pairwise PPI comparison
I am trying to run synteract using a while loop as shown below where I fetch two protein sequence at a time and loop synteract.py over and over again. I have to do it for 6 million pairwise comparison. I do this by activating the conda environment and opening 7 separate tabs with same enviornment and run 2K jobs in each. These jobs doesn't overwhelm GPU memory since at any time 7 jobs are running in parallel. However I see that the jobs in all the terminals freeze after a total of nearly 5K jobs are finished. Can you help me why is that so and how can I prevent that?
Hi @rohitsatyam ,
I'm not sure of the exact issue, but the highest throughput and most reliable (no freezing, crashing, etc.) inference will be achieved by running one process on your system - otherwise the registry can get overwhelmed. The simple inference script pasted below should run faster and more reliably than your current method. Please use something similar to this so that the model is initialized once per inference run (otherwise you are severely inflating our download count). Hopefully it is helpful. Please let me know if you have any other questions.
import torch
import re
import argparse
import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict
from tqdm.auto import tqdm
class PairDataset(Dataset):
def __init__(self, sequences_a: List[str], sequences_b: List[str]):
self.sequences_a = sequences_a
self.sequences_b = sequences_b
def __len__(self):
return len(self.sequences_a)
def __getitem__(self, idx: int) -> Tuple[str, str]:
return self.sequences_a[idx], self.sequences_b[idx]
class PairCollator:
def __init__(self, tokenizer, max_length=1024):
self.tokenizer = tokenizer
self.max_length = max_length
def sanitize_seq(self, seq: str) -> str:
seq = ' '.join(list(re.sub(r'[UZOB]', 'X', seq)))
return seq
def __call__(self, batch: List[Tuple[str, str]]) -> Dict[str, torch.Tensor]:
seqs_a, seqs_b, = zip(*batch)
seqs = []
for a, b in zip(seqs_a, seqs_b):
seq = self.sanitize_seq(a) + ' [SEP] ' + self.sanitize_seq(b)
seqs.append(seq)
seqs = self.tokenizer(seqs, padding='longest', truncation=True, max_length=self.max_length, return_tensors='pt')
return {
'input_ids': seqs['input_ids'],
'attention_mask': seqs['attention_mask'],
}
def main(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Loading model from {args.model_path}")
model = BertForSequenceClassification.from_pretrained(args.model_path, attn_implementation="sdpa").eval().to(device)
# When using PyTorch >= 2.5.1 on a linux machine, spda attention will greatly speed up inference
tokenizer = BertTokenizer.from_pretrained(args.model_path)
print(f"Tokenizer loaded")
"""
Load your data into two lists of sequences, where you want the PPI for each pair sequences_a[i], sequences_b[i]
We recommend trimmed sequence pairs that sum over 1022 tokens (for the 1024 max length limit of SYNTERACT)
We also recommend sorting the sequences by length in descending order, as this will speed up inference by reducing padding
Example:
from datasets import load_dataset
data = load_dataset('Synthyra/NEGATOME', split='combined')
# Filter out examples where the total length exceeds 1022
data = data.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= 1022)
# Add a new column 'total_length' that is the sum of lengths of SeqA and SeqB
data = data.map(lambda x: {"total_length": len(x['SeqA']) + len(x['SeqB'])})
# Sort the dataset by 'total_length' in descending order (longest sequences first)
data = data.sort("total_length", reverse=True)
# Now retrieve the sorted sequences
sequences_a = data['SeqA']
sequences_b = data['SeqB']
"""
print("Loading data...")
sequences_a = []
sequences_b = []
print("Creating torch dataset...")
pair_dataset = PairDataset(sequences_a, sequences_b)
pair_collator = PairCollator(tokenizer, max_length=1024)
data_loader = DataLoader(pair_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=pair_collator)
all_seqs_a = []
all_seqs_b = []
all_probs = []
all_preds = []
print("Starting inference...")
with torch.no_grad():
for i, batch in enumerate(tqdm(data_loader, total=len(data_loader), desc="Batches processed")):
# Because sequences are sorted, the initial estimate for time will be much longer than the actual time it will take
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
logits = model(input_ids, attention_mask=attention_mask).logits.detach().cpu()
prob_of_interaction = torch.softmax(logits, dim=1)[:, 1] # can do 1 - this for no interaction prob
pred = torch.argmax(logits, dim=1)
# Store results
batch_start = i * args.batch_size
batch_end = min((i + 1) * args.batch_size, len(sequences_a))
all_seqs_a.extend(sequences_a[batch_start:batch_end])
all_seqs_b.extend(sequences_b[batch_start:batch_end])
all_probs.extend(prob_of_interaction.tolist())
all_preds.extend(pred.tolist())
# round to 5 decimal places
all_probs = [round(prob, 5) for prob in all_probs]
# Create dataframe and save to CSV
results_df = pd.DataFrame({
'sequence_a': all_seqs_a,
'sequence_b': all_seqs_b,
'probabilities': all_probs,
'prediction': all_preds
})
print(f"Saving results to {args.save_path}")
results_df.to_csv(args.save_path, index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='GleghornLab/SYNTERACT')
parser.add_argument('--save_path', type=str, default='ppi_predictions.csv')
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--num_workers', type=int, default=0) # can increase to use multiprocessing for dataloader, 4 is a good value usually
args = parser.parse_args()
main(args)
The script produces a nice csv like this:
Synthyra will have a version of Synteract coming out soon, internally we are calling it SynteractTurbo. 6 million pairwise comparisons will take quite a while with Synteract1.0. If you would like to collaborate, we could run some inference for you with SyteractTurbo to see if the outputs are helpful. Reach out at lhallee@udel.edu if you are interested.