Spaces:
Runtime error
Runtime error
import re | |
import PIL.Image | |
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
from datasets import load_dataset | |
import infer | |
import matplotlib.pyplot as plt | |
from sklearn.manifold import TSNE | |
from sklearn.preprocessing import LabelEncoder | |
import torch | |
from torch import nn | |
from transformers import BertConfig, BertForMaskedLM, PreTrainedTokenizerFast | |
from huggingface_hub import PyTorchModelHubMixin | |
from config import DEFAULT_INPUTS, MODELS, DATASETS | |
# We need this for the eco layers because they are too big | |
PIL.Image.MAX_IMAGE_PIXELS = None | |
torch.set_grad_enabled(False) | |
# Load models | |
class DNASeqClassifier(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, bert_model, env_dim, num_classes): | |
super(DNASeqClassifier, self).__init__() | |
self.bert = bert_model | |
self.env_dim = env_dim | |
self.num_classes = num_classes | |
self.fc = nn.Linear(768 + env_dim, num_classes) | |
def forward(self, bert_inputs, env_data): | |
outputs = self.bert(**bert_inputs) | |
dna_embeddings = outputs.hidden_states[-1].mean(1) | |
combined = torch.cat((dna_embeddings, env_data), dim=1) | |
logits = self.fc(combined) | |
return logits | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODELS["embeddings"]) | |
embeddings_model = BertForMaskedLM.from_pretrained(MODELS["embeddings"]) | |
classification_model = DNASeqClassifier.from_pretrained( | |
MODELS["classification"], | |
bert_model=BertForMaskedLM( | |
BertConfig(vocab_size=259, output_hidden_states=True), | |
), | |
) | |
embeddings_model.eval() | |
classification_model.eval() | |
# Load datasets | |
ecolayers_ds = load_dataset(DATASETS["ecolayers"]) | |
def set_default_inputs(): | |
return (DEFAULT_INPUTS["dna_sequence"], | |
DEFAULT_INPUTS["latitude"], | |
DEFAULT_INPUTS["longitude"]) | |
def preprocess(dna_sequence: str, latitude: str, longitude: str): | |
""" | |
Prepares app input for downsteram tasks | |
""" | |
# Preprocess the DNA sequence turning it into an embedding | |
dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence) | |
dna_seq_preprocessed: str = re.sub(r"N+$", "", dna_sequence) | |
dna_seq_preprocessed = dna_seq_preprocessed[:660] | |
dna_seq_preprocessed = " ".join([ | |
dna_seq_preprocessed[i:i+4] for i in range(0, len(dna_seq_preprocessed), 4) | |
]) | |
dna_embedding: torch.Tensor = embeddings_model( | |
**tokenizer(dna_seq_preprocessed, return_tensors="pt") | |
).hidden_states[-1].mean(1).squeeze() | |
# Preprocess the location data | |
coords = (float(latitude), float(longitude)) | |
return dna_embedding, coords | |
# ecolayer_data = ecolayers_ds # TODO something something... | |
# # format lat and lon into coords | |
# coords = (inp_lat, inp_lng) | |
# # Grab rasters from the tifs | |
# ecoLayers = load_dataset("LofiAmazon/Global-Ecolayers") | |
# temp = pd.DataFrame([coords, embed], columns = ['coord', 'embeddings']) | |
# data = pd.merge(temp, ecoLayers, on='coord', how='left') | |
# return data | |
# def predict_genus(): | |
# data = preprocess() | |
# out = infer.infer_dna(data) | |
# results = [] | |
# genuses = infer.infer() | |
# results.append({ | |
# "sequence": dna_df['nucraw'], | |
# # "predictions": pd.concat([dna_genuses, envdna_genuses], axis=0) | |
# 'predictions': genuses}) | |
# return results | |
# def tsne_DNA(data, genuses): | |
# data["embeddings"] = data["embeddings"].apply(lambda x: np.array(list(map(float, x[1:-1].split())))) | |
# # Pick genuses with most samples | |
# top_k = 5 | |
# genus_counts = df["genus"].value_counts() | |
# top_genuses = genus_counts.head(top_k).index | |
# df = df[df["genus"].isin(top_genuses)] | |
# # Create a t-SNE plot of the embeddings | |
# n_genus = len(df["genus"].unique()) | |
# tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, n_iter=1000, random_state=0) | |
# X = np.stack(df["embeddings"].tolist()) | |
# y = df["genus"].tolist() | |
# X_tsne = tsne.fit_transform(X) | |
# label_encoder = LabelEncoder() | |
# y_encoded = label_encoder.fit_transform(y) | |
# plot = plt.figure(figsize=(6, 5)) | |
# scatter = plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="viridis", alpha=0.7) | |
# return plot | |
with gr.Blocks() as demo: | |
# Header section | |
gr.Markdown("# DNA Identifier Tool") | |
gr.Markdown(( | |
"Welcome to Lofi Amazon Beats' DNA Identifier Tool. " | |
"Please enter a DNA sequence and the coordinates at which its sample " | |
"was taken to get started. Click 'I'm feeling lucky' to see use a " | |
"random sequence." | |
)) | |
with gr.Row(): | |
with gr.Column(): | |
inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (min 200 and max 660 characters)") | |
with gr.Column(): | |
with gr.Row(): | |
inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. -3.009083") | |
with gr.Row(): | |
inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281") | |
with gr.Row(): | |
btn_run = gr.Button("Predict") | |
btn_run.click(fn=preprocess, inputs=[inp_dna, inp_lat, inp_lng]) | |
btn_defaults = gr.Button("I'm feeling lucky") | |
btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng]) | |
with gr.Tab("Genus Prediction"): | |
with gr.Row(): | |
gr.Markdown("Make plot or table for Top 5 species") | |
with gr.Row(): | |
genus_out = gr.Dataframe(headers=["DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"]) | |
# btn_run.click(fn=predict_genus, inputs=[inp_dna, inp_lat, inp_lng], outputs=genus_out) | |
with gr.Tab('DNA Embedding Space Visualizer'): | |
gr.Markdown("If the highest genus probability is very low for your DNA sequence, we can still examine the DNA embedding of the sequence in relation to known samples for clues.") | |
with gr.Row() as row: | |
with gr.Column(): | |
gr.Markdown("Plot of your DNA sequence among other known species clusters.") | |
# plot = gr.Plot("") | |
# btn_run.click(fn=tsne_DNA, inputs=[inp_dna, genus_out]) | |
with gr.Column(): | |
gr.Markdown("Plot of the five most common species at your sample coordinate.") | |
demo.launch() | |