vshulev commited on
Commit
2d43134
·
1 Parent(s): ae3ea59

Implement genus classification

Browse files
Files changed (2) hide show
  1. app.py +80 -55
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import re
2
  import PIL.Image
3
  import pandas as pd
@@ -12,14 +14,18 @@ import torch
12
  from torch import nn
13
  from transformers import BertConfig, BertForMaskedLM, PreTrainedTokenizerFast
14
  from huggingface_hub import PyTorchModelHubMixin
 
15
 
16
- from config import DEFAULT_INPUTS, MODELS, DATASETS
17
 
18
  # We need this for the eco layers because they are too big
19
  PIL.Image.MAX_IMAGE_PIXELS = None
20
 
21
  torch.set_grad_enabled(False)
22
 
 
 
 
23
 
24
  # Load models
25
  class DNASeqClassifier(nn.Module, PyTorchModelHubMixin):
@@ -60,10 +66,8 @@ def set_default_inputs():
60
  DEFAULT_INPUTS["longitude"])
61
 
62
 
63
- def preprocess(dna_sequence: str, latitude: str, longitude: str):
64
- """
65
- Prepares app input for downsteram tasks
66
- """
67
 
68
  # Preprocess the DNA sequence turning it into an embedding
69
  dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence)
@@ -80,58 +84,65 @@ def preprocess(dna_sequence: str, latitude: str, longitude: str):
80
  # Preprocess the location data
81
  coords = (float(latitude), float(longitude))
82
 
83
- return dna_embedding, coords
84
- # ecolayer_data = ecolayers_ds # TODO something something...
85
 
86
- # # format lat and lon into coords
87
- # coords = (inp_lat, inp_lng)
88
- # # Grab rasters from the tifs
89
- # ecoLayers = load_dataset("LofiAmazon/Global-Ecolayers")
90
- # temp = pd.DataFrame([coords, embed], columns = ['coord', 'embeddings'])
91
- # data = pd.merge(temp, ecoLayers, on='coord', how='left')
92
 
93
- # return data
94
-
95
- # def predict_genus():
96
- # data = preprocess()
97
- # out = infer.infer_dna(data)
98
-
99
- # results = []
100
-
101
- # genuses = infer.infer()
102
-
103
- # results.append({
104
- # "sequence": dna_df['nucraw'],
105
- # # "predictions": pd.concat([dna_genuses, envdna_genuses], axis=0)
106
- # 'predictions': genuses})
107
-
108
- # return results
109
 
110
- # def tsne_DNA(data, genuses):
111
- # data["embeddings"] = data["embeddings"].apply(lambda x: np.array(list(map(float, x[1:-1].split()))))
112
 
113
- # # Pick genuses with most samples
114
- # top_k = 5
115
- # genus_counts = df["genus"].value_counts()
116
- # top_genuses = genus_counts.head(top_k).index
117
- # df = df[df["genus"].isin(top_genuses)]
118
 
119
- # # Create a t-SNE plot of the embeddings
120
- # n_genus = len(df["genus"].unique())
121
- # tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, n_iter=1000, random_state=0)
122
 
123
- # X = np.stack(df["embeddings"].tolist())
124
- # y = df["genus"].tolist()
 
 
125
 
126
- # X_tsne = tsne.fit_transform(X)
127
 
128
- # label_encoder = LabelEncoder()
129
- # y_encoded = label_encoder.fit_transform(y)
130
 
131
- # plot = plt.figure(figsize=(6, 5))
132
- # scatter = plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="viridis", alpha=0.7)
133
 
134
- # return plot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
 
137
  with gr.Blocks() as demo:
@@ -156,20 +167,34 @@ with gr.Blocks() as demo:
156
 
157
  with gr.Row():
158
  btn_run = gr.Button("Predict")
159
- btn_run.click(fn=preprocess, inputs=[inp_dna, inp_lat, inp_lng])
 
 
 
160
 
161
  btn_defaults = gr.Button("I'm feeling lucky")
162
  btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng])
163
 
164
-
165
  with gr.Tab("Genus Prediction"):
166
- with gr.Row():
167
- gr.Markdown("Make plot or table for Top 5 species")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- with gr.Row():
170
- genus_out = gr.Dataframe(headers=["DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
171
- # btn_run.click(fn=predict_genus, inputs=[inp_dna, inp_lat, inp_lng], outputs=genus_out)
172
-
173
  with gr.Tab('DNA Embedding Space Visualizer'):
174
  gr.Markdown("If the highest genus probability is very low for your DNA sequence, we can still examine the DNA embedding of the sequence in relation to known samples for clues.")
175
 
 
1
+ from io import BytesIO
2
+ import os
3
  import re
4
  import PIL.Image
5
  import pandas as pd
 
14
  from torch import nn
15
  from transformers import BertConfig, BertForMaskedLM, PreTrainedTokenizerFast
16
  from huggingface_hub import PyTorchModelHubMixin
17
+ from pinecone import Pinecone
18
 
19
+ from config import DEFAULT_INPUTS, MODELS, DATASETS, ID_TO_GENUS_MAP
20
 
21
  # We need this for the eco layers because they are too big
22
  PIL.Image.MAX_IMAGE_PIXELS = None
23
 
24
  torch.set_grad_enabled(False)
25
 
26
+ # Configure pinecone
27
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
28
+ pc_index = pc.Index("amazon")
29
 
30
  # Load models
31
  class DNASeqClassifier(nn.Module, PyTorchModelHubMixin):
 
66
  DEFAULT_INPUTS["longitude"])
67
 
68
 
69
+ def preprocess(dna_sequence: str, latitude: float, longitude: float):
70
+ """Prepares app input for downsteram tasks"""
 
 
71
 
72
  # Preprocess the DNA sequence turning it into an embedding
73
  dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence)
 
