Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -106,7 +106,7 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
106 |
label2id = {v: k for k, v in id2label.items()}
|
107 |
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
|
108 |
|
109 |
-
|
110 |
# Load the data from pickle files (replace with your local paths)
|
111 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
112 |
train_sequences = pickle.load(f)
|
@@ -119,22 +119,23 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
119 |
|
120 |
with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
|
121 |
test_labels = pickle.load(f)
|
|
|
122 |
|
123 |
# Tokenization
|
124 |
tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
|
125 |
-
max_sequence_length = 1000
|
126 |
|
127 |
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
128 |
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
129 |
|
130 |
# Directly truncate the entire list of labels
|
131 |
-
train_labels = truncate_labels(train_labels, max_sequence_length)
|
132 |
-
test_labels = truncate_labels(test_labels, max_sequence_length)
|
133 |
|
134 |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
135 |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
136 |
|
137 |
-
|
138 |
# Compute Class Weights
|
139 |
classes = [0, 1]
|
140 |
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
@@ -142,6 +143,7 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
142 |
accelerator = Accelerator()
|
143 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
144 |
print(" class_weights:", class_weights)
|
|
|
145 |
|
146 |
# Convert the model into a PeftModel
|
147 |
peft_config = LoraConfig(
|
@@ -188,7 +190,7 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
188 |
fp16=True,
|
189 |
#report_to='wandb'
|
190 |
report_to=None,
|
191 |
-
hub_token =
|
192 |
)
|
193 |
|
194 |
# Initialize Trainer
|
@@ -211,7 +213,7 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
211 |
return save_path
|
212 |
|
213 |
# Constants & Globals
|
214 |
-
|
215 |
|
216 |
MODEL_OPTIONS = [
|
217 |
"facebook/esm2_t6_8M_UR50D",
|
@@ -233,19 +235,19 @@ with open("./datasets/train_labels_chunked_by_family.pkl", "rb") as f:
|
|
233 |
with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
|
234 |
test_labels = pickle.load(f)
|
235 |
|
236 |
-
|
237 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
238 |
max_sequence_length = 1000
|
239 |
|
240 |
-
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
241 |
-
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
242 |
|
243 |
# Directly truncate the entire list of labels
|
244 |
train_labels = truncate_labels(train_labels, max_sequence_length)
|
245 |
test_labels = truncate_labels(test_labels, max_sequence_length)
|
246 |
|
247 |
-
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
248 |
-
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
249 |
|
250 |
|
251 |
# Compute Class Weights
|
|
|
106 |
label2id = {v: k for k, v in id2label.items()}
|
107 |
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
|
108 |
|
109 |
+
'''
|
110 |
# Load the data from pickle files (replace with your local paths)
|
111 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
112 |
train_sequences = pickle.load(f)
|
|
|
119 |
|
120 |
with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
|
121 |
test_labels = pickle.load(f)
|
122 |
+
'''
|
123 |
|
124 |
# Tokenization
|
125 |
tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
|
126 |
+
#max_sequence_length = 1000
|
127 |
|
128 |
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
129 |
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
130 |
|
131 |
# Directly truncate the entire list of labels
|
132 |
+
#train_labels = truncate_labels(train_labels, max_sequence_length)
|
133 |
+
#test_labels = truncate_labels(test_labels, max_sequence_length)
|
134 |
|
135 |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
136 |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
137 |
|
138 |
+
'''
|
139 |
# Compute Class Weights
|
140 |
classes = [0, 1]
|
141 |
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
|
|
143 |
accelerator = Accelerator()
|
144 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
145 |
print(" class_weights:", class_weights)
|
146 |
+
'''
|
147 |
|
148 |
# Convert the model into a PeftModel
|
149 |
peft_config = LoraConfig(
|
|
|
190 |
fp16=True,
|
191 |
#report_to='wandb'
|
192 |
report_to=None,
|
193 |
+
hub_token = HF_TOKEN, #jw 20240701
|
194 |
)
|
195 |
|
196 |
# Initialize Trainer
|
|
|
213 |
return save_path
|
214 |
|
215 |
# Constants & Globals
|
216 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
217 |
|
218 |
MODEL_OPTIONS = [
|
219 |
"facebook/esm2_t6_8M_UR50D",
|
|
|
235 |
with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
|
236 |
test_labels = pickle.load(f)
|
237 |
|
238 |
+
## Tokenization
|
239 |
+
#tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
240 |
max_sequence_length = 1000
|
241 |
|
242 |
+
#train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
243 |
+
#test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
244 |
|
245 |
# Directly truncate the entire list of labels
|
246 |
train_labels = truncate_labels(train_labels, max_sequence_length)
|
247 |
test_labels = truncate_labels(test_labels, max_sequence_length)
|
248 |
|
249 |
+
#train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
250 |
+
#test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
251 |
|
252 |
|
253 |
# Compute Class Weights
|