File size: 3,698 Bytes
107b8d2
 
 
 
45539e9
9b1cba9
f560239
107b8d2
 
 
 
302efca
 
 
eefdf2d
302efca
 
 
 
 
 
 
 
 
 
 
 
 
 
f560239
107b8d2
 
302efca
107b8d2
302efca
 
 
 
 
 
 
 
 
 
 
 
107b8d2
302efca
 
 
 
 
107b8d2
302efca
eefdf2d
302efca
 
f560239
9fec676
 
9b1cba9
 
 
 
 
 
 
 
 
 
9fec676
107b8d2
 
 
f560239
107b8d2
302efca
 
 
9b1cba9
302efca
9b1cba9
 
9fec676
 
 
 
9b1cba9
 
302efca
107b8d2
 
302efca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from torch.distributions.categorical import Categorical
import numpy as np
import pandas as pd

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("TianlaiChen/PepMLM-650M")
model = AutoModelForMaskedLM.from_pretrained("TianlaiChen/PepMLM-650M")

def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
    sequence = protein_seq + binder_seq
    tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)

    # Create a mask for the binder sequence
    binder_mask = torch.zeros(tensor_input.shape).to(model.device)
    binder_mask[0, -len(binder_seq)-1:-1] = 1

    # Mask the binder sequence in the input and create labels
    masked_input = tensor_input.clone().masked_fill_(binder_mask.bool(), tokenizer.mask_token_id)
    labels = tensor_input.clone().masked_fill_(~binder_mask.bool(), -100)

    with torch.no_grad():
        loss = model(masked_input, labels=labels).loss
    return np.exp(loss.item())
    

def generate_peptide(protein_seq, peptide_length, top_k, num_binders):

    peptide_length = int(peptide_length)
    top_k = int(top_k)
    num_binders = int(num_binders)

    binders_with_ppl = []

    for _ in range(num_binders):
        # Generate binder
        masked_peptide = '<mask>' * peptide_length
        input_sequence = protein_seq + masked_peptide
        inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)

        with torch.no_grad():
            logits = model(**inputs).logits
        mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
        logits_at_masks = logits[0, mask_token_indices]
    
        # Apply top-k sampling
        top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
        probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
        predicted_indices = Categorical(probabilities).sample()
        predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)

        generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')

        # Compute PPL for the generated binder
        ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)

        # Add the generated binder and its PPL to the results list
        binders_with_ppl.append([generated_binder, ppl_value])
        
        # Convert the list of lists to a pandas dataframe
        df = pd.DataFrame(binders_with_ppl, columns=["Binder", "Perplexity"])
    
        # Save the dataframe to a CSV file
        output_filename = "output.csv"
        df.to_csv(output_filename, index=False)


    return binders_with_ppl, output_filename


# Define the Gradio interface
interface = gr.Interface(
    fn=generate_peptide,
    inputs=[
        gr.Textbox(label="Protein Sequence", info="Enter protein sequence here", type="text"),
        gr.Slider(3, 50, value=15, label="Peptide Length", step=1, info='Default value is 15'),
        gr.Slider(1, 10, value=3, label="Top K Value", step=1, info='Default value is 3'),
        gr.Dropdown(choices=[1, 2, 4, 8, 16, 32], label="Number of Binders", value=1)
    ],
    outputs=[
    gr.Dataframe(
        headers=["Binder", "Perplexity"],
        datatype=["str", "number"],
        col_count=(2, "fixed")
    ),
    gr.outputs.File(label="Download CSV")
    ],
    title="PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling"
)

interface.launch()