AmelieSchreiber commited on
Commit
542ec7f
1 Parent(s): 9ddf524

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +100 -1
README.md CHANGED
@@ -32,4 +32,103 @@ Micro
32
  ```
33
  Validation Precision: 0.9822020821532512
34
  Validation Recall: 0.9999363677941498
35
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ```
33
  Validation Precision: 0.9822020821532512
34
  Validation Recall: 0.9999363677941498
35
+ ```
36
+
37
+ ## Using the model
38
+
39
+ First, download the `train_sequences.fasta` file and the `train_terms.tsv` file, and provide the local paths in the code below:
40
+
41
+ ```python
42
+ import os
43
+ import numpy as np
44
+ import torch
45
+ from transformers import AutoTokenizer, EsmForSequenceClassification, AdamW
46
+ from torch.nn.functional import binary_cross_entropy_with_logits
47
+ from sklearn.model_selection import train_test_split
48
+ from sklearn.metrics import f1_score, precision_score, recall_score
49
+ # from accelerate import Accelerator
50
+ from Bio import SeqIO
51
+
52
+ # Step 1: Data Preprocessing (Replace with your local paths)
53
+ fasta_file = "/Users/amelieschreiber/.cursor-tutor/projects/python/cafa5/cafa-5-protein-function-prediction/Train/train_sequences.fasta"
54
+ tsv_file = "/Users/amelieschreiber/.cursor-tutor/projects/python/cafa5/cafa-5-protein-function-prediction/Train/train_terms.tsv"
55
+
56
+ fasta_data = {}
57
+ tsv_data = {}
58
+
59
+ for record in SeqIO.parse(fasta_file, "fasta"):
60
+ fasta_data[record.id] = str(record.seq)
61
+
62
+ with open(tsv_file, 'r') as f:
63
+ for line in f:
64
+ parts = line.strip().split("\t")
65
+ tsv_data[parts[0]] = parts[1:]
66
+
67
+ # tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
68
+ seq_length = 1022
69
+ # tokenized_data = tokenizer(list(fasta_data.values()), padding=True, truncation=True, return_tensors="pt", max_length=seq_length)
70
+
71
+ unique_terms = list(set(term for terms in tsv_data.values() for term in terms))
72
+ ```
73
+
74
+
75
+ Second, downlowd the file `go-basic.obo` [from here](https://huggingface.co/datasets/AmelieSchreiber/cafa_5)
76
+ and store the file locally, then provide the local path in the the code below:
77
+
78
+ ```python
79
+ import torch
80
+ from transformers import AutoTokenizer, EsmForSequenceClassification
81
+ from sklearn.metrics import precision_recall_fscore_support
82
+
83
+ # 1. Parsing the go-basic.obo file
84
+ def parse_obo_file(file_path):
85
+ with open(file_path, 'r') as f:
86
+ data = f.read().split("[Term]")
87
+
88
+ terms = []
89
+ for entry in data[1:]:
90
+ lines = entry.strip().split("\n")
91
+ term = {}
92
+ for line in lines:
93
+ if line.startswith("id:"):
94
+ term["id"] = line.split("id:")[1].strip()
95
+ elif line.startswith("name:"):
96
+ term["name"] = line.split("name:")[1].strip()
97
+ elif line.startswith("namespace:"):
98
+ term["namespace"] = line.split("namespace:")[1].strip()
99
+ elif line.startswith("def:"):
100
+ term["definition"] = line.split("def:")[1].split('"')[1]
101
+ terms.append(term)
102
+ return terms
103
+
104
+ parsed_terms = parse_obo_file("go-basic.obo") # Replace `go-basic.obo` with your path
105
+
106
+ # 2. Load the saved model and tokenizer
107
+ model_path = "AmelieSchreiber/esm2_t6_8M_finetuned_cafa5"
108
+ loaded_model = EsmForSequenceClassification.from_pretrained(model_path)
109
+ loaded_tokenizer = AutoTokenizer.from_pretrained(model_path)
110
+
111
+ # 3. The predict_protein_function function
112
+ def predict_protein_function(sequence, model, tokenizer, go_terms):
113
+ inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=1022)
114
+ model.eval()
115
+ with torch.no_grad():
116
+ outputs = model(**inputs)
117
+ predictions = torch.sigmoid(outputs.logits)
118
+ predicted_indices = torch.where(predictions > 0.05)[1].tolist()
119
+
120
+ functions = []
121
+ for idx in predicted_indices:
122
+ term_id = unique_terms[idx] # Use the unique_terms list from your training script
123
+ for term in go_terms:
124
+ if term["id"] == term_id:
125
+ functions.append(term["name"])
126
+ break
127
+
128
+ return functions
129
+
130
+ # 4. Predicting protein function for an example sequence
131
+ example_sequence = "MAYLGSLVQRRLELASGDRLEASLGVGSELDVRGDRVKAVGSLDLEEGRLEQAGVSMA" # Replace with your protein sequence
132
+ predicted_functions = predict_protein_function(example_sequence, loaded_model, loaded_tokenizer, parsed_terms)
133
+ print(predicted_functions)
134
+ ```