Aarya Venkat commited on
Commit
d1ca73b
1 Parent(s): bc978c9

Update -- need to add new model

Browse files
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /usr/src/app
6
+
7
+ # Install any needed packages specified in requirements.txt
8
+ COPY requirements.txt ./
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Install BLAST
12
+ RUN apt-get update && apt-get install -y ncbi-blast+
13
+
14
+ # Copy the current directory contents into the container at /usr/src/app
15
+ COPY . .
16
+
17
+ # Set up a new user named "user" with user ID 1000
18
+ RUN useradd -m -u 1000 user
19
+ # Switch to the "user" user
20
+ USER user
21
+ # Set home to the user's home directory
22
+ ENV HOME=/home/user \\
23
+ PATH=/home/user/.local/bin:$PATH
24
+
25
+ # Define environment variable
26
+ ENV NAME Glydentify
27
+
28
+ # Run app.py when the container launches
29
+ CMD ["python", "app.py", "--host", "0.0.0.0", "--port", "7860"]
30
+
app.py CHANGED
@@ -8,6 +8,8 @@ from tqdm import tqdm
8
  import numpy as np
9
  import seaborn as sns
10
  from sklearn.model_selection import train_test_split
 
 
11
  import matplotlib.pyplot as plt
12
  import pickle
13
  import torch.nn.functional as F
@@ -16,164 +18,196 @@ import io
16
  from PIL import Image
17
  import Bio
18
  from Bio import SeqIO
 
 
19
  import zipfile
20
  import os
21
 
22
- # Load the model from the file
23
- with open('family_labels.pkl', 'rb') as filefam:
24
- yfam = pickle.load(filefam)
25
-
26
- tokenizerfam = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") #facebook/esm2_t33_650M_UR50D
27
-
28
- device = 'cpu'
29
- device
30
-
31
- modelfam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=len(yfam.classes_))
32
- modelfam = modelfam.to('cpu')
33
-
34
- modelfam.load_state_dict(torch.load("family.pth", map_location=torch.device('cpu')))
35
- modelfam.eval()
36
-
37
- x_testfam = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH
38
- LQRVSLPRMVYPQPKVLTPCRKDVLVVTPWLAPIVWEGTFNIDILNEQFRLQNTTIGLTV
39
- FAIKKYVAFLKLFLETAEKHFMVGHRVHYYVFTDQPAAVPRVTLGTGRQLSVLEVRAYKR
40
- WQDVSMRRMEMISDFCERRFLSEVDYLVCVDVDMEFRDHVGVEILTPLFGTLHPGFYGSS
41
- REAFTYERRPQSQAYIPKDEGDFYYLGGFFGGSVQEVQRLTRACHQAMMVDQANGIEAVW
42
- HDESHLNKYLLRHKPTKVLSPEYLWDQQLLGWPAVLRKLRFTAVPKNHQAVRNP
43
- """]
44
-
45
- encoded_inputfam = tokenizerfam(x_testfam, padding=True, truncation=True, max_length=512, return_tensors="pt")
46
- input_idsfam = encoded_inputfam["input_ids"]
47
- attention_maskfam = encoded_inputfam["attention_mask"]
48
-
49
- with torch.no_grad():
50
- outputfam = modelfam(input_idsfam, attention_mask=attention_maskfam)
51
- logitsfam = outputfam.logits
52
- probabilitiesfam = F.softmax(logitsfam, dim=1)
53
- _, predicted_labelsfam = torch.max(logitsfam, dim=1)
54
- probabilitiesfam[0]
55
-
56
- decoded_labelsfam = yfam.inverse_transform(predicted_labelsfam.tolist())
57
- decoded_labelsfam
58
-
59
-
60
-
61
- #Load donor model from file
62
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
63
-
64
- with open('donor_labels.pkl', 'rb') as file:
65
- label_encoder = pickle.load(file)
66
-
67
- # encoded_labels = label_encoder.fit(y)
68
- # labels = torch.tensor(encoded_labels)
69
-
70
- model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=len(label_encoder.classes_))
71
- model = model.to('cpu')
72
 
73
- model.load_state_dict(torch.load("best_model_35M_t12_5v5.pth", map_location=torch.device('cpu'))) #model_best_35v2M.pth
74
- model.eval()
75
 
76
- x_test = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH
77
- LQRVSLPRMVYPQPKVLTPCRKDVLVVTPWLAPIVWEGTFNIDILNEQFRLQNTTIGLTV
78
- FAIKKYVAFLKLFLETAEKHFMVGHRVHYYVFTDQPAAVPRVTLGTGRQLSVLEVRAYKR
79
- WQDVSMRRMEMISDFCERRFLSEVDYLVCVDVDMEFRDHVGVEILTPLFGTLHPGFYGSS
80
- REAFTYERRPQSQAYIPKDEGDFYYLGGFFGGSVQEVQRLTRACHQAMMVDQANGIEAVW
81
- HDESHLNKYLLRHKPTKVLSPEYLWDQQLLGWPAVLRKLRFTAVPKNHQAVRNP
82
- """]
 
 
 
 
 
83
 
