File size: 1,383 Bytes
a5316e5
 
7e0319c
a5316e5
 
7e0319c
a5316e5
 
 
 
7e0319c
a5316e5
1c2f25a
 
 
a5316e5
1c2f25a
a5316e5
 
1c2f25a
 
 
 
a5316e5
 
7e0319c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import gradio as gr
from transformers import pipeline
from datasets import load_dataset

classifier = pipeline("text-classification", model="CesarLeblanc/test_model")
dataset = load_dataset("CesarLeblanc/text_classification_dataset")

def text_classification(text):
    result = classifier(text)
    habitat_label = result[0]['label']
    habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
    habitat_score = result[0]['score']
    formatted_output = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%"
    image_output = gr.Image(value='N13.jpg', type="filepath")
    return formatted_output, image_output

examples=["sparganium erectum, calystegia sepium, persicaria amphibia", "thinopyrum junceum, cakile maritima"]

io = gr.Interface(fn=text_classification, 
                         inputs= gr.Textbox(lines=2, label="List of comma-separated binomial names of species (see examples)", placeholder="Enter species here..."), 
                         outputs=[gr.Textbox(lines=2, label="Vegetation Plot Classification Result"), "image"],
                         title="Pl@ntBERT",
                         description="Vegetation Plot Classification: enter the species found in a vegetation plot and see its EUNIS habitat!",
                         examples=examples)

io.launch()