wangjin2000 commited on
Commit
a8846d6
·
verified ·
1 Parent(s): 5bbc76e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -2
app.py CHANGED
@@ -98,13 +98,49 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
98
  # Add other hyperparameters as needed
99
  }
100
  # The base model you will train a LoRA on top of
101
- base_model_path = "facebook/esm2_t12_35M_UR50D"
102
 
103
  # Define labels and model
104
  id2label = {0: "No binding site", 1: "Binding site"}
105
  label2id = {v: k for k, v in id2label.items()}
106
  base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Convert the model into a PeftModel
109
  peft_config = LoraConfig(
110
  task_type=TaskType.TOKEN_CLS,
@@ -178,6 +214,7 @@ MODEL_OPTIONS = [
178
  "facebook/esm2_t33_650M_UR50D",
179
  ] # models users can choose from
180
 
 
181
  # Load the data from pickle files (replace with your local paths)
182
  with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
183
  train_sequences = pickle.load(f)
@@ -213,7 +250,6 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
213
  accelerator = Accelerator()
214
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
215
 
216
- '''
217
  # inference
218
  # Path to the saved LoRA model
219
  model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
 
98
  # Add other hyperparameters as needed
99
  }
100
  # The base model you will train a LoRA on top of
101
+ #base_model_path = "facebook/esm2_t12_35M_UR50D"
102
 
103
  # Define labels and model
104
  id2label = {0: "No binding site", 1: "Binding site"}
105
  label2id = {v: k for k, v in id2label.items()}
106
  base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
107
 
108
+
109
+ # Load the data from pickle files (replace with your local paths)
110
+ with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
111
+ train_sequences = pickle.load(f)
112
+
113
+ with open("./datasets/test_sequences_chunked_by_family.pkl", "rb") as f:
114
+ test_sequences = pickle.load(f)
115
+
116
+ with open("./datasets/train_labels_chunked_by_family.pkl", "rb") as f:
117
+ train_labels = pickle.load(f)
118
+
119
+ with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
120
+ test_labels = pickle.load(f)
121
+
122
+ # Tokenization
123
+ tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
124
+ max_sequence_length = 1000
125
+
126
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
127
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
128
+
129
+ # Directly truncate the entire list of labels
130
+ train_labels = truncate_labels(train_labels, max_sequence_length)
131
+ test_labels = truncate_labels(test_labels, max_sequence_length)
132
+
133
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
134
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
135
+
136
+
137
+ # Compute Class Weights
138
+ classes = [0, 1]
139
+ flat_train_labels = [label for sublist in train_labels for label in sublist]
140
+ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
141
+ accelerator = Accelerator()
142
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
143
+
144
  # Convert the model into a PeftModel
145
  peft_config = LoraConfig(
146
  task_type=TaskType.TOKEN_CLS,
 
214
  "facebook/esm2_t33_650M_UR50D",
215
  ] # models users can choose from
216
 
217
+ '''
218
  # Load the data from pickle files (replace with your local paths)
219
  with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
220
  train_sequences = pickle.load(f)
 
250
  accelerator = Accelerator()
251
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
252
 
 
253
  # inference
254
  # Path to the saved LoRA model
255
  model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"