Spaces:
Runtime error
Runtime error
import os | |
import pickle | |
import re | |
import PIL.Image | |
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
from datasets import load_dataset | |
import matplotlib.pyplot as plt | |
from sklearn.manifold import TSNE | |
from sklearn.preprocessing import LabelEncoder | |
import torch | |
from torch import nn | |
from transformers import BertConfig, BertForMaskedLM, PreTrainedTokenizerFast | |
from huggingface_hub import PyTorchModelHubMixin | |
from pinecone import Pinecone | |
import rasterio | |
from rasterio.sample import sample_gen | |
from config import DEFAULT_INPUTS, MODELS, DATASETS, ID_TO_GENUS_MAP, LAYER_NAMES | |
# We need this for the eco layers because they are too big | |
PIL.Image.MAX_IMAGE_PIXELS = None | |
torch.set_grad_enabled(False) | |
# Configure pinecone | |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) | |
pc_index = pc.Index("amazon") | |
# Load models | |
class DNASeqClassifier(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, bert_model, env_dim, num_classes): | |
super(DNASeqClassifier, self).__init__() | |
self.bert = bert_model | |
self.env_dim = env_dim | |
self.num_classes = num_classes | |
self.fc = nn.Linear(768 + env_dim, num_classes) | |
def forward(self, bert_inputs, env_data): | |
outputs = self.bert(**bert_inputs) | |
dna_embeddings = outputs.hidden_states[-1].mean(1) | |
combined = torch.cat((dna_embeddings, env_data), dim=1) | |
logits = self.fc(combined) | |
return logits | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODELS["embeddings"]) | |
embeddings_model = BertForMaskedLM.from_pretrained(MODELS["embeddings"]) | |
classification_model = DNASeqClassifier.from_pretrained( | |
MODELS["classification"], | |
bert_model=BertForMaskedLM( | |
BertConfig(vocab_size=259, output_hidden_states=True), | |
), | |
) | |
with open("scaler.pkl", "rb") as f: | |
scaler = pickle.load(f) | |
embeddings_model.eval() | |
classification_model.eval() | |
# Load datasets | |
amazon_ds = load_dataset(DATASETS["amazon"]) | |
def set_default_inputs(): | |
return (DEFAULT_INPUTS["dna_sequence"], | |
DEFAULT_INPUTS["latitude"], | |
DEFAULT_INPUTS["longitude"]) | |
def preprocess(dna_sequence: str, latitude: float, longitude: float): | |
"""Prepares app input for downsteram tasks""" | |
# Preprocess the DNA sequence turning it into an embedding | |
dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence) | |
dna_seq_preprocessed: str = re.sub(r"N+$", "", dna_sequence) | |
dna_seq_preprocessed = dna_seq_preprocessed[:660] | |
dna_seq_preprocessed = " ".join([ | |
dna_seq_preprocessed[i:i+4] for i in range(0, len(dna_seq_preprocessed), 4) | |
]) | |
dna_embedding: torch.Tensor = embeddings_model( | |
**tokenizer(dna_seq_preprocessed, return_tensors="pt") | |
).hidden_states[-1].mean(1).squeeze() | |
# Preprocess the location data | |
coords = (float(latitude), float(longitude)) | |
return dna_embedding, coords[0], coords[1] | |
def tokenize(dna_sequence: str) -> dict[str, torch.Tensor]: | |
dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence) | |
dna_seq_preprocessed: str = re.sub(r"N+$", "", dna_sequence) | |
dna_seq_preprocessed = dna_seq_preprocessed[:660] | |
dna_seq_preprocessed = " ".join([ | |
dna_seq_preprocessed[i:i+4] for i in range(0, len(dna_seq_preprocessed), 4) | |
]) | |
return tokenizer(dna_seq_preprocessed, return_tensors="pt") | |
def get_embedding(dna_sequence: str) -> torch.Tensor: | |
dna_embedding: torch.Tensor = embeddings_model( | |
**tokenize(dna_sequence) | |
).hidden_states[-1].mean(1).squeeze() | |
return dna_embedding | |
def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str): | |
coords = (float(latitude), float(longitude)) | |
if method == "cosine": | |
embedding = get_embedding(dna_sequence) | |
result = pc_index.query( | |
namespace="all", | |
vector=embedding.tolist(), | |
top_k=10, | |
include_metadata=True, | |
) | |
top_k = [m["metadata"]["genus"] for m in result["matches"]] | |
top_k = pd.Series(top_k).value_counts() | |
top_k = top_k / top_k.sum() | |
if method == "fine_tuned_model": | |
bert_inputs = tokenize(dna_sequence) | |
env_data = [] | |
for layer in LAYER_NAMES: | |
with rasterio.open(layer) as dataset: | |
# Get the corresponding ecological values for the samples | |
results = sample_gen(dataset, [coords]) | |
results = [r for r in results] | |
layer_data = np.mean(results[0]) | |
env_data.append(layer_data) | |
env_data = scaler.transform([env_data]) | |
env_data = torch.from_numpy(env_data).to(torch.float32) | |
logits = classification_model(bert_inputs, env_data) | |
temperature = 0.2 | |
probs = torch.softmax(logits / temperature, dim=1).squeeze() | |
top_k = torch.topk(probs, 10) | |
top_k = pd.Series( | |
top_k.values.detach().numpy(), | |
index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()] | |
) | |
fig, ax = plt.subplots() | |
ax.bar(top_k.index.astype(str), top_k.values) | |
ax.set_ylim(0, 1) | |
ax.set_title("Genus Prediction") | |
ax.set_xlabel("Genus") | |
ax.set_ylabel("Probability") | |
ax.set_xticklabels(top_k.index.astype(str), rotation=90) | |
fig.subplots_adjust(bottom=0.3) | |
fig.canvas.draw() | |
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
def cluster_dna(top_k: float): | |
df = amazon_ds["train"].to_pandas() | |
df = df[df["genus"].notna()] | |
top_k = int(top_k) | |
genus_counts = df["genus"].value_counts() | |
top_genuses = genus_counts.head(top_k).index | |
df = df[df["genus"].isin(top_genuses)] | |
tsne = TSNE( | |
n_components=2, perplexity=30, learning_rate=200, | |
n_iter=1000, random_state=0, | |
) | |
X = np.stack(df["embeddings"].tolist()) | |
y = df["genus"].tolist() | |
X_tsne = tsne.fit_transform(X) | |
label_encoder = LabelEncoder() | |
y_encoded = label_encoder.fit_transform(y) | |
fig, ax = plt.subplots() | |
ax.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="viridis", alpha=0.7) | |
ax.set_title(f"DNA Embedding Space (of {str(top_k)} most common genera)") | |
# Reduce unnecessary whitespace | |
ax.set_xlim(X_tsne[:, 0].min() - 0.1, X_tsne[:, 0].max() + 0.1) | |
fig.canvas.draw() | |
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
with gr.Blocks() as demo: | |
# Header section | |
gr.Markdown("# DNA Identifier Tool") | |
gr.Markdown(( | |
"Welcome to Lofi Amazon Beats' DNA Identifier Tool. " | |
"Please enter a DNA sequence and the coordinates at which its sample " | |
"was taken to get started. Click 'I'm feeling lucky' to see use a " | |
"random sequence." | |
)) | |
with gr.Row(): | |
with gr.Column(): | |
inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (min 200 and max 660 characters)") | |
with gr.Column(): | |
with gr.Row(): | |
inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. -3.009083") | |
with gr.Row(): | |
inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281") | |
with gr.Row(): | |
btn_defaults = gr.Button("I'm feeling lucky") | |
btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng]) | |
with gr.Tab("Genus Prediction"): | |
gr.Markdown(""" | |
## Genus prediction | |
A demo of predicting the genus of a DNA sequence using multiple | |
approaches (method dropdown): | |
- **fine_tuned_model**: using our | |
`LofiAmazon/BarcodeBERT-Finetuned-Amazon` which predicts the genus | |
based on the DNA sequence and environmental data. | |
- **cosine**: computes a cosine similarity between the DNA sequence | |
embedding generated by our model and the embeddings of known samples | |
that we precomputed and stored in a Pinecone index. Thie method | |
DOES NOT examine ecological layer data. | |
""") | |
# gr.Interface( | |
# fn=predict_genus, | |
# inputs=[ | |
# gr.Dropdown(choices=["cosine", "fine_tuned_model"], value="fine_tuned_model"), | |
# inp_dna, | |
# inp_lat, | |
# inp_lng, | |
# ], | |
# outputs=["image"], | |
# allow_flagging="never", | |
# ) | |
method_dropdown = gr.Dropdown(choices=["cosine", "fine_tuned_model"], value="fine_tuned_model") | |
predict_button = gr.Button("Predict Genus") | |
genus_output = gr.Image() | |
predict_button.click( | |
fn=predict_genus, | |
inputs=[method_dropdown, inp_dna, inp_lat, inp_lng], | |
outputs=genus_output | |
) | |
with gr.Tab("DNA Embedding Space Visualizer"): | |
gr.Markdown(""" | |
## DNA Embedding Space Visualizer | |
We show a 2D t-SNE plot of the DNA embeddings of the five most common | |
genera in our dataset. This shows that the DNA Transformer model is | |
learning to cluster similar DNA sequences together. | |
""") | |
# gr.Interface( | |
# fn=cluster_dna, | |
# inputs=[ | |
# gr.Slider(minimum=1, maximum=10, step=1, value=5, | |
# label="Number of top genera to visualize") | |
# ], | |
# outputs=["image"], | |
# allow_flagging="never", | |
# ) | |
top_k_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of top genera to visualize") | |
visualize_button = gr.Button("Visualize Embedding Space") | |
visualize_output = gr.Image() | |
visualize_button.click( | |
fn=cluster_dna, | |
inputs=top_k_slider, | |
outputs=visualize_output | |
) | |
demo.launch() | |