Synteract freezes when running in loop after approx 5K pairwise PPI comparison

#15
by rohitsatyam - opened

Hi @colinhorger @gleghorn

I am trying to run synteract using a while loop as shown below where I fetch two protein sequence at a time and loop synteract.py over and over again. I have to do it for 6 million pairwise comparison. I do this by activating the conda environment and opening 7 separate tabs with same enviornment and run 2K jobs in each. These jobs doesn't overwhelm GPU memory since at any time 7 jobs are running in parallel. However I see that the jobs in all the terminals freeze after a total of nearly 5K jobs are finished. Can you help me why is that so and how can I prevent that?

Screenshot_2025-02-11-22-33-02-16_68ed9935803a63844709bbd59cb0bdde.jpg

IMG_20250211_224033.jpg

Gleghorn Lab org

Hi @rohitsatyam ,

I'm not sure of the exact issue, but the highest throughput and most reliable (no freezing, crashing, etc.) inference will be achieved by running one process on your system - otherwise the registry can get overwhelmed. The simple inference script pasted below should run faster and more reliably than your current method. Please use something similar to this so that the model is initialized once per inference run (otherwise you are severely inflating our download count). Hopefully it is helpful. Please let me know if you have any other questions.

import torch
import re
import argparse
import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict
from tqdm.auto import tqdm


class PairDataset(Dataset):
    def __init__(self, sequences_a: List[str], sequences_b: List[str]):
        self.sequences_a = sequences_a
        self.sequences_b = sequences_b

    def __len__(self):
        return len(self.sequences_a)

    def __getitem__(self, idx: int) -> Tuple[str, str]:
        return self.sequences_a[idx], self.sequences_b[idx]
    

class PairCollator:
    def __init__(self, tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def sanitize_seq(self, seq: str) -> str:
        seq = ' '.join(list(re.sub(r'[UZOB]', 'X', seq)))
        return seq

    def __call__(self, batch: List[Tuple[str, str]]) -> Dict[str, torch.Tensor]:
        seqs_a, seqs_b, = zip(*batch)
        seqs = []
        for a, b in zip(seqs_a, seqs_b):
            seq = self.sanitize_seq(a) + ' [SEP] ' + self.sanitize_seq(b)
            seqs.append(seq)
        seqs = self.tokenizer(seqs, padding='longest', truncation=True, max_length=self.max_length, return_tensors='pt')
        return {
            'input_ids': seqs['input_ids'],
            'attention_mask': seqs['attention_mask'],
        }


def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    print(f"Loading model from {args.model_path}")
    model = BertForSequenceClassification.from_pretrained(args.model_path, attn_implementation="sdpa").eval().to(device)
    # When using PyTorch >= 2.5.1 on a linux machine, spda attention will greatly speed up inference
    tokenizer = BertTokenizer.from_pretrained(args.model_path)
    print(f"Tokenizer loaded")

    """
    Load your data into two lists of sequences, where you want the PPI for each pair sequences_a[i], sequences_b[i]
    We recommend trimmed sequence pairs that sum over 1022 tokens (for the 1024 max length limit of SYNTERACT)
    We also recommend sorting the sequences by length in descending order, as this will speed up inference by reducing padding

    Example:
        from datasets import load_dataset
        data = load_dataset('Synthyra/NEGATOME', split='combined')
        # Filter out examples where the total length exceeds 1022
        data = data.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= 1022)
        # Add a new column 'total_length' that is the sum of lengths of SeqA and SeqB
        data = data.map(lambda x: {"total_length": len(x['SeqA']) + len(x['SeqB'])})
        # Sort the dataset by 'total_length' in descending order (longest sequences first)
        data = data.sort("total_length", reverse=True)
        # Now retrieve the sorted sequences
        sequences_a = data['SeqA']
        sequences_b = data['SeqB']
    """
    print("Loading data...")
    sequences_a = []
    sequences_b = []

    print("Creating torch dataset...")
    pair_dataset = PairDataset(sequences_a, sequences_b)
    pair_collator = PairCollator(tokenizer, max_length=1024)
    data_loader = DataLoader(pair_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=pair_collator)

    all_seqs_a = []
    all_seqs_b = []
    all_probs = []
    all_preds = []

    print("Starting inference...")
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader, total=len(data_loader), desc="Batches processed")):
            # Because sequences are sorted, the initial estimate for time will be much longer than the actual time it will take
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            logits = model(input_ids, attention_mask=attention_mask).logits.detach().cpu()

            prob_of_interaction = torch.softmax(logits, dim=1)[:, 1] # can do 1 - this for no interaction prob
            pred = torch.argmax(logits, dim=1)

            # Store results
            batch_start = i * args.batch_size
            batch_end = min((i + 1) * args.batch_size, len(sequences_a))
            all_seqs_a.extend(sequences_a[batch_start:batch_end])
            all_seqs_b.extend(sequences_b[batch_start:batch_end])
            all_probs.extend(prob_of_interaction.tolist())
            all_preds.extend(pred.tolist())

    # round to 5 decimal places
    all_probs = [round(prob, 5) for prob in all_probs]

    # Create dataframe and save to CSV
    results_df = pd.DataFrame({
        'sequence_a': all_seqs_a,
        'sequence_b': all_seqs_b,
        'probabilities': all_probs,
        'prediction': all_preds
    })
    print(f"Saving results to {args.save_path}")
    results_df.to_csv(args.save_path, index=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='GleghornLab/SYNTERACT')
    parser.add_argument('--save_path', type=str, default='ppi_predictions.csv')
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--num_workers', type=int, default=0) # can increase to use multiprocessing for dataloader, 4 is a good value usually
    args = parser.parse_args()

    main(args)

The script produces a nice csv like this:

image.png

Gleghorn Lab org

Synthyra will have a version of Synteract coming out soon, internally we are calling it SynteractTurbo. 6 million pairwise comparisons will take quite a while with Synteract1.0. If you would like to collaborate, we could run some inference for you with SyteractTurbo to see if the outputs are helpful. Reach out at lhallee@udel.edu if you are interested.

Sign up or log in to comment