File size: 2,708 Bytes
0963d32
 
 
36351c1
 
0963d32
36351c1
0963d32
 
 
36351c1
 
0963d32
36351c1
 
0963d32
 
 
 
36351c1
0963d32
36351c1
 
 
 
 
 
 
 
 
 
 
 
0963d32
 
 
 
36351c1
 
6c93aa8
36351c1
0963d32
36351c1
 
0963d32
 
36351c1
0963d32
36351c1
0963d32
36351c1
 
0963d32
 
 
 
36351c1
0963d32
 
36351c1
 
 
 
 
0963d32
 
 
36351c1
 
 
 
0963d32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.

# %% auto 0
__all__ = ['repo_id', 'learner', 'path', 'countries', 'categories', 'title', 'description', 'article', 'image', 'label',
           'country', 'summary', 'link', 'examples', 'intf', 'get_countries', 'classify_image']

# %% app.ipynb 3
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai
import gradio as gr
import wikipedia
import pandas as pd

# %% app.ipynb 4
repo_id = "Jimmie/snake-species-identification"

# loading the model from huggingface_hub
learner = from_pretrained_fastai(repo_id)

# %% app.ipynb 5
path = Path('demo-images/')
countries = pd.read_csv('species_to_country_mapping.csv', index_col=0)

# %% app.ipynb 9
def get_countries(binomial):
    sample_row = countries.loc[binomial]
    country_list = sample_row[sample_row == 1].index.tolist()
    # title case all items in country_list
    country_list = [country.title() for country in country_list]
    # return all items in country_list as a string
    return ", ".join(country_list)

# %% app.ipynb 20
categories = tuple(learner.dls.vocab)

def classify_image(img):
    pred,idx,probs = learner.predict(img)
    countries = get_countries(pred)
    summary = wikipedia.summary(pred)
    wiki_link = f'Learn more: <a href={wikipedia.page(pred).url} target="_blank">{pred}</a>'
    return dict(zip(categories, map(float, probs))), countries, summary, wiki_link

# %% app.ipynb 22
title = "Snake Species Identification"

description = """
This demo is an ongoing iteration of the [Snake Species Identification](https://github.com/jimmiemunyi/the-snake-project-cls) project meant to classify snakes up to the species level (binomial name).

Currently, it can classify snakes into 50 categories but it is continually updated to support more categories (over 200).

The model can be found here: https://huggingface.co/Jimmie/snake-species-identification.
The model is trained on the following dataset: https://www.aicrowd.com/challenges/snakeclef2021-snake-species-identification-challenge.

Enjoy!
"""

article = "Blog posts on how the model is being trained: COMING SOON!"


image = gr.Image(shape=(224, 224))
label = gr.Label(num_top_classes=3, label='Binomial')
country = gr.Textbox(label='Countries where the species is found')
summary = gr.Textbox(label='Wikipedia Summary')
link = gr.HTML(label="Learn More:", show_label=True)
examples = list(path.ls())


intf = gr.Interface(fn=classify_image, inputs=image, 
                   outputs=[label, country, summary, link], examples=examples,
                   title = title, description = description, article = article, 
                   cache_examples=False)
intf.launch(inline=False)