84
- encoded_input = tokenizer(x_test, padding=True, truncation=True, max_length=512, return_tensors="pt")
85
- input_ids = encoded_input["input_ids"]
86
- attention_mask = encoded_input["attention_mask"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- with torch.no_grad():
89
- output = model(input_ids, attention_mask=attention_mask)
90
- logits = output.logits
91
- probabilities = F.softmax(logits, dim=1)
92
- _, predicted_labels = torch.max(logits, dim=1)
93
- probabilities[0]
94
 
95
- decoded_labels = label_encoder.inverse_transform(predicted_labels.tolist())
96
- decoded_labels
 
 
 
 
 
 
 
 
 
97
 
 
98
 
99
  glycosyltransferase_db = {
100
- "GT31-chsy" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
101
- "GT2-CesA2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' },
102
- "GT43-arath" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
103
- "GT8-Met1" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' },
104
- "GT32-higher" : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'},
105
  "GT40" : {'CAZy Name': 'GT40', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT40.html'},
106
  "GT16" : {'CAZy Name': 'GT16', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT16.html'},
107
  "GT27" : {'CAZy Name': 'GT27', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT27.html'},
108
  "GT55" : {'CAZy Name': 'GT55', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT55.html'},
109
- "GT8-Glycogenin" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' },
110
- "GT8-1" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' },
111
  "GT25" : {'CAZy Name': 'GT25', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT25.html'},
112
- "GT2-DPM_like" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' },
113
- "GT31-fringe" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
114
- "GT2-Bact_puta" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' },
115
  "GT84" : {'CAZy Name': 'GT84', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT84.html'},
116
  "GT13" : {'CAZy Name': 'GT13', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT13.html'},
117
- "GT43-cele" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
118
- "GT2-Bact_LPS1" : {'CAZy Name': 'GT92', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' },
119
- "GT2-Bact_Oant" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
120
  "GT67" : {'CAZy Name': 'GT67', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT67.html'},
121
- "GT2-HAS" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' },
122
  "GT82" : {'CAZy Name': 'GT82', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT82.html'},
123
  "GT24" : {'CAZy Name': 'GT24', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT24.html'},
124
- "GT31-plant" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
125
- "GT81-Bact" : {'CAZy Name': 'GT81', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT81.html'},
126
- "GT2-Bact_gt25Me": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
127
- "GT2-B3GntL" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '4 ', 'More Info': 'http://www.cazy.org/GT2.html' },
128
  "GT49" : {'CAZy Name': 'GT49', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT49.html'},
129
  "GT34" : {'CAZy Name': 'GT34', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT34.html'},
130
  "GT45" : {'CAZy Name': 'GT45', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT45.html'},
131
- "GT32-lower" : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'},
132
  "GT88" : {'CAZy Name': 'GT88', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT88.html'},
133
  "GT21" : {'CAZy Name': 'GT21', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT21.html'},
134
- "GT2-DPG_synt" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' },
135
- "GT43-b3gat2" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
136
- "GT2-Chitin_synt": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT2.html' },
137
- "GT8-Bact" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' },
138
- "GT8-Met2" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' },
139
- "GT2-Bact_Chlor1": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
140
  "GT54" : {'CAZy Name': 'GT54', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT54.html'},
141
- "GT2-Cel_bre3" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' },
142
- "GT2-Bact_Rham" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' },
143
  "GT6" : {'CAZy Name': 'GT6 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT6.html' },
144
- "GT2-Bact_puta2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
145
- "GT7-1" : {'CAZy Name': 'GT7 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT7.html' },
146
- "GT2-Csl" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '4 ', 'More Info': 'http://www.cazy.org/GT2.html' },
147
- "GT2-ExoU" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
148
- "GT2-Csl2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '4 ', 'More Info': 'http://www.cazy.org/GT2.html' },
149
  "GT64" : {'CAZy Name': 'GT64', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT64.html'},
150
- "GT2-Bact_Chlor2": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
151
  "GT78" : {'CAZy Name': 'GT78', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT78.html'},
152
  "GT12" : {'CAZy Name': 'GT12', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT12.html'},
153
- "GT31-gnt" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
154
- "GT2-Bact_CHS" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT2.html' },
155
  "GT62" : {'CAZy Name': 'GT62', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '3 ', 'More Info': 'http://www.cazy.org/GT62.html'},
156
- "GT8-Met_Pla" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' },
157
  "GT15" : {'CAZy Name': 'GT15', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT15.html'},
158
- "GT43-b3gat1" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
159
- "GT31-b3glt" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
160
- "GT2-CesA1" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' },
161
  "GT60" : {'CAZy Name': 'GT60', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT60.html'},
162
  "GT14" : {'CAZy Name': 'GT14', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT14.html'},
163
- "GT2-Bact_DPM_sy": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' },
164
  "GT17" : {'CAZy Name': 'GT17', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT17.html'},
165
- "GT2-Bact_LPS2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '3 ', 'More Info': 'http://www.cazy.org/GT2.html' },
166
  "GT77" : {'CAZy Name': 'GT77', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT77.html'},
167
- "GT2-Bact_EpsO" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
168
- "GT43-b3gat3" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
169
- "GT8-Fun" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' },
170
  "GT75" : {'CAZy Name': 'GT75', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT75.html'},
171
- "GT2-Bact_GlfT" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' },
172
 
173
  }
174
 
175
-
176
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
 
179
  def get_family_info(family_name):
@@ -201,26 +235,46 @@ def fig_to_img(fig):
201
 
202
  def preprocess_protein_sequence(protein_fasta):
203
  lines = protein_fasta.split('\n')
204
-
205
  headers = [line for line in lines if line.startswith('>')]
206
  if len(headers) > 1:
207
- return None, "Multiple fasta sequences detected. Please upload a fasta file with only one sequence."
208
 
209
  protein_sequence = ''.join(line for line in lines if not line.startswith('>'))
 
210
 
211
- # Check for invalid characters
212
- valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy") # the 20 standard amino acids
213
- if not set(protein_sequence).issubset(valid_characters):
214
- return None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids. Does your sequence contain gaps?"
 
 
 
 
215
 
216
- return protein_sequence, None
 
217
 
 
 
 
 
218
 
219
- def process_family_sequence(protein_fasta):
220
- protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta)
221
- if error_msg:
222
- return None, None, None, error_msg
 
 
 
 
 
 
223
 
 
 
 
 
 
224
  encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
225
  input_idsfam = encoded_input["input_ids"]
226
  attention_maskfam = encoded_input["attention_mask"]
@@ -231,28 +285,26 @@ def process_family_sequence(protein_fasta):
231
  probabilitiesfam = F.softmax(logitsfam, dim=1)
232
  _, predicted_labelsfam = torch.max(logitsfam, dim=1)
233
 
234
- decoded_labelsfam = yfam.inverse_transform(predicted_labelsfam.tolist())
235
- family_info = get_family_info(decoded_labelsfam[0])
236
-
237
- figfam = plt.figure(figsize=(10, 5))
238
- labelsfam = yfam.classes_
239
- probabilitiesfam = probabilitiesfam.tolist()
240
 
241
- # Convert the nested list to a flat list of probabilities
242
- probabilitiesfam_flat = probabilitiesfam[0] if probabilitiesfam else []
243
 
244
- # Sort labels and probabilities by probability
245
- labels_probsfam = list(zip(labelsfam, probabilitiesfam_flat))
246
- labels_probsfam.sort(key=lambda x: x[1], reverse=True)
247
 
248
- # Select the top 5 fams
249
- labels_probs_top5fam = labels_probsfam[:5]
250
- labels_top5, probabilities_top5 = zip(*labels_probs_top5fam)
 
251
 
252
- y_posfam = np.arange(len(labels_top5))
 
 
253
 
254
- plt.barh(y_posfam, [prob*100 for prob in probabilities_top5], align='center', alpha=0.5)
255
- plt.yticks(y_posfam, labels_top5)
 
256
  plt.xlabel('Probability (%)')
257
  plt.title('Top 5 Family Class Probabilities')
258
  plt.xlim(0, 100)
@@ -261,171 +313,96 @@ def process_family_sequence(protein_fasta):
261
  img = fig_to_img(figfam)
262
 
263
  if len(protein_sequence) < 100:
264
- return decoded_labelsfam[0], img, None, f"**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}"
265
-
266
 
267
- return decoded_labelsfam[0], img, None, family_info
268
 
269
 
270
- def process_single_sequence(protein_fasta): #, protein_file
271
- protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta)
272
- if error_msg:
273
- return None, None, None, error_msg
274
-
275
  encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
276
- input_ids = encoded_input["input_ids"]
277
- attention_mask = encoded_input["attention_mask"]
278
 
279
  with torch.no_grad():
280
- output = model(input_ids, attention_mask=attention_mask)
281
- logits = output.logits
282
- dprobabilities = F.softmax(logits, dim=1)[0]
283
- _, predicted_labels = torch.max(logits, dim=1)
284
 
285
- decoded_labels = label_encoder.inverse_transform(predicted_labels.tolist())
286
- family_info = get_family_info(decoded_labels[0])
287
 
288
- fig = plt.figure(figsize=(10, 5))
289
- labels = label_encoder.classes_
290
- dprobabilities = dprobabilities.tolist()
291
 
292
- # Sort labels and probabilities by probability
293
- labels_probs = list(zip(labels, dprobabilities))
294
- labels_probs.sort(key=lambda x: x[1], reverse=True)
 
295
 
296
- # Select the top 3 donors
297
- labels_probs_top3 = labels_probs[:3]
298
- labels_top3, probabilities_top3 = zip(*labels_probs_top3)
299
-
300
- y_pos = np.arange(len(labels_top3))
301
-
302
- plt.barh(y_pos, [prob*100 for prob in probabilities_top3], align='center', alpha=0.5)
303
- plt.yticks(y_pos, labels_top3)
304
  plt.xlabel('Probability (%)')
305
  plt.title('Top 3 Donor Class Probabilities')
306
- plt.xlim(0, 100)
307
- plt.close(fig)
308
 
309
- img = fig_to_img(fig)
310
 
311
  if len(protein_sequence) < 100:
312
- return decoded_labels[0], img, None, f"**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}"
313
-
314
-
315
- return decoded_labels[0], img, None, None
316
-
317
- def process_sequence_file(protein_file): # added progress parameter that is displayed in gradio #, progress=gr.Progress()
318
- try:
319
- records = list(SeqIO.parse(protein_file.name, "fasta"))
320
- except Exception as e:
321
- return str(e)
322
-
323
- if not os.path.exists('results'):
324
- os.makedirs('results')
325
-
326
- total = len(records)
327
-
328
- for idx, record in enumerate(records):
329
- protein_sequence = str(record.seq)
330
-
331
- valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy")
332
- if not set(protein_sequence).issubset(valid_characters):
333
- with open(f'results/result_{idx+1}.txt', 'w') as file:
334
- file.write("Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids. Does your sequence contain gaps?")
335
- continue
336
-
337
- label, img, _, info = process_single_sequence(protein_sequence)
338
- img.save(f'results/result_{idx+1}.png')
339
- with open(f'results/result_{idx+1}.txt', 'w') as file:
340
- file.write(f'Predicted Donor: {label}\n\n{info}')
341
-
342
- # progress(idx/total) # Update the progress bar
343
-
344
- # Create a zip file w/ results -- To Do: Figure out how to improve compression for large files
345
- with zipfile.ZipFile('predicted_results.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
346
- for root, dirs, files in os.walk('results/'):
347
- for file in files:
348
- zipf.write(os.path.join(root, file))
349
 
350
- return 'predicted_results.zip' #Provide indication of how to interpret downloaded zip file? f"**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions.
351
 
352
- # Function to mask a residue at a particular position
353
- def mask_residue(sequence, position):
354
- return sequence[:position] + 'X' + sequence[position+1:]
355
-
356
- def generate_heatmap(protein_fasta):
357
- protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta)
358
-
359
- # Tokenize and predict for original sequence
360
- encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
361
- with torch.no_grad():
362
- original_output = model(encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"])
363
- original_probabilities = F.softmax(original_output.logits, dim=1).cpu().numpy()[0]
364
-
365
- # Define the size of each group
366
- group_size = 10 # allow user to change this
367
-
368
- # Calculate the number of groups
369
- num_groups = len(protein_sequence) // group_size + (len(protein_sequence) % group_size > 0)
370
-
371
- # Initialize an array to hold the importance scores
372
- importance_scores = np.zeros((num_groups, len(original_probabilities)))
373
-
374
- # Initialize tqdm progress bar
375
- # with tqdm(total=num_groups, desc="Processing groups", position=0, leave=True) as pbar:
376
- # # Loop through each group of residues in the sequence
377
- for i in range(0, len(protein_sequence), group_size):
378
- # Mask the residues in the group at positions [i, i + group_size)
379
- masked_sequence = protein_sequence[:i] + 'X' * min(group_size, len(protein_sequence) - i) + protein_sequence[i + group_size:]
380
-
381
- # Tokenize and predict for the masked sequence
382
- encoded_input = tokenizer([masked_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
383
- with torch.no_grad():
384
- masked_output = model(encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"])
385
- masked_probabilities = F.softmax(masked_output.logits, dim=1).cpu().numpy()[0]
386
-
387
- # Calculate the change in probabilities and store it as the importance score
388
- group_index = i // group_size
389
- importance_scores[group_index, :] = np.abs(original_probabilities - masked_probabilities)
390
-
391
- progress = (i // group_size + 1) / num_groups * 100
392
- print(f"Progress: {progress:.2f}%")
393
 
394
- figmap, ax = plt.subplots(figsize=(20, 20))
395
- sns.heatmap(importance_scores, annot=True, cmap="coolwarm", xticklabels=label_encoder.classes_, yticklabels=[f"{i}-{i+group_size-1}" for i in range(0, len(protein_sequence), group_size)], ax=ax)
396
- ax.set_xlabel("Predicted Labels")
397
- ax.set_ylabel("Residue Position Groups")
398
-
399
- img = fig_to_img(figmap)
400
-
401
- return img
402
 
403
 
404
- def main_function_single(sequence, show_explanation):
405
- # Process seq, and return outputs for both fam and don
406
- family_label, family_img, _, family_info = process_family_sequence(sequence)
407
- donor_label, donor_img, *_ = process_single_sequence(sequence)
408
- figmap = None
409
- if show_explanation:
410
- figmap = generate_heatmap(sequence)
411
- return family_label, family_img, family_info, donor_label, donor_img, figmap
412
-
413
- def main_function_upload(protein_file): #, progress=gr.Progress()
414
- return process_sequence_file(protein_file) #, progress
415
-
416
  prediction_imagefam = gr.outputs.Image(type='pil', label="Family prediction graph")
417
  prediction_imagedonor = gr.outputs.Image(type='pil', label="Donor prediction graph")
418
- prediction_explain = gr.outputs.Image(type='pil', label="Donor prediction explanation")
419
-
420
 
421
  with gr.Blocks() as app:
422
- gr.Markdown("# Glydentify (alpha v0.3)")
423
 
424
  with gr.Tab("Single Sequence Prediction"):
425
  with gr.Row().style(equal_height=True):
426
  with gr.Column():
427
  sequence = gr.inputs.Textbox(lines=16, placeholder='Enter Protein Sequence Here...', label="Protein Sequence")
428
- explanation_checkbox = gr.inputs.Checkbox(label="Show Explanation", default=False)
429
  with gr.Column():
430
  with gr.Accordion("Example:"):
431
  gr.Markdown("""
@@ -443,37 +420,19 @@ with gr.Blocks() as app:
443
  with gr.Row().style(equal_height=True):
444
  with gr.Column():
445
  predict_button = gr.Button("Predict")
446
- predict_button.click(main_function_single, inputs=[sequence, explanation_checkbox],
447
  outputs=[family_prediction, prediction_imagefam, info_markdown,
448
- donor_prediction, prediction_imagedonor, prediction_explain])
449
 
450
  # Family & Donor Section
451
  with gr.Row().style(equal_height=True):
452
  with gr.Column():
453
- with gr.Accordion("Prediction Bar Graphs:"):
454
  prediction_imagefam.render() # = gr.outputs.Image(type='pil', label="Family prediction graph")
455
- prediction_imagedonor.render() # = gr.outputs.Image(type='pil', label="Donor prediction graph")
456
-
457
- # Explain Section
458
  with gr.Column():
459
- if explanation_checkbox: # Only render if the checkbox is checked
460
- with gr.Accordion("Donor explanation"):
461
- prediction_explain.render() # = gr.outputs.Image(type='pil', label="Donor prediction explaination")
462
 
463
- with gr.Tab("Multiple Sequence Prediction"):
464
- with gr.Row().style(equal_height=True):
465
- with gr.Column():
466
- protein_file = gr.inputs.File(label="Upload FASTA file")
467
- with gr.Column():
468
- result_file = gr.outputs.File(label="Download predictions of uploaded sequences")
469
- with gr.Row().style(equal_height=True):
470
- with gr.Column():
471
- process_button = gr.Button("Process")
472
- process_button.click(main_function_upload, inputs=protein_file, outputs=[result_file])
473
- with gr.Column():
474
- clear = gr.Button("Clear")
475
- clear.click(lambda: None)
476
- # clear.click()
477
 
478
  app.launch(show_error=True)
479
 
 
8
  import numpy as np
9
  import seaborn as sns
10
  from sklearn.model_selection import train_test_split
11
+ import matplotlib
12
+ matplotlib.use('Agg') # Use the non-interactive Agg backend
13
  import matplotlib.pyplot as plt
14
  import pickle
15
  import torch.nn.functional as F
 
18
  from PIL import Image
19
  import Bio
20
  from Bio import SeqIO
21
+ from Bio.Blast import NCBIXML
22
+ import subprocess
23
  import zipfile
24
  import os
25
 
26
+ GTA_fam_dict = {
27
+ 0: "GT116",
28
+ 1: "GT12",
29
+ 2: "GT13",
30
+ 3: "GT14",
31
+ 4: "GT15",
32
+ 5: "GT16",
33
+ 6: "GT17",
34
+ 7: "GT2-clade1",
35
+ 8: "GT2-clade2",
36
+ 9: "GT2-clade3",
37
+ 10: "GT2-clade4",
38
+ 11: "GT2-clade5",
39
+ 12: "GT2-related",
40
+ 13: "GT21",
41
+ 14: "GT24",
42
+ 15: "GT25",
43
+ 16: "GT27",
44
+ 17: "GT31",
45
+ 18: "GT32",
46
+ 19: "GT34",
47
+ 20: "GT40",
48
+ 21: "GT43",
49
+ 22: "GT45",
50
+ 23: "GT49",
51
+ 24: "GT54",
52
+ 25: "GT55",
53
+ 26: "GT6",
54
+ 27: "GT60",
55
+ 28: "GT62",
56
+ 29: "GT64",
57
+ 30: "GT67",
58
+ 31: "GT7",
59
+ 32: "GT75",
60
+ 33: "GT77",
61
+ 34: "GT78",
62
+ 35: "GT8",
63
+ 36: "GT81",
64
+ 37: "GT82",
65
+ 38: "GT84",
66
+ 39: "GT88",
67
+ 40: "GT92"
68
+ }
 
 
 
 
 
 
 
69
 
 
 
70
 
71
+ GTA_don_dict = {
72
+ 0: "N-Acetyl Galactosamine",
73
+ 1: "N-Acetyl Glucosamine",
74
+ 2: "Arabinose",
75
+ 3: "Galactose",
76
+ 4: "Galacturonic Acid",
77
+ 5: "Glucose",
78
+ 6: "Glucuronic Acid",
79
+ 7: "Mannose",
80
+ 8: "Rhamnose",
81
+ 9: "Xylose"
82
+ }
83
 
84
+ GTB_fam_dict = {
85
+ 0: "GT1",
86
+ 1: "GT10",
87
+ 2: "GT104",
88
+ 3: "GT11",
89
+ 4: "GT18",
90
+ 5: "GT19",
91
+ 6: "GT20",
92
+ 7: "GT23",
93
+ 8: "GT28",
94
+ 9: "GT3",
95
+ 10: "GT30",
96
+ 11: "GT35",
97
+ 12: "GT37",
98
+ 13: "GT38",
99
+ 14: "GT4",
100
+ 15: "GT41",
101
+ 16: "GT5",
102
+ 17: "GT52",
103
+ 18: "GT63",
104
+ 19: "GT65",
105
+ 20: "GT68",
106
+ 21: "GT70",
107
+ 22: "GT72",
108
+ 23: "GT80",
109
+ 24: "GT9",
110
+ 25: "GT90",
111
+ 26: "GT99"
112
+ }
113
 
 
 
 
 
 
 
114
 
115
+ GTB_don_dict = {
116
+ 0: "Fucose",
117
+ 1: "Galactose",
118
+ 2: "N-Acetyl Galactosamine",
119
+ 3: "Glucuronic Acid",
120
+ 4: "N-Acetyl Glucosamine",
121
+ 5: "Glucose",
122
+ 6: "Mannose",
123
+ 7: "Other",
124
+ 8: "Xylose"
125
+ }
126
 
127
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") #facebook/esm2_t33_650M_UR50D
128
 
129
  glycosyltransferase_db = {
 
 
 
 
 
130
  "GT40" : {'CAZy Name': 'GT40', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT40.html'},
131
  "GT16" : {'CAZy Name': 'GT16', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT16.html'},
132
  "GT27" : {'CAZy Name': 'GT27', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT27.html'},
133
  "GT55" : {'CAZy Name': 'GT55', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT55.html'},
 
 
134
  "GT25" : {'CAZy Name': 'GT25', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT25.html'},
135
+ "GT2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' },
 
 
136
  "GT84" : {'CAZy Name': 'GT84', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT84.html'},
137
  "GT13" : {'CAZy Name': 'GT13', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT13.html'},
 
 
 
138
  "GT67" : {'CAZy Name': 'GT67', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT67.html'},
 
139
  "GT82" : {'CAZy Name': 'GT82', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT82.html'},
140
  "GT24" : {'CAZy Name': 'GT24', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT24.html'},
141
+ "GT81" : {'CAZy Name': 'GT81', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT81.html'},
 
 
 
142
  "GT49" : {'CAZy Name': 'GT49', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT49.html'},
143
  "GT34" : {'CAZy Name': 'GT34', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT34.html'},
144
  "GT45" : {'CAZy Name': 'GT45', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT45.html'},
145
+ "GT32" : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'},
146
  "GT88" : {'CAZy Name': 'GT88', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT88.html'},
147
  "GT21" : {'CAZy Name': 'GT21', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT21.html'},
 
 
 
 
 
 
148
  "GT54" : {'CAZy Name': 'GT54', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT54.html'},
 
 
149
  "GT6" : {'CAZy Name': 'GT6 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT6.html' },
150
+ "GT7" : {'CAZy Name': 'GT7 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT7.html' },
 
 
 
 
151
  "GT64" : {'CAZy Name': 'GT64', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT64.html'},
 
152
  "GT78" : {'CAZy Name': 'GT78', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT78.html'},
153
  "GT12" : {'CAZy Name': 'GT12', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT12.html'},
154
+ "GT31" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
 
155
  "GT62" : {'CAZy Name': 'GT62', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '3 ', 'More Info': 'http://www.cazy.org/GT62.html'},
156
+ "GT8" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' },
157
  "GT15" : {'CAZy Name': 'GT15', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT15.html'},
158
+ "GT43" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
 
 
159
  "GT60" : {'CAZy Name': 'GT60', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT60.html'},
160
  "GT14" : {'CAZy Name': 'GT14', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT14.html'},
 
161
  "GT17" : {'CAZy Name': 'GT17', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT17.html'},
 
162
  "GT77" : {'CAZy Name': 'GT77', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT77.html'},
 
 
 
163
  "GT75" : {'CAZy Name': 'GT75', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT75.html'},
 
164
 
165
  }
166
 
167
+ def parse_blast_output_for_best_evalue(output_file):
168
+ with open(output_file) as result_handle:
169
+ blast_record = NCBIXML.read(result_handle)
170
+
171
+ if len(blast_record.alignments) == 0:
172
+ # Handle the case where no alignments are found
173
+ # You might return a high e-value or None to indicate no match
174
+ return None
175
+
176
+ best_hit = blast_record.alignments[0]
177
+ best_evalue = best_hit.hsps[0].expect
178
+ print(best_evalue)
179
+ return best_evalue
180
+
181
+ def run_local_blast(sequence, database):
182
+ # Temporarily save the query sequence to a file
183
+ query_file = "temp_query.fasta"
184
+ with open(query_file, "w") as file:
185
+ file.write(">Query\n" + sequence)
186
+
187
+ # Specify the output file for BLAST results
188
+ output_file = "blast_results.xml"
189
+
190
+ # Construct the BLAST command
191
+ blast_cmd = [
192
+ "blastp",
193
+ "-query", query_file,
194
+ "-db", database,
195
+ "-out", output_file,
196
+ "-outfmt", "5", # Output format 5 is XML
197
+ "-evalue", "1e-2" # Set your desired E-value threshold here
198
+ ]
199
+
200
+ # Execute the BLAST search
201
+ subprocess.run(blast_cmd, check=True)
202
+
203
+ # Parse the BLAST output to find the best E-value
204
+ best_evalue = parse_blast_output_for_best_evalue(output_file)
205
+
206
+ # Clean up temporary files
207
+ os.remove(query_file)
208
+ os.remove(output_file)
209
+
210
+ return best_evalue
211
 
212
 
213
  def get_family_info(family_name):
 
235
 
236
  def preprocess_protein_sequence(protein_fasta):
237
  lines = protein_fasta.split('\n')
 
238
  headers = [line for line in lines if line.startswith('>')]
239
  if len(headers) > 1:
240
+ return None, None, None, "Multiple fasta sequences detected. Please upload a fasta file with only one sequence."
241
 
242
  protein_sequence = ''.join(line for line in lines if not line.startswith('>'))
243
+ valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy")
244
 
245
+ # Check if every character in the sequence is in the set of valid characters.
246
+ if any(char.upper() not in valid_characters for char in protein_sequence):
247
+ return None, None, None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids."
248
+
249
+ print("Running Blast.")
250
+
251
+ gta_db_path = "blast_data/GTA/GTA.db"
252
+ gtb_db_path = "blast_data/GTB/GTB.db"
253
 
254
+ evalue_gta = run_local_blast(protein_sequence, gta_db_path)
255
+ evalue_gta = evalue_gta if evalue_gta is not None else 1e+100
256
 
257
+ evalue_gtb = run_local_blast(protein_sequence, gtb_db_path)
258
+ evalue_gtb = evalue_gtb if evalue_gtb is not None else 1e+100
259
+ print("E-value GT-A:", evalue_gta, "E-value GT-B:", evalue_gtb)
260
+ print("Blast finished running. Checking sequence against known data.")
261
 
262
+ # Determine which models to use based on the best E-value
263
+ model_fam = "GTA_fam.pth" if evalue_gta < evalue_gtb else "GTB_fam.pth"
264
+ model_don = "GTA_don.pth" if evalue_gta < evalue_gtb else "GTB_don.pth"
265
+ print("Selected model for family:", model_fam, "and donor:", model_don)
266
+
267
+
268
+ # Adjust your existing condition to check if both E-values exceed the threshold
269
+ if evalue_gta > 1e-2 and evalue_gtb > 1e-2:
270
+ # If both E-values are above the threshold, it suggests the sequence does not match well with either database
271
+ return None, None, None, "**Warning:** The sequence does not appear to be a GT-A or GT-B. Please ensure you are submitting a sequence from these families."
272
 
273
+ return protein_sequence, model_fam, model_don, None
274
+
275
+
276
+
277
+ def process_family_sequence(protein_sequence, modelfam, label_dict):
278
  encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
279
  input_idsfam = encoded_input["input_ids"]
280
  attention_maskfam = encoded_input["attention_mask"]
 
285
  probabilitiesfam = F.softmax(logitsfam, dim=1)
286
  _, predicted_labelsfam = torch.max(logitsfam, dim=1)
287
 
288
+ predicted_label_index_fam = predicted_labelsfam.item() # Assuming single sample prediction
289
+ decoded_label_fam = label_dict.get(predicted_label_index_fam, "Unknown Label") # Decoding label using the dictionary
 
 
 
 
290
 
291
+ family_info = get_family_info(decoded_label_fam)
 
292
 
293
+ figfam = plt.figure(figsize=(10, 5))
294
+ # probabilitiesfam_flat = probabilitiesfam.squeeze().tolist() # Flatten probabilities
 
295
 
296
+ # Extract and sort top 5 label probabilities
297
+ top5_probs, top5_labels = torch.topk(probabilitiesfam, 5)
298
+ top5_labels = top5_labels.squeeze().tolist()
299
+ top5_decoded_labels = [label_dict.get(label, "Unknown") for label in top5_labels]
300
 
301
+ # For debugging
302
+ print("Top 5 labels:", top5_labels)
303
+ print("Available keys in label_dict:", label_dict.keys())
304
 
305
+ y_posfam = np.arange(len(top5_decoded_labels))
306
+ plt.barh(y_posfam, [prob * 100 for prob in top5_probs.squeeze().tolist()], align='center', alpha=0.5)
307
+ plt.yticks(y_posfam, top5_decoded_labels)
308
  plt.xlabel('Probability (%)')
309
  plt.title('Top 5 Family Class Probabilities')
310
  plt.xlim(0, 100)
 
313
  img = fig_to_img(figfam)
314
 
315
  if len(protein_sequence) < 100:
316
+ return decoded_label_fam, img, None, "**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}"
 
317
 
318
+ return decoded_label_fam, img, None, family_info
319
 
320
 
321
+ def process_donor_sequence(protein_sequence, modeldon, label_dict):
 
 
 
 
322
  encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
323
+ input_idsdon = encoded_input["input_ids"]
324
+ attention_maskdon = encoded_input["attention_mask"]
325
 
326
  with torch.no_grad():
327
+ outputdon = modeldon(input_idsdon, attention_mask=attention_maskdon)
328
+ logitsdon = outputdon.logits
329
+ probabilitiesdon = F.softmax(logitsdon, dim=1)
330
+ _, predicted_labelsdon = torch.max(logitsdon, dim=1)
331
 
332
+ predicted_label_index_don = predicted_labelsdon.item() # Assuming single sample prediction
333
+ decoded_label_don = label_dict.get(predicted_label_index_don, "Unknown Label") # Decoding label using the dictionary
334
 
335
+ figdon = plt.figure(figsize=(10, 5))
336
+ probabilitiesdon_flat = probabilitiesdon.squeeze().tolist() # Flatten probabilities
 
337
 
338
+ # Extract and sort top 5 label probabilities
339
+ top3_probs, top3_labels = torch.topk(probabilitiesdon, 3)
340
+ top3_labels = top3_labels.squeeze().tolist()
341
+ top3_decoded_labels = [label_dict.get(label, "Unknown") for label in top3_labels]
342
 
343
+ y_posdon = np.arange(len(top3_decoded_labels))
344
+ plt.barh(y_posdon, [prob * 100 for prob in top3_probs.squeeze().tolist()], align='center', alpha=0.5)
345
+ plt.yticks(y_posdon, top3_decoded_labels)
 
 
 
 
 
346
  plt.xlabel('Probability (%)')
347
  plt.title('Top 3 Donor Class Probabilities')
348
+ plt.xlim(0, 100)
349
+ plt.close(figdon)
350
 
351
+ img = fig_to_img(figdon)
352
 
353
  if len(protein_sequence) < 100:
354
+ return decoded_label_don, img, None, "**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
+ return decoded_label_don, img, None
357
 
358
+ def main_function_single(sequence):
359
+ # Initial preprocessing including BLAST-based model selection
360
+ protein_sequence, model_fam_path, model_don_path, error_msg = preprocess_protein_sequence(sequence)
361
+ if error_msg:
362
+ print(error_msg)
363
+ return None, None, error_msg, None, None
364
+
365
+ model_config = {
366
+ "GTA_fam.pth": {"num_labels": 41, "label_dict": GTA_fam_dict},
367
+ "GTB_fam.pth": {"num_labels": 27, "label_dict": GTB_fam_dict},
368
+ "GTA_don.pth": {"num_labels": 10, "label_dict": GTA_don_dict},
369
+ "GTB_don.pth": {"num_labels": 9, "label_dict": GTB_don_dict},
370
+ }
371
+
372
+ # Load the model for family classification
373
+ config_fam = model_config[model_fam_path]
374
+ model_fam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=config_fam["num_labels"])
375
+ model_fam.load_state_dict(torch.load(model_fam_path, map_location=torch.device('cpu')), strict=False)
376
+ model_fam.eval()
377
+ model_fam.to('cpu')
378
+
379
+ # Load the model for donor classification
380
+ config_don = model_config[model_don_path]
381
+ model_don = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=config_don["num_labels"])
382
+ model_don.load_state_dict(torch.load(model_don_path, map_location=torch.device('cpu')), strict=False)
383
+ model_don.eval()
384
+ model_don.to('cpu')
385
+
386
+ print(config_fam["label_dict"])
387
+
388
+ # Pass the label dictionary along with the model to the processing functions
389
+ family_label, family_img, _, family_info = process_family_sequence(protein_sequence, model_fam, config_fam["label_dict"])
390
+ donor_label, donor_img, _ = process_donor_sequence(protein_sequence, model_don, config_don["label_dict"])
 
 
 
 
 
 
 
 
391
 
392
+ return family_label, family_img, family_info, donor_label, donor_img
 
 
 
 
 
 
 
393
 
394
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  prediction_imagefam = gr.outputs.Image(type='pil', label="Family prediction graph")
396
  prediction_imagedonor = gr.outputs.Image(type='pil', label="Donor prediction graph")
 
 
397
 
398
  with gr.Blocks() as app:
399
+ gr.Markdown("# Glydentify (alpha v0.5)")
400
 
401
  with gr.Tab("Single Sequence Prediction"):
402
  with gr.Row().style(equal_height=True):
403
  with gr.Column():
404
  sequence = gr.inputs.Textbox(lines=16, placeholder='Enter Protein Sequence Here...', label="Protein Sequence")
405
+ # explanation_checkbox = gr.inputs.Checkbox(label="Show Explanation", default=False)
406
  with gr.Column():
407
  with gr.Accordion("Example:"):
408
  gr.Markdown("""
 
420
  with gr.Row().style(equal_height=True):
421
  with gr.Column():
422
  predict_button = gr.Button("Predict")
423
+ predict_button.click(main_function_single, inputs=[sequence],
424
  outputs=[family_prediction, prediction_imagefam, info_markdown,
425
+ donor_prediction, prediction_imagedonor])
426
 
427
  # Family & Donor Section
428
  with gr.Row().style(equal_height=True):
429
  with gr.Column():
430
+ with gr.Accordion("Family Prediction:"):
431
  prediction_imagefam.render() # = gr.outputs.Image(type='pil', label="Family prediction graph")
 
 
 
432
  with gr.Column():
433
+ with gr.Accordion("Donor Prediction:"):
434
+ prediction_imagedonor.render() # = gr.outputs.Image(type='pil', label="Donor prediction graph")
 
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
  app.launch(show_error=True)
438
 
backup/family.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:03dcff847ada129cd2889ea3f62071b666b009087829dabf85301210c7fe8382
3
- size 136199341
 
 
 
 
backup/family_labels.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c3e8fe9ddb883008ab377fba3837200626ee609fbe892950b1fada9ff078eca4
3
- size 4559
 
 
 
 
best_model_35M_t12_5v5.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1621bc6500a0dc3510af6d53cc405d4ac7cc8e0827e23b74f488867d321bc0e8
3
- size 136069629
 
 
 
 
donor_labels.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:504b291ac3a1de0e767117935a5546dd8d38b1150bc9183c9a3a8fbce3897a96
3
- size 679
 
 
 
 
family.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:725b2904a82171be55bf702f10e01d6185806e2556578d4cc99e1af9711b3952
3
- size 136163661
 
 
 
 
family_labels.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9f0cac818ec047a4e6c0ca9a9d3026bd3d224a8d492ea533dae107a4a8269db5
3
- size 3419
 
 
 
 
requirements.txt CHANGED
@@ -14,4 +14,5 @@ transformers==4.31.0
14
  scikit-learn==1.3.0
15
  torch==2.0.1
16
  torchaudio==2.0.2
17
- torchvision==0.15.2
 
 
14
  scikit-learn==1.3.0
15
  torch==2.0.1
16
  torchaudio==2.0.2
17
+ torchvision==0.15.2
18
+ accelerate==0.29.1