wangjin2000 commited on
Commit
30609c9
·
verified ·
1 Parent(s): 1821c6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -3
app.py CHANGED
@@ -26,6 +26,8 @@ from transformers import (
26
  Trainer
27
  )
28
 
 
 
29
  from datasets import Dataset
30
  from accelerate import Accelerator
31
  # Imports specific to the custom peft lora model
@@ -105,10 +107,39 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
105
  accelerator = Accelerator()
106
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
107
 
108
- dubug_result = class_weights
109
- demo = gr.Blocks(title="DEMO FOR ESMBind")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  with demo:
112
- gr.Markdown("# DEMO FOR ESMBind")
113
  gr.Textbox(dubug_result)
114
  demo.launch()
 
26
  Trainer
27
  )
28
 
29
+ from peft import PeftModel
30
+
31
  from datasets import Dataset
32
  from accelerate import Accelerator
33
  # Imports specific to the custom peft lora model
 
107
  accelerator = Accelerator()
108
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
109
 
110
+ # inference
111
+ # Path to the saved LoRA model
112
+ model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
113
+ # ESM2 base model
114
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
115
+
116
+ # Load the model
117
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
118
+ loaded_model = PeftModel.from_pretrained(base_model, model_path)
119
+
120
+ # Ensure the model is in evaluation mode
121
+ loaded_model.eval()
122
+
123
+ # Protein sequence for inference
124
+ protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
125
+
126
+ # Tokenize the sequence
127
+ inputs = tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
128
+
129
+ # Run the model
130
+ with torch.no_grad():
131
+ logits = loaded_model(**inputs).logits
132
+
133
+ # Get predictions
134
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
135
+ predictions = torch.argmax(logits, dim=2)
136
+
137
+ # debug result
138
+ dubug_result = predictions #class_weights
139
+
140
+ demo = gr.Blocks(title="DEMO FOR ESM2Bind")
141
 
142
  with demo:
143
+ gr.Markdown("# DEMO FOR ESM2Bind")
144
  gr.Textbox(dubug_result)
145
  demo.launch()