wangjin2000 commited on
Commit
8aefe80
·
verified ·
1 Parent(s): dfde78e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -220,9 +220,6 @@ inputs = tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_l
220
  with torch.no_grad():
221
  logits = loaded_model(**inputs).logits
222
 
223
- # train
224
- saved_path = train_function_no_sweeps(base_model_path,train_dataset, test_dataset)
225
-
226
  # Get predictions
227
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
228
  predictions = torch.argmax(logits, dim=2)
@@ -237,7 +234,10 @@ id2label = {
237
  for token, prediction in zip(tokens, predictions[0].numpy()):
238
  if token not in ['<pad>', '<cls>', '<eos>']:
239
  print((token, id2label[prediction]))
240
-
 
 
 
241
  # debug result
242
  dubug_result = saved_path #predictions #class_weights
243
 
 
220
  with torch.no_grad():
221
  logits = loaded_model(**inputs).logits
222
 
 
 
 
223
  # Get predictions
224
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
225
  predictions = torch.argmax(logits, dim=2)
 
234
  for token, prediction in zip(tokens, predictions[0].numpy()):
235
  if token not in ['<pad>', '<cls>', '<eos>']:
236
  print((token, id2label[prediction]))
237
+
238
+ # train
239
+ saved_path = train_function_no_sweeps(base_model_path,train_dataset, test_dataset)
240
+
241
  # debug result
242
  dubug_result = saved_path #predictions #class_weights
243