Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -146,7 +146,7 @@ def analyze_dna(username, password, sequence, model_name):
|
|
146 |
|
147 |
def get_logits(seq, model_name):
|
148 |
|
149 |
-
if model_name == '
|
150 |
|
151 |
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
152 |
with torch.no_grad():
|
@@ -165,10 +165,13 @@ def analyze_dna(username, password, sequence, model_name):
|
|
165 |
seq = seq.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
|
166 |
|
167 |
# Apply one-hot encoding to the sequence
|
168 |
-
input_tensor = one_hot_encode(seq).unsqueeze(0)
|
169 |
with torch.no_grad():
|
170 |
logits = model(input_tensor)
|
171 |
return logits
|
|
|
|
|
|
|
172 |
|
173 |
|
174 |
# if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):
|
|
|
146 |
|
147 |
def get_logits(seq, model_name):
|
148 |
|
149 |
+
if model_name == 'GENA-Bert':
|
150 |
|
151 |
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
152 |
with torch.no_grad():
|
|
|
165 |
seq = seq.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
|
166 |
|
167 |
# Apply one-hot encoding to the sequence
|
168 |
+
input_tensor = one_hot_encode(seq).unsqueeze(0).float()
|
169 |
with torch.no_grad():
|
170 |
logits = model(input_tensor)
|
171 |
return logits
|
172 |
+
|
173 |
+
else:
|
174 |
+
raise ValueError("Invalid model name")
|
175 |
|
176 |
|
177 |
# if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):
|