File size: 8,733 Bytes
a5316e5
 
d563836
 
544f914
a5316e5
5f4434d
8da738a
7145ecb
544f914
 
 
6176ef8
f37b5da
d563836
c8ce48e
d563836
 
 
 
 
 
6176ef8
 
 
2f21a7c
360e1ea
6176ef8
 
 
ccf126e
bb9c09b
ccf126e
 
 
 
 
 
 
 
 
 
 
2f21a7c
acd0907
ccf126e
 
 
b1a0d53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c82104c
b1a0d53
6f59e3c
aa09a05
60270f6
8da738a
47d0212
 
dae2ae1
47d0212
e728893
dae2ae1
47d0212
e728893
47d0212
8da738a
3fcba4f
6176ef8
4e59324
b1a0d53
aa09a05
24390e2
4e59324
 
 
 
 
3218b1a
4e59324
24390e2
4e59324
 
 
 
 
 
 
 
 
 
 
aa09a05
4e59324
 
aa09a05
4e59324
 
 
3218b1a
4e59324
 
 
3218b1a
 
 
a423984
 
 
 
 
3218b1a
4e59324
92d90c6
dae2ae1
47d0212
dae2ae1
47d0212
 
 
621b193
6176ef8
 
5282aca
7145ecb
1e09a50
f30d0ea
5282aca
d3e6e3b
b8df8bd
53e71ae
142304a
47d0212
53e71ae
b5d9907
8e84211
142304a
d3c40d6
763f5b7
53e71ae
5282aca
d3e6e3b
5282aca
15ecc3d
 
8effe15
3cb6c3b
b5d9907
8e84211
142304a
d3c40d6
763f5b7
a5316e5
736419a
4e59324
a5316e5
8954378
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gradio as gr
from transformers import pipeline
import requests
from bs4 import BeautifulSoup
import pandas as pd

# Initialize models
classification_model = pipeline("text-classification", model="models/text_classification_model", tokenizer="models/text_classification_model", top_k=5)
mask_model = pipeline("fill-mask", model="models/fill_mask_model", tokenizer="models/fill_mask_model", top_k=100)

# Load data
eunis_habitats = pd.read_excel('data/eunis_habitats.xlsx')
    
def return_habitat_image(habitat_label):
    floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    #image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"  # While we don't have the rights
    #image_url = "https://files.ibot.cas.cz/cevs/images/syntaxa/large/Rorippo-Phalaridetum_arundinaceae2.jpg"  # 800-600 for Q51
    image = gr.Image(value=image_url)
    return image

def return_species_image(species):
    species = species.capitalize()
    floraveg_url = f"https://floraveg.eu/taxon/overview/{species}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    #image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"  # While we don't have the rights
    #image_url = "https://files.ibot.cas.cz/cevs/images/taxa/large/Eryngium_maritimum18.jpg"  # 1600-1200 for Q51 for eryngium maritimum
    image = gr.Image(value=image_url)
    return image

def gbif_normalization(text):
    base = "https://api.gbif.org/v1"
    api = "species"
    function = "match"
    parameter = "name"
    url = f"{base}/{api}/{function}?{parameter}="
    all_species = text.split(',')
    all_species = [species.strip() for species in all_species]
    species_gbif = []
    for species in all_species:
        url = url.replace(url.partition('name')[2], f'={species}')
        r = requests.get(url)
        r = r.json()
        if 'species' in r:
            r = r["species"]
        else:
            r = species
        species_gbif.append(r)
    text = ", ".join(species_gbif)
    text = text.lower()
    return text

