mawairon commited on
Commit
8c5a0b0
·
verified ·
1 Parent(s): b7c4acd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -85,7 +85,7 @@ def load_model(model_name: str):
85
  return model, None
86
 
87
  else:
88
- return {"error": "Invalid model name"}
89
 
90
 
91
 
@@ -112,24 +112,29 @@ def analyze_dna(username, password, sequence, model_name):
112
  model, tokenizer = load_model(model_name)
113
 
114
  def get_logits(seq, model_name):
 
115
  if model_name == 'gena-bert':
 
116
  inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
117
  with torch.no_grad():
118
  logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
119
  return logits
120
 
121
  elif model_name == 'CNN':
122
- # Truncate sequence
123
  SEQUENCE_LENGTH = 8000
 
 
 
124
  seq = seq[:SEQUENCE_LENGTH]
125
 
126
  # Pad sequences to the desired length
127
- seq = seq.ljust(length, pad_char)[:SEQUENCE_LENGTH]
128
 
129
- # Apply one-hot encoding to the 'sequence' column
130
- input = seq.one_hot_encode()
131
  with torch.no_grad():
132
- logits = model(input)
133
  return logits
134
 
135
 
 
85
  return model, None
86
 
87
  else:
88
+ raise ValueError("Invalid model name")
89
 
90
 
91
 
 
112
  model, tokenizer = load_model(model_name)
113
 
114
  def get_logits(seq, model_name):
115
+
116
  if model_name == 'gena-bert':
117
+
118
  inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
119
  with torch.no_grad():
120
  logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
121
  return logits
122
 
123
  elif model_name == 'CNN':
124
+
125
  SEQUENCE_LENGTH = 8000
126
+ pad_char = 'N'
127
+
128
+ # Truncate sequence
129
  seq = seq[:SEQUENCE_LENGTH]
130
 
131
  # Pad sequences to the desired length
132
+ seq = seq.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
133
 
134
+ # Apply one-hot encoding to the sequence
135
+ input_tensor = one_hot_encode(seq).unsqueeze(0)
136
  with torch.no_grad():
137
+ logits = model(input_tensor)
138
  return logits
139
 
140