mawairon commited on
Commit
93aa4b0
1 Parent(s): 809ae7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -11,7 +11,7 @@ import os
11
  import huggingface_hub
12
  from huggingface_hub import hf_hub_download, login
13
  import model_archs
14
- from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN
15
  import tangermeme
16
  from tangermeme.utils import one_hot_encode
17
 
@@ -144,11 +144,11 @@ def analyze_dna(username, password, sequence, model_name):
144
 
145
  model, tokenizer = load_model(model_name)
146
 
147
- def get_logits(seq, model_name):
148
 
149
  if model_name == 'GENA-Bert':
150
 
151
- inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
152
  with torch.no_grad():
153
  logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
154
  return logits
@@ -159,13 +159,15 @@ def analyze_dna(username, password, sequence, model_name):
159
  pad_char = 'N'
160
 
161
  # Truncate sequence
162
- seq = seq[:SEQUENCE_LENGTH]
163
 
164
  # Pad sequences to the desired length
165
- seq = seq.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
166
 
167
  # Apply one-hot encoding to the sequence
168
- input_tensor = one_hot_encode(seq).unsqueeze(0).float()
 
 
169
  with torch.no_grad():
170
  logits = model(input_tensor)
171
  return logits
 
11
  import huggingface_hub
12
  from huggingface_hub import hf_hub_download, login
13
  import model_archs
14
+ from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN, ResNet1d
15
  import tangermeme
16
  from tangermeme.utils import one_hot_encode
17
 
 
144
 
145
  model, tokenizer = load_model(model_name)
146
 
147
+ def get_logits(sequence, model_name):
148
 
149
  if model_name == 'GENA-Bert':
150
 
151
+ inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
152
  with torch.no_grad():
153
  logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
154
  return logits
 
159
  pad_char = 'N'
160
 
161
  # Truncate sequence
162
+ sequence = sequence[:SEQUENCE_LENGTH]
163
 
164
  # Pad sequences to the desired length
165
+ sequence = sequence.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
166
 
167
  # Apply one-hot encoding to the sequence
168
+ input_tensor = one_hot_encode(sequence).unsqueeze(0).float()
169
+ print(f'shape of input tensor{input_tensor.shape}')
170
+ model.eval()
171
  with torch.no_grad():
172
  logits = model(input_tensor)
173
  return logits