Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -98,13 +98,49 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
98 |
# Add other hyperparameters as needed
|
99 |
}
|
100 |
# The base model you will train a LoRA on top of
|
101 |
-
base_model_path = "facebook/esm2_t12_35M_UR50D"
|
102 |
|
103 |
# Define labels and model
|
104 |
id2label = {0: "No binding site", 1: "Binding site"}
|
105 |
label2id = {v: k for k, v in id2label.items()}
|
106 |
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
# Convert the model into a PeftModel
|
109 |
peft_config = LoraConfig(
|
110 |
task_type=TaskType.TOKEN_CLS,
|
@@ -178,6 +214,7 @@ MODEL_OPTIONS = [
|
|
178 |
"facebook/esm2_t33_650M_UR50D",
|
179 |
] # models users can choose from
|
180 |
|
|
|
181 |
# Load the data from pickle files (replace with your local paths)
|
182 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
183 |
train_sequences = pickle.load(f)
|
@@ -213,7 +250,6 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
|
|
213 |
accelerator = Accelerator()
|
214 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
215 |
|
216 |
-
'''
|
217 |
# inference
|
218 |
# Path to the saved LoRA model
|
219 |
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
|
|
|
98 |
# Add other hyperparameters as needed
|
99 |
}
|
100 |
# The base model you will train a LoRA on top of
|
101 |
+
#base_model_path = "facebook/esm2_t12_35M_UR50D"
|
102 |
|
103 |
# Define labels and model
|
104 |
id2label = {0: "No binding site", 1: "Binding site"}
|
105 |
label2id = {v: k for k, v in id2label.items()}
|
106 |
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
|
107 |
|
108 |
+
|
109 |
+
# Load the data from pickle files (replace with your local paths)
|
110 |
+
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
111 |
+
train_sequences = pickle.load(f)
|
112 |
+
|
113 |
+
with open("./datasets/test_sequences_chunked_by_family.pkl", "rb") as f:
|
114 |
+
test_sequences = pickle.load(f)
|
115 |
+
|
116 |
+
with open("./datasets/train_labels_chunked_by_family.pkl", "rb") as f:
|
117 |
+
train_labels = pickle.load(f)
|
118 |
+
|
119 |
+
with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
|
120 |
+
test_labels = pickle.load(f)
|
121 |
+
|
122 |
+
# Tokenization
|
123 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
|
124 |
+
max_sequence_length = 1000
|
125 |
+
|
126 |
+
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
127 |
+
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
128 |
+
|
129 |
+
# Directly truncate the entire list of labels
|
130 |
+
train_labels = truncate_labels(train_labels, max_sequence_length)
|
131 |
+
test_labels = truncate_labels(test_labels, max_sequence_length)
|
132 |
+
|
133 |
+
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
134 |
+
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
135 |
+
|
136 |
+
|
137 |
+
# Compute Class Weights
|
138 |
+
classes = [0, 1]
|
139 |
+
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
140 |
+
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
|
141 |
+
accelerator = Accelerator()
|
142 |
+
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
143 |
+
|
144 |
# Convert the model into a PeftModel
|
145 |
peft_config = LoraConfig(
|
146 |
task_type=TaskType.TOKEN_CLS,
|
|
|
214 |
"facebook/esm2_t33_650M_UR50D",
|
215 |
] # models users can choose from
|
216 |
|
217 |
+
'''
|
218 |
# Load the data from pickle files (replace with your local paths)
|
219 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
220 |
train_sequences = pickle.load(f)
|
|
|
250 |
accelerator = Accelerator()
|
251 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
252 |
|
|
|
253 |
# inference
|
254 |
# Path to the saved LoRA model
|
255 |
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
|