84
  # Preprocess the location data
85
  coords = (float(latitude), float(longitude))
86
 
87
+ return dna_embedding, coords[0], coords[1]
 
88
 
 
 
 
 
 
 
89
 
90
+ def tokenize(dna_sequence: str) -> dict[str, torch.Tensor]:
91
+ dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence)
92
+ dna_seq_preprocessed: str = re.sub(r"N+$", "", dna_sequence)
93
+ dna_seq_preprocessed = dna_seq_preprocessed[:660]
94
+ dna_seq_preprocessed = " ".join([
95
+ dna_seq_preprocessed[i:i+4] for i in range(0, len(dna_seq_preprocessed), 4)
96
+ ])
 
 
 
 
 
 
 
 
 
97
 
98
+ return tokenizer(dna_seq_preprocessed, return_tensors="pt")
 
99
 
 
 
 
 
 
100
 
 
 
 
101
 
102
+ def get_embedding(dna_sequence: str) -> torch.Tensor:
103
+ dna_embedding: torch.Tensor = embeddings_model(
104
+ **tokenize(dna_sequence)
105
+ ).hidden_states[-1].mean(1).squeeze()
106
 
107
+ return dna_embedding
108
 
 
 
109
 
110
+ def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str):
111
+ coords = (float(latitude), float(longitude))
112
 
113
+ if method == "cosine":
114
+ embedding = get_embedding(dna_sequence)
115
+ result = pc_index.query(
116
+ namespace="all",
117
+ vector=embedding.tolist(),
118
+ top_k=100,
119
+ include_metadata=True,
120
+ )
121
+ top_k = [m["metadata"]["genus"] for m in result["matches"]]
122
+
123
+ top_k = pd.Series(top_k).value_counts()
124
+ top_k = top_k / top_k.sum()
125
+
126
+ if method == "fine_tuned_model":
127
+ bert_inputs = tokenize(dna_sequence)
128
+ logits = classification_model(bert_inputs, torch.zeros(1, 7))
129
+ probs = torch.softmax(logits, dim=1).squeeze()
130
+ top_k = torch.topk(probs, 10)
131
+ top_k = pd.Series(
132
+ top_k.values.detach().numpy(),
133
+ index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()]
134
+ )
135
+ # top_k = pd.Series(top_k.values.detach().numpy(), index=top_k.indices.detach().numpy())
136
+
137
+ fig, ax = plt.subplots()
138
+ ax.bar(top_k.index.astype(str), top_k.values)
139
+ ax.set_title("Genus Prediction")
140
+ ax.set_xlabel("Genus")
141
+ ax.set_ylabel("Probability")
142
+ ax.set_xticklabels(top_k.index.astype(str), rotation=90)
143
+ fig.canvas.draw()
144
+
145
+ return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
146
 
147
 
148
  with gr.Blocks() as demo:
 
167
 
168
  with gr.Row():
169
  btn_run = gr.Button("Predict")
170
+ btn_run.click(
171
+ fn=preprocess,
172
+ inputs=[inp_dna, inp_lat, inp_lng],
173
+ )
174
 
175
  btn_defaults = gr.Button("I'm feeling lucky")
176
  btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng])
177
 
 
178
  with gr.Tab("Genus Prediction"):
179
+ gr.Interface(
180
+ fn=predict_genus,
181
+ inputs=[
182
+ gr.Dropdown(choices=["cosine", "fine_tuned_model"], value="fine_tuned_model"),
183
+ inp_dna,
184
+ inp_lat,
185
+ inp_lng,
186
+ ],
187
+ outputs=["image"],
188
+ )
189
+
190
+ # with gr.Row():
191
+
192
+ # gr.Markdown("Make plot or table for Top 5 species")
193
+
194
+ # with gr.Row():
195
+ # genus_out = gr.Dataframe(headers=["DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
196
+ # # btn_run.click(fn=predict_genus, inputs=[inp_dna, inp_lat, inp_lng], outputs=genus_out)
197
 
 
 
 
 
198
  with gr.Tab('DNA Embedding Space Visualizer'):
199
  gr.Markdown("If the highest genus probability is very low for your DNA sequence, we can still examine the DNA embedding of the sequence in relation to known samples for clues.")
200
 
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  huggingface-hub==0.23.2
2
  pandas==2.2.2
 
3
  rasterio==1.3.10
4
  torch==2.3.0
5
  tqdm==4.66.4
 
1
  huggingface-hub==0.23.2
2
  pandas==2.2.2
3
+ pinecone_client==4.1.0
4
  rasterio==1.3.10
5
  torch==2.3.0
6
  tqdm==4.66.4