nsorros commited on
Commit
cd56275
1 Parent(s): ae2a255

Add streamlit app

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +35 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
1
+ # Venv
2
+ .env
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import streamlit as st
3
+
4
+
5
+ st.header("MeshTagger 🔖")
6
+ threshold = st.sidebar.slider("Threshold", value=0.5, min_value=0.0, max_value=1.0)
7
+ display_probabilities = st.sidebar.checkbox("Display probabilities")
8
+
9
+ if "model" not in st.session_state:
10
+ with st.spinner("Loading model and tokenizer..."):
11
+ st.session_state["tokenizer"] = AutoTokenizer.from_pretrained(
12
+ "Wellcome/WellcomeBertMesh"
13
+ )
14
+ st.session_state["model"] = AutoModel.from_pretrained(
15
+ "Wellcome/WellcomeBertMesh", trust_remote_code=True
16
+ )
17
+
18
+ model = st.session_state["model"]
19
+ tokenizer = st.session_state["tokenizer"]
20
+
21
+ text = st.text_area("", value="This text is about Malaria", height=400)
22
+ inputs = tokenizer([text], padding="max_length")
23
+ outputs = model(**inputs)[0]
24
+
25
+ if display_probabilities:
26
+ data = [
27
+ (model.id2label[label_id], label_prob.item())
28
+ for label_id, label_prob in enumerate(outputs)
29
+ if label_prob > threshold
30
+ ]
31
+ st.table(data)
32
+ else:
33
+ for label_id, label_prob in enumerate(outputs):
34
+ if label_prob > threshold:
35
+ st.button(model.id2label[label_id])