wangjin2000 commited on
Commit
ae3b0b0
·
verified ·
1 Parent(s): cc251cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -295,6 +295,10 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
295
  accelerator = Accelerator()
296
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
297
 
 
 
 
 
298
  '''
299
  # inference
300
  # Path to the saved LoRA model
@@ -323,14 +327,13 @@ with torch.no_grad():
323
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
324
  predictions = torch.argmax(logits, dim=2)
325
 
326
- '''
327
  # Define labels
328
  id2label = {
329
  0: "No binding site",
330
  1: "Binding site"
331
  }
332
 
333
- '''
334
  # Print the predicted labels for each token
335
  for token, prediction in zip(tokens, predictions[0].numpy()):
336
  if token not in ['<pad>', '<cls>', '<eos>']:
 
295
  accelerator = Accelerator()
296
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
297
 
298
+ # Define labels and model
299
+ id2label = {0: "No binding site", 1: "Binding site"}
300
+ label2id = {v: k for k, v in id2label.items()}
301
+
302
  '''
303
  # inference
304
  # Path to the saved LoRA model
 
327
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
328
  predictions = torch.argmax(logits, dim=2)
329
 
330
+
331
  # Define labels
332
  id2label = {
333
  0: "No binding site",
334
  1: "Binding site"
335
  }
336
 
 
337
  # Print the predicted labels for each token
338
  for token, prediction in zip(tokens, predictions[0].numpy()):
339
  if token not in ['<pad>', '<cls>', '<eos>']: