Spaces:
Running
Running
File size: 2,627 Bytes
3ae84a3 78cdedf 3e5aff4 3ae84a3 19cb4eb 9b28e54 3ae84a3 0d69242 3ae84a3 0d69242 78cdedf 3ae84a3 0d69242 9b28e54 78cdedf 0d69242 c45624f 19cb4eb c45624f 8fd2365 19cb4eb 3e5aff4 3ae84a3 19cb4eb 3e5aff4 19cb4eb 3e5aff4 c3eae7d 3e5aff4 19cb4eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import pickle
import gradio as gr
from datasets import load_dataset
from transformers import AutoModel, AutoFeatureExtractor
import wikipedia
# Only runs once when the script is first run.
with open("index_768_cosine.pickle", "rb") as handle:
index = pickle.load(handle)
# Load model for computing embeddings.
feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
model = AutoModel.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
# Candidate images.
dataset = load_dataset("sasha/butterflies_10k_names_multiple")
ds = dataset["train"]
def query(image, top_k=4):
inputs = feature_extractor(image, return_tensors="pt")
model_output = model(**inputs)
embedding = model_output.pooler_output.detach()
results = index.query(embedding, k=top_k)
inx = results[0][0].tolist()
logits = results[1][0].tolist()
images = ds.select(inx)["image"]
captions = ds.select(inx)["name"]
images_with_captions = [(i, c) for i, c in zip(images,captions)]
labels_with_probs = dict(zip(captions,logits))
labels_with_probs = {k: 1- v for k, v in labels_with_probs.items()}
try:
description = wikipedia.summary(captions[0], sentences = 1)
description = "### " + description
url = wikipedia.page(captions[0]).url
url = " You can learn more about your butterfly [here](" + str(url) + ")!"
description = description + url
except:
description = "### Butterflies are insects in the order Lepidoptera, which also includes moths. Adult butterflies have large, often brightly coloured wings."
url = "https://en.wikipedia.org/wiki/Butterfly"
url = " You can learn more about butterflies [here](" + str(url) + ")!"
description = description + url
return images_with_captions, labels_with_probs, description
with gr.Blocks() as demo:
gr.Markdown("# Find my Butterfly 🦋")
gr.Markdown("## Use this Space to find your butterfly, based on the [iNaturalist butterfly dataset](https://huggingface.co/datasets/huggan/inat_butterflies_top10k)!")
with gr.Row():
with gr.Column(min_width= 900):
inputs = gr.Image(shape=(800, 1600))
btn = gr.Button("Find my butterfly!")
description = gr.Markdown()
with gr.Column():
outputs=gr.Gallery().style(grid=[2], height="auto")
labels = gr.Label()
gr.Markdown("### Image Examples")
gr.Examples(
examples=["elton.jpg", "ken.jpg", "gaga.jpg", "taylor.jpg"],
inputs=inputs,
outputs=[outputs,labels],
fn=query,
cache_examples=True,
)
btn.click(query, inputs, [outputs, labels, description])
demo.launch()
|