def classification(text, k):
    text = gbif_normalization(text)
    result = classification_model(text)
    habitat_labels = [res['label'] for res in result[0][:k]]
    habitat_name = eunis_habitats[eunis_habitats['EUNIS 2020 code'] == habitat_labels[0]]['EUNIS-2021 habitat name'].values[0]
    if k == 1:
        text = f"This vegetation plot probably belongs to the habitat type {habitat_labels[0]}."
        text += f"\nThis habitat type is named '{habitat_name}'."
    elif k == 2:
        text = f"This vegetation plot probably belongs to the habitat type {', '.join(habitat_labels[:-1])} or {habitat_labels[-1]}."
        text += f"\nThe most likely habitat type (i.e., {habitat_labels[0]}) is named '{habitat_name}'."
    else:
        text = f"This vegetation plot probably belongs to the habitat type {', '.join(habitat_labels[:-1])}, or {habitat_labels[-1]}."
        text += f"\nThe most likely habitat type (i.e., {habitat_labels[0]}) is named '{habitat_name}'."
    text += f"\nSee an image of this habitat type below."
    image_output = return_habitat_image(habitat_labels[0])
    return text, image_output

def masking(text, k):
    text = gbif_normalization(text)
    text_split = text.split(', ')
    
    best_predictions = []
    
    for _ in range(k):
        max_score = 0
        best_prediction = None
        best_position = None
        best_sentence = None

        for i in range(len(text_split) + 1):
            masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:])
            
            j = 0
            while True:
                prediction = mask_model(masked_text)[j]
                species = prediction['token_str']
                if species in text_split or species in best_predictions:
                    j += 1
                else:
                    break

            score = prediction['score']
            sentence = prediction['sequence']

            if score > max_score:
                max_score = score
                best_prediction = species
                best_position = i
                best_sentence = sentence
        
        best_predictions.append(best_prediction)
        text_split.insert(best_position, best_prediction)
        
    best_positions = [text_split.index(prediction) for prediction in best_predictions]

    
    best_sentence = ", ".join(
        [s.strip().capitalize() for s in best_sentence.split(",")]
    )
    
    if k == 1:
        text = f"The most likely missing species is {best_predictions[0].capitalize()} (position {best_positions[0]})."
    elif k == 2:
        text = f"The most likely missing species are {', '.join(best_predictions[:-1].capitalize())} and {best_predictions[-1].capitalize()} (positions {', '.join(map(str, best_positions[:-1]))} and {best_positions[-1]})."
    else:
        text = f"The most likely missing species are {', '.join(best_predictions[:-1].capitalize())}, and {best_predictions[-1].capitalize()} (positions {', '.join(map(str, best_positions[:-1]))}, and {best_positions[-1]})."
    text += f"\nThe completed vegetation plot is thus '{best_sentence}'."
    text += f"\nSee an image of this species (i.e., {best_predictions[0].capitalize()}) below."
    image = return_species_image(best_predictions[0])
    return text, image

with gr.Blocks() as demo:

    gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""")
    
    with gr.Tab("Vegetation plot classification"):
        gr.Markdown("""<h3 style="text-align: center;">Habitat identification of vegetation plots!</h3>""")
        with gr.Row():
            with gr.Column():
                species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
                k_classification = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of habitat types to display.")
            with gr.Column():
                text_classification = gr.Textbox(label="Prediction")
                image_classification = gr.Image()
        button_classification = gr.Button("Classify")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples([["phragmites australis, lemna minor, typha latifolia", 3]], [species_classification, k_classification], [text_classification, image_classification], classification, True)
        
    with gr.Tab("Missing species finding"):
        gr.Markdown("""<h3 style="text-align: center;">Missing vascular plant species retrieval!</h3>""")
        with gr.Row():
            with gr.Column():
                species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
                k_masking = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of missing species to find.")
            with gr.Column():
                text_masking = gr.Textbox(label="Prediction")
                image_masking = gr.Image()
        button_masking = gr.Button("Find")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples([["calamagrostis arenaria, medicago marina, pancratium maritimum, thinopyrum junceum", 1]], [species_masking, k_masking], [text_masking, image_masking], masking, True)

    button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[text_classification, image_classification])
    button_masking.click(masking, inputs=[species_masking, k_masking], outputs=[text_masking, image_masking])

demo.launch()