mawairon commited on
Commit
513b115
·
verified ·
1 Parent(s): 9cef3cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -93,9 +93,10 @@ model = BertClassifier(base_model, log_reg, num_labels = N_UNIQUE_CLASSES)
93
  # Define a function to process the DNA sequence
94
  def analyze_dna(sequence):
95
  # Preprocess the input sequence
96
- inputs = tokenizer(sequence, return_tensors='pt')
 
97
  # Get model predictions
98
- outputs = model(**inputs)
99
 
100
  # Convert logits to probabilities
101
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
@@ -107,7 +108,7 @@ def analyze_dna(sequence):
107
  # Prepare the output as a list of tuples (class_index, probability)
108
  result = [(index, prob) for index, prob in zip(top_5_indices, top_5_probs)]
109
 
110
- return result
111
 
112
  # Create a Gradio interface
113
  demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")
 
93
  # Define a function to process the DNA sequence
94
  def analyze_dna(sequence):
95
  # Preprocess the input sequence
96
+ inputs = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
97
+
98
  # Get model predictions
99
+ _, outputs = model(**inputs)
100
 
101
  # Convert logits to probabilities
102
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
 
108
  # Prepare the output as a list of tuples (class_index, probability)
109
  result = [(index, prob) for index, prob in zip(top_5_indices, top_5_probs)]
110
 
111
+ return probabilities
112
 
113
  # Create a Gradio interface
114
  demo = gr.Interface(fn=analyze_dna, inputs="text", outputs="json")