wangjin2000 commited on
Commit
caa306e
·
verified ·
1 Parent(s): 45ee12b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -3
app.py CHANGED
@@ -81,7 +81,39 @@ class WeightedTrainer(Trainer):
81
  loss = compute_loss(model, inputs)
82
  return (loss, outputs) if return_outputs else loss
83
 
84
- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # fine-tuning function
86
  def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
87
 
@@ -102,8 +134,10 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
102
  #base_model_path = "facebook/esm2_t12_35M_UR50D"
103
 
104
  # Define labels and model
105
- id2label = {0: "No binding site", 1: "Binding site"}
106
- label2id = {v: k for k, v in id2label.items()}
 
 
107
  base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
108
 
109
  '''
@@ -289,12 +323,14 @@ with torch.no_grad():
289
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
290
  predictions = torch.argmax(logits, dim=2)
291
 
 
292
  # Define labels
293
  id2label = {
294
  0: "No binding site",
295
  1: "Binding site"
296
  }
297
 
 
298
  # Print the predicted labels for each token
299
  for token, prediction in zip(tokens, predictions[0].numpy()):
300
  if token not in ['<pad>', '<cls>', '<eos>']:
 
81
  loss = compute_loss(model, inputs)
82
  return (loss, outputs) if return_outputs else loss
83
 
84
+ # Predict binding site with finetuned PEFT model
85
+ def predict_bind(base_model_path,PEFT_model_path,input_seq):
86
+ # Load the model
87
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
88
+ loaded_model = PeftModel.from_pretrained(base_model, PEFT_model_path)
89
+
90
+ # Ensure the model is in evaluation mode
91
+ loaded_model.eval()
92
+
93
+ # Tokenization
94
+ tokenizer = AutoTokenizer.from_pretrained(base_model_path)
95
+
96
+ # Tokenize the sequence
97
+ inputs = tokenizer(input_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
98
+
99
+ # Run the model
100
+ with torch.no_grad():
101
+ logits = loaded_model(**inputs).logits
102
+
103
+ # Get predictions
104
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
105
+ predictions = torch.argmax(logits, dim=2)
106
+
107
+ binding_site=[]
108
+ # Print the predicted labels for each token
109
+ for n, token, prediction in enumerate(zip(tokens, predictions[0].numpy())):
110
+ if token not in ['<pad>', '<cls>', '<eos>']:
111
+ print((token, id2label[prediction]))
112
+ if prediction == 1:
113
+ print((n+1,token, id2label[prediction]))
114
+ binding_site.append(n+1,token, id2label[prediction])
115
+ return binding_site
116
+
117
  # fine-tuning function
118
  def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
119
 
 
134
  #base_model_path = "facebook/esm2_t12_35M_UR50D"
135
 
136
  # Define labels and model
137
+ #id2label = {0: "No binding site", 1: "Binding site"}
138
+ #label2id = {v: k for k, v in id2label.items()}
139
+
140
+
141
  base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
142
 
143
  '''
 
323
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
324
  predictions = torch.argmax(logits, dim=2)
325
 
326
+ '''
327
  # Define labels
328
  id2label = {
329
  0: "No binding site",
330
  1: "Binding site"
331
  }
332
 
333
+ '''
334
  # Print the predicted labels for each token
335
  for token, prediction in zip(tokens, predictions[0].numpy()):
336
  if token not in ['<pad>', '<cls>', '<eos>']: