Spaces:
Sleeping
Sleeping
File size: 5,639 Bytes
c1e6692 088c2ad 996a1ec 3f3c29c dc7d693 66990c3 9bf3f2b 65391f8 6dddebe dc7d693 9bf3f2b 66990c3 3f3c29c db842cf 3f3c29c 66990c3 3f3c29c 996a1ec 3f3c29c 66990c3 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 996a1ec 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 3f3c29c 306f08b 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b 3f3c29c 9bf3f2b db9f444 9bf3f2b 1f65033 5935bca b1eabde 5935bca c214f12 5935bca 6dddebe 356d0ee b1eabde 356d0ee fe4ceb1 dc3cae8 356d0ee fe4ceb1 356d0ee b1eabde 356d0ee 5935bca 9bf3f2b 356d0ee a948bde 356d0ee 9bf3f2b fe4ceb1 5f8dde1 6dddebe 5935bca a948bde 6dddebe 5f8dde1 c1e6692 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os
import huggingface_hub
from huggingface_hub import hf_hub_download, login
# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
# Update labels based on the given conditions
for k, v in int_to_label.items():
if "KOREA" in v:
int_to_label[k] = "KOREA"
elif "KINGDOM" in v:
int_to_label[k] = "UK"
elif "RUSSIAN" in v:
int_to_label[k] = "RUSSIA"
class LogisticRegressionTorch(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super(LogisticRegressionTorch, self).__init__()
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.batch_norm(x)
out = self.linear(x)
return out
class BertClassifier(nn.Module):
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
super(BertClassifier, self).__init__()
self.bert = bert_model
self.classifier = classifier
self.num_labels = num_labels
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
pooled_output = outputs.hidden_states[-1][:, 0, :]
logits = self.classifier(pooled_output)
return logits
def load_model():
metadata_features = 0
N_UNIQUE_CLASSES = 38
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
input_size = 768 + metadata_features
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
token = os.getenv('HUGGINGFACE_TOKEN')
if token is None:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
login(token=token)
file_path = hf_hub_download(
repo_id="mawairon/noo_test",
filename="gena-blastln-bs33-lr4e-05-S168.pth",
use_auth_token=token
)
weights = torch.load(file_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
return model, tokenizer
model, tokenizer = load_model()
def analyze_dna(username, password, sequence):
valid_usernames = os.getenv('USERNAME').split(',')
env_password = os.getenv('PASSWORD')
if username not in valid_usernames or password != env_password:
return {"error": "Invalid username or password"}, ""
try:
# Remove all whitespace characters
sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
return {"error": "Sequence contains invalid characters"}, ""
if len(sequence) < 300:
return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
def get_logits(seq):
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
with torch.no_grad():
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
return logits
if len(sequence) > 3000:
num_shifts = len(sequence) // 1000
logits_sum = None
for i in range(num_shifts):
shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
logits = get_logits(shifted_sequence)
if logits_sum is None:
logits_sum = logits
else:
logits_sum += logits
logits_avg = logits_sum / num_shifts
else:
logits_avg = get_logits(sequence)
probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
top_5_labels = [int_to_label[i] for i in top_5_indices]
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Top 5 Most Likely Labels')
plt.gca().invert_yaxis()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
except Exception as e:
return {"error": str(e)}, ""
# Create a Gradio interface
demo = gr.Interface(
fn=analyze_dna,
inputs=[
gr.Textbox(label="Username"),
gr.Textbox(label="Password", type="password"),
gr.Textbox(label="DNA Sequence")
],
outputs=["json", "html"]
)
# Launch the interface
demo.launch()
|