Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -81,7 +81,39 @@ class WeightedTrainer(Trainer):
|
|
81 |
loss = compute_loss(model, inputs)
|
82 |
return (loss, outputs) if return_outputs else loss
|
83 |
|
84 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# fine-tuning function
|
86 |
def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
|
87 |
|
@@ -102,8 +134,10 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
102 |
#base_model_path = "facebook/esm2_t12_35M_UR50D"
|
103 |
|
104 |
# Define labels and model
|
105 |
-
id2label = {0: "No binding site", 1: "Binding site"}
|
106 |
-
label2id = {v: k for k, v in id2label.items()}
|
|
|
|
|
107 |
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
|
108 |
|
109 |
'''
|
@@ -289,12 +323,14 @@ with torch.no_grad():
|
|
289 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
290 |
predictions = torch.argmax(logits, dim=2)
|
291 |
|
|
|
292 |
# Define labels
|
293 |
id2label = {
|
294 |
0: "No binding site",
|
295 |
1: "Binding site"
|
296 |
}
|
297 |
|
|
|
298 |
# Print the predicted labels for each token
|
299 |
for token, prediction in zip(tokens, predictions[0].numpy()):
|
300 |
if token not in ['<pad>', '<cls>', '<eos>']:
|
|
|
81 |
loss = compute_loss(model, inputs)
|
82 |
return (loss, outputs) if return_outputs else loss
|
83 |
|
84 |
+
# Predict binding site with finetuned PEFT model
|
85 |
+
def predict_bind(base_model_path,PEFT_model_path,input_seq):
|
86 |
+
# Load the model
|
87 |
+
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
|
88 |
+
loaded_model = PeftModel.from_pretrained(base_model, PEFT_model_path)
|
89 |
+
|
90 |
+
# Ensure the model is in evaluation mode
|
91 |
+
loaded_model.eval()
|
92 |
+
|
93 |
+
# Tokenization
|
94 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
95 |
+
|
96 |
+
# Tokenize the sequence
|
97 |
+
inputs = tokenizer(input_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
|
98 |
+
|
99 |
+
# Run the model
|
100 |
+
with torch.no_grad():
|
101 |
+
logits = loaded_model(**inputs).logits
|
102 |
+
|
103 |
+
# Get predictions
|
104 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
105 |
+
predictions = torch.argmax(logits, dim=2)
|
106 |
+
|
107 |
+
binding_site=[]
|
108 |
+
# Print the predicted labels for each token
|
109 |
+
for n, token, prediction in enumerate(zip(tokens, predictions[0].numpy())):
|
110 |
+
if token not in ['<pad>', '<cls>', '<eos>']:
|
111 |
+
print((token, id2label[prediction]))
|
112 |
+
if prediction == 1:
|
113 |
+
print((n+1,token, id2label[prediction]))
|
114 |
+
binding_site.append(n+1,token, id2label[prediction])
|
115 |
+
return binding_site
|
116 |
+
|
117 |
# fine-tuning function
|
118 |
def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
|
119 |
|
|
|
134 |
#base_model_path = "facebook/esm2_t12_35M_UR50D"
|
135 |
|
136 |
# Define labels and model
|
137 |
+
#id2label = {0: "No binding site", 1: "Binding site"}
|
138 |
+
#label2id = {v: k for k, v in id2label.items()}
|
139 |
+
|
140 |
+
|
141 |
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
|
142 |
|
143 |
'''
|
|
|
323 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
324 |
predictions = torch.argmax(logits, dim=2)
|
325 |
|
326 |
+
'''
|
327 |
# Define labels
|
328 |
id2label = {
|
329 |
0: "No binding site",
|
330 |
1: "Binding site"
|
331 |
}
|
332 |
|
333 |
+
'''
|
334 |
# Print the predicted labels for each token
|
335 |
for token, prediction in zip(tokens, predictions[0].numpy()):
|
336 |
if token not in ['<pad>', '<cls>', '<eos>']:
|