mawairon commited on
Commit
809ae7d
·
verified ·
1 Parent(s): b312d27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -146,7 +146,7 @@ def analyze_dna(username, password, sequence, model_name):
146
 
147
  def get_logits(seq, model_name):
148
 
149
- if model_name == 'gena-bert':
150
 
151
  inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
152
  with torch.no_grad():
@@ -165,10 +165,13 @@ def analyze_dna(username, password, sequence, model_name):
165
  seq = seq.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
166
 
167
  # Apply one-hot encoding to the sequence
168
- input_tensor = one_hot_encode(seq).unsqueeze(0)
169
  with torch.no_grad():
170
  logits = model(input_tensor)
171
  return logits
 
 
 
172
 
173
 
174
  # if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):
 
146
 
147
  def get_logits(seq, model_name):
148
 
149
+ if model_name == 'GENA-Bert':
150
 
151
  inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
152
  with torch.no_grad():
 
165
  seq = seq.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
166
 
167
  # Apply one-hot encoding to the sequence
168
+ input_tensor = one_hot_encode(seq).unsqueeze(0).float()
169
  with torch.no_grad():
170
  logits = model(input_tensor)
171
  return logits
172
+
173
+ else:
174
+ raise ValueError("Invalid model name")
175
 
176
 
177
  # if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):