AmelieSchreiber commited on
Commit
6b4ca98
1 Parent(s): 42f78dd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +107 -0
README.md CHANGED
@@ -2,6 +2,8 @@
2
  license: mit
3
  ---
4
 
 
 
5
  ```python
6
  Train:
7
  ({'accuracy': 0.9406146072672105,
@@ -18,3 +20,108 @@ Test:
18
  'auc': 0.8547531675185658,
19
  'mcc': 0.3930283737012391})
20
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: mit
3
  ---
4
 
5
+ ## Metrics:
6
+
7
  ```python
8
  Train:
9
  ({'accuracy': 0.9406146072672105,
 
20
  'auc': 0.8547531675185658,
21
  'mcc': 0.3930283737012391})
22
  ```
23
+
24
+ ## Using the Model
25
+
26
+ Head over to [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family)
27
+ to download the dataset first. Once you have the pickle files downloaded locally, run the following:
28
+
29
+ ```python
30
+ from datasets import Dataset
31
+ from transformers import AutoTokenizer
32
+ import pickle
33
+
34
+ # Load tokenizer
35
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
36
+
37
+ # Function to truncate labels
38
+ def truncate_labels(labels, max_length):
39
+ """Truncate labels to the specified max_length."""
40
+ return [label[:max_length] for label in labels]
41
+
42
+ # Set the maximum sequence length
43
+ max_sequence_length = 1000
44
+
45
+ # Load the data from pickle files
46
+ with open("train_sequences_chunked_by_family.pkl", "rb") as f:
47
+ train_sequences = pickle.load(f)
48
+ with open("test_sequences_chunked_by_family.pkl", "rb") as f:
49
+ test_sequences = pickle.load(f)
50
+ with open("train_labels_chunked_by_family.pkl", "rb") as f:
51
+ train_labels = pickle.load(f)
52
+ with open("test_labels_chunked_by_family.pkl", "rb") as f:
53
+ test_labels = pickle.load(f)
54
+
55
+ # Tokenize the sequences
56
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
57
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
58
+
59
+ # Truncate the labels to match the tokenized sequence lengths
60
+ train_labels = truncate_labels(train_labels, max_sequence_length)
61
+ test_labels = truncate_labels(test_labels, max_sequence_length)
62
+
63
+ # Create train and test datasets
64
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
65
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
66
+ ```
67
+
68
+ Then run the following to get the train/test metrics:
69
+
70
+ ```python
71
+ from sklearn.metrics import(
72
+ matthews_corrcoef,
73
+ accuracy_score,
74
+ precision_recall_fscore_support,
75
+ roc_auc_score
76
+ )
77
+ from peft import PeftModel
78
+ from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification
79
+ from transformers import Trainer
80
+ from accelerate import Accelerator
81
+
82
+ # Instantiate the accelerator
83
+ accelerator = Accelerator()
84
+
85
+ # Define paths to the LoRA and base models
86
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
87
+ lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model
88
+
89
+ # Load the base model
90
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
91
+
92
+ # Load the LoRA model
93
+ model = PeftModel.from_pretrained(base_model, lora_model_path)
94
+ model = accelerator.prepare(model) # Prepare the model using the accelerator
95
+
96
+ # Define label mappings
97
+ id2label = {0: "No binding site", 1: "Binding site"}
98
+ label2id = {v: k for k, v in id2label.items()}
99
+
100
+ # Create a data collator
101
+ data_collator = DataCollatorForTokenClassification(tokenizer)
102
+
103
+ # Define a function to compute the metrics
104
+ def compute_metrics(dataset):
105
+ # Get the predictions using the trained model
106
+ trainer = Trainer(model=model, data_collator=data_collator)
107
+ predictions, labels, _ = trainer.predict(test_dataset=dataset)
108
+
109
+ # Remove padding and special tokens
110
+ mask = labels != -100
111
+ true_labels = labels[mask].flatten()
112
+ flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()
113
+
114
+ # Compute the metrics
115
+ accuracy = accuracy_score(true_labels, flat_predictions)
116
+ precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
117
+ auc = roc_auc_score(true_labels, flat_predictions)
118
+ mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC
119
+
120
+ return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc} # Include the MCC in the returned dictionary
121
+
122
+ # Get the metrics for the training and test datasets
123
+ train_metrics = compute_metrics(train_dataset)
124
+ test_metrics = compute_metrics(test_dataset)
125
+
126
+ train_metrics, test_metrics
127
+ ```