File size: 4,757 Bytes
c1e6692
088c2ad
996a1ec
3f3c29c
 
dc7d693
 
 
3f3c29c
 
 
 
 
 
 
 
 
996a1ec
3f3c29c
 
 
 
 
 
 
 
 
 
996a1ec
3f3c29c
 
 
 
 
 
 
 
 
996a1ec
3f3c29c
996a1ec
 
3f3c29c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1e6692
5f8dde1
3f3c29c
 
 
 
996a1ec
3f3c29c
 
 
 
 
 
 
9ab99fb
3f3c29c
 
 
 
c1e6692
daf9507
996a1ec
83fe210
daf9507
1f65033
 
 
 
5f8dde1
 
 
513b115
 
1f65033
5f8dde1
996a1ec
17c6a2f
 
04805af
 
996a1ec
17c6a2f
1f65033
04805af
 
 
 
1f65033
 
 
 
dc7d693
 
 
 
 
 
 
 
 
5f8dde1
 
04805af
5f8dde1
 
c1e6692
1f65033
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd



class LogisticRegressionTorch(nn.Module):

    def __init__(self,
                 input_dim: int,
                 output_dim: int):

        super(LogisticRegressionTorch, self).__init__()
        self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.batch_norm(x)
        out = self.linear(x)
        return out

class BertClassifier(nn.Module):

    def __init__(self,
                 bert_model: AutoModel,
                 classifier: LogisticRegressionTorch,
                 num_labels: int):

        super(BertClassifier, self).__init__()
        self.bert = bert_model  # Assume bert_model is an instance of a pre-trained BertModel
        self.classifier = classifier
        self.num_labels = num_labels

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
                token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
        # Extract outputs from the BERT model
        outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        
        # Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
        pooled_output = outputs.hidden_states[-1][:, 0, :]

        assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
        # to-do later!

        # Pass the pooled output to the classifier to get the logits
        logits = self.classifier(pooled_output)

        # Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
        loss = None

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            pred = logits.view(-1, self.num_labels)
            observed = labels.view(-1)
            loss = loss_fct(pred, observed)

        # Return the loss and logits
        return loss, logits



# Load the Hugging Face model and tokenizer

metadata_features = 0
N_UNIQUE_CLASSES = 38 ## or 38

base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)

# Initialize the classifier
input_size = 768 + metadata_features # featurizer output size + metadata size
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)

# Load Weights
model_weights_path = 'gena-blastln-bs33-lr4e-05-S168.pth'
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))

base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])

# Creating Model
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()

# Dictionary to decode model predictions
label_to_int = pd.read_pkl('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}

# Define a function to process the DNA sequence
def analyze_dna(sequence):
    # Preprocess the input sequence
    inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)

    print("Tokenization done.")
    # Get model predictions
    _, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

    print("Forward pass done.")
    
    # Convert logits to probabilities
    probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()

    print("Probabilities done.")
    # Get the top 5 most likely classes
    top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
    top_5_probs = [probabilities[i] for i in top_5_indices]
    
    # Map indices to label names
    top_5_labels = [int_to_label[i] for i in top_5_indices]
    
    # Prepare the output as a list of tuples (label_name, probability)
    #result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
    # Plot histogram
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.barh(top_5_labels, top_5_probs, color='skyblue')
    ax.set_xlabel('Probability')
    ax.set_title('Top 5 Most Likely Labels')
    plt.gca().invert_yaxis()  # Highest probabilities at the top

    #return result

# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")

# Launch the interface
demo.launch()