Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,17 +3,18 @@ import transformers
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
import pandas as pd
|
|
|
|
|
|
|
8 |
|
9 |
-
|
|
|
|
|
10 |
|
11 |
class LogisticRegressionTorch(nn.Module):
|
12 |
|
13 |
-
def __init__(self,
|
14 |
-
input_dim: int,
|
15 |
-
output_dim: int):
|
16 |
-
|
17 |
super(LogisticRegressionTorch, self).__init__()
|
18 |
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
|
19 |
self.linear = nn.Linear(input_dim, output_dim)
|
@@ -25,11 +26,7 @@ class LogisticRegressionTorch(nn.Module):
|
|
25 |
|
26 |
class BertClassifier(nn.Module):
|
27 |
|
28 |
-
def __init__(self,
|
29 |
-
bert_model: AutoModel,
|
30 |
-
classifier: LogisticRegressionTorch,
|
31 |
-
num_labels: int):
|
32 |
-
|
33 |
super(BertClassifier, self).__init__()
|
34 |
self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
|
35 |
self.classifier = classifier
|
@@ -61,22 +58,20 @@ class BertClassifier(nn.Module):
|
|
61 |
# Return the loss and logits
|
62 |
return loss, logits
|
63 |
|
64 |
-
|
65 |
-
|
66 |
# Load the Hugging Face model and tokenizer
|
67 |
|
68 |
metadata_features = 0
|
69 |
-
N_UNIQUE_CLASSES = 38
|
70 |
|
71 |
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
|
72 |
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
|
73 |
|
74 |
# Initialize the classifier
|
75 |
-
input_size = 768 + metadata_features
|
76 |
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
|
77 |
|
78 |
# Load Weights
|
79 |
-
model_weights_path = 'gena-blastln-bs33-lr4e-05-S168.pth'
|
80 |
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
|
81 |
|
82 |
base_model.load_state_dict(weights['model_state_dict'])
|
@@ -84,11 +79,6 @@ log_reg.load_state_dict(weights['log_reg_state_dict'])
|
|
84 |
|
85 |
# Creating Model
|
86 |
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
|
87 |
-
model.eval()
|
88 |
-
|
89 |
-
# Dictionary to decode model predictions
|
90 |
-
label_to_int = pd.read_pkl('label_to_int.pkl')
|
91 |
-
int_to_label = {v: k for k, v in label_to_int.items()}
|
92 |
|
93 |
# Define a function to process the DNA sequence
|
94 |
def analyze_dna(sequence):
|
@@ -113,20 +103,27 @@ def analyze_dna(sequence):
|
|
113 |
top_5_labels = [int_to_label[i] for i in top_5_indices]
|
114 |
|
115 |
# Prepare the output as a list of tuples (label_name, probability)
|
116 |
-
|
|
|
117 |
# Plot histogram
|
|
|
118 |
fig, ax = plt.subplots(figsize=(10, 6))
|
119 |
ax.barh(top_5_labels, top_5_probs, color='skyblue')
|
120 |
ax.set_xlabel('Probability')
|
121 |
ax.set_title('Top 5 Most Likely Labels')
|
122 |
plt.gca().invert_yaxis() # Highest probabilities at the top
|
123 |
|
124 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
# Create a Gradio interface
|
127 |
-
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
|
128 |
|
129 |
# Launch the interface
|
130 |
demo.launch()
|
131 |
-
|
132 |
-
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
|
4 |
import torch
|
5 |
import torch.nn as nn
|
|
|
6 |
import pandas as pd
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import io
|
9 |
+
import base64
|
10 |
|
11 |
+
# Assuming label_to_int is a dictionary with {label_name: label_index}
|
12 |
+
label_to_int = pd.read_pickle('label_to_int.pkl')
|
13 |
+
int_to_label = {v: k for k, v in label_to_int.items()}
|
14 |
|
15 |
class LogisticRegressionTorch(nn.Module):
|
16 |
|
17 |
+
def __init__(self, input_dim: int, output_dim: int):
|
|
|
|
|
|
|
18 |
super(LogisticRegressionTorch, self).__init__()
|
19 |
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
|
20 |
self.linear = nn.Linear(input_dim, output_dim)
|
|
|
26 |
|
27 |
class BertClassifier(nn.Module):
|
28 |
|
29 |
+
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
|
|
|
|
|
|
|
|
|
30 |
super(BertClassifier, self).__init__()
|
31 |
self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
|
32 |
self.classifier = classifier
|
|
|
58 |
# Return the loss and logits
|
59 |
return loss, logits
|
60 |
|
|
|
|
|
61 |
# Load the Hugging Face model and tokenizer
|
62 |
|
63 |
metadata_features = 0
|
64 |
+
N_UNIQUE_CLASSES = 38 # or 38
|
65 |
|
66 |
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
|
67 |
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
|
68 |
|
69 |
# Initialize the classifier
|
70 |
+
input_size = 768 + metadata_features # featurizer output size + metadata size
|
71 |
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
|
72 |
|
73 |
# Load Weights
|
74 |
+
model_weights_path = 'model/gena-blastln-bs33-lr4e-05-S168.pth'
|
75 |
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
|
76 |
|
77 |
base_model.load_state_dict(weights['model_state_dict'])
|
|
|
79 |
|
80 |
# Creating Model
|
81 |
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
# Define a function to process the DNA sequence
|
84 |
def analyze_dna(sequence):
|
|
|
103 |
top_5_labels = [int_to_label[i] for i in top_5_indices]
|
104 |
|
105 |
# Prepare the output as a list of tuples (label_name, probability)
|
106 |
+
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
|
107 |
+
|
108 |
# Plot histogram
|
109 |
+
|
110 |
fig, ax = plt.subplots(figsize=(10, 6))
|
111 |
ax.barh(top_5_labels, top_5_probs, color='skyblue')
|
112 |
ax.set_xlabel('Probability')
|
113 |
ax.set_title('Top 5 Most Likely Labels')
|
114 |
plt.gca().invert_yaxis() # Highest probabilities at the top
|
115 |
|
116 |
+
# Save plot to a PNG image in memory
|
117 |
+
buf = io.BytesIO()
|
118 |
+
plt.savefig(buf, format='png')
|
119 |
+
buf.seek(0)
|
120 |
+
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
121 |
+
buf.close()
|
122 |
+
|
123 |
+
return result, f'<img src="data:image/png;base64,{image_base64}" />'
|
124 |
|
125 |
# Create a Gradio interface
|
126 |
+
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs=["json", "html"])
|
127 |
|
128 |
# Launch the interface
|
129 |
demo.launch()
|
|
|
|