Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -81,25 +81,44 @@ def load_model():
|
|
81 |
model, tokenizer = load_model()
|
82 |
|
83 |
def analyze_dna(username, password, sequence):
|
84 |
-
|
|
|
85 |
env_password = os.getenv('PASSWORD')
|
86 |
|
87 |
if username not in valid_usernames or password != env_password:
|
88 |
return {"error": "Invalid username or password"}, ""
|
89 |
|
90 |
try:
|
91 |
-
|
92 |
-
sequence = sequence.replace(" ", "")
|
|
|
93 |
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
|
94 |
return {"error": "Sequence contains invalid characters"}, ""
|
95 |
|
96 |
if len(sequence) < 300:
|
97 |
return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
|
104 |
top_5_probs = [probabilities[i] for i in top_5_indices]
|
105 |
top_5_labels = [int_to_label[i] for i in top_5_indices]
|
|
|
81 |
model, tokenizer = load_model()
|
82 |
|
83 |
def analyze_dna(username, password, sequence):
|
84 |
+
|
85 |
+
valid_usernames = os.getenv('USERNAME').split(',')
|
86 |
env_password = os.getenv('PASSWORD')
|
87 |
|
88 |
if username not in valid_usernames or password != env_password:
|
89 |
return {"error": "Invalid username or password"}, ""
|
90 |
|
91 |
try:
|
92 |
+
# Remove all whitespace characters
|
93 |
+
sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
|
94 |
+
|
95 |
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
|
96 |
return {"error": "Sequence contains invalid characters"}, ""
|
97 |
|
98 |
if len(sequence) < 300:
|
99 |
return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
|
100 |
|
101 |
+
def get_logits(seq):
|
102 |
+
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
103 |
+
with torch.no_grad():
|
104 |
+
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
105 |
+
return logits
|
106 |
+
|
107 |
+
if len(sequence) > 3000:
|
108 |
+
num_shifts = len(sequence) // 1000
|
109 |
+
logits_sum = None
|
110 |
+
for i in range(num_shifts):
|
111 |
+
shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
|
112 |
+
logits = get_logits(shifted_sequence)
|
113 |
+
if logits_sum is None:
|
114 |
+
logits_sum = logits
|
115 |
+
else:
|
116 |
+
logits_sum += logits
|
117 |
+
logits_avg = logits_sum / num_shifts
|
118 |
+
else:
|
119 |
+
logits_avg = get_logits(sequence)
|
120 |
+
|
121 |
+
probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
|
122 |
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
|
123 |
top_5_probs = [probabilities[i] for i in top_5_indices]
|
124 |
top_5_labels = [int_to_label[i] for i in top_5_indices]
|