Spaces:
Sleeping
Sleeping
File size: 4,757 Bytes
c1e6692 088c2ad 996a1ec 3f3c29c dc7d693 3f3c29c 996a1ec 3f3c29c 996a1ec 3f3c29c 996a1ec 3f3c29c 996a1ec 3f3c29c c1e6692 5f8dde1 3f3c29c 996a1ec 3f3c29c 9ab99fb 3f3c29c c1e6692 daf9507 996a1ec 83fe210 daf9507 1f65033 5f8dde1 513b115 1f65033 5f8dde1 996a1ec 17c6a2f 04805af 996a1ec 17c6a2f 1f65033 04805af 1f65033 dc7d693 5f8dde1 04805af 5f8dde1 c1e6692 1f65033 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
class LogisticRegressionTorch(nn.Module):
def __init__(self,
input_dim: int,
output_dim: int):
super(LogisticRegressionTorch, self).__init__()
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.batch_norm(x)
out = self.linear(x)
return out
class BertClassifier(nn.Module):
def __init__(self,
bert_model: AutoModel,
classifier: LogisticRegressionTorch,
num_labels: int):
super(BertClassifier, self).__init__()
self.bert = bert_model # Assume bert_model is an instance of a pre-trained BertModel
self.classifier = classifier
self.num_labels = num_labels
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None,
token_type_ids: torch.Tensor = None, labels: torch.Tensor = None):
# Extract outputs from the BERT model
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
# Take the hidden states from the last layer and extract the hidden state of the first token for each element in the batch
pooled_output = outputs.hidden_states[-1][:, 0, :]
assert pooled_output.shape == (input_ids.shape[0], 768), f"Expected shape ({input_ids.shape[0]}, 768), but got {pooled_output.shape}"
# to-do later!
# Pass the pooled output to the classifier to get the logits
logits = self.classifier(pooled_output)
# Compute loss if labels are provided (assuming using CrossEntropyLoss for classification)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
pred = logits.view(-1, self.num_labels)
observed = labels.view(-1)
loss = loss_fct(pred, observed)
# Return the loss and logits
return loss, logits
# Load the Hugging Face model and tokenizer
metadata_features = 0
N_UNIQUE_CLASSES = 38 ## or 38
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
# Initialize the classifier
input_size = 768 + metadata_features # featurizer output size + metadata size
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
# Load Weights
model_weights_path = 'gena-blastln-bs33-lr4e-05-S168.pth'
weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
# Creating Model
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
# Dictionary to decode model predictions
label_to_int = pd.read_pkl('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
# Define a function to process the DNA sequence
def analyze_dna(sequence):
# Preprocess the input sequence
inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
print("Tokenization done.")
# Get model predictions
_, logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
print("Forward pass done.")
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
print("Probabilities done.")
# Get the top 5 most likely classes
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
# Map indices to label names
top_5_labels = [int_to_label[i] for i in top_5_indices]
# Prepare the output as a list of tuples (label_name, probability)
#result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
# Plot histogram
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Top 5 Most Likely Labels')
plt.gca().invert_yaxis() # Highest probabilities at the top
#return result
# Create a Gradio interface
demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
# Launch the interface
demo.launch()
|