azizbarank commited on
Commit
26ef1d5
1 Parent(s): e149850

Create new file

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("pip install torch")
3
+ os.system("pip install transformers")
4
+ os.system("pip install sentencepiece")
5
+ os.system("pip install plotly")
6
+
7
+
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
9
+ import sentencepiece
10
+ import torch
11
+ import plotly.graph_objects as go
12
+ import streamlit as st
13
+
14
+ text_1 = """Avec la Ligue 1 qui reprend ses droits à partir de vendredi 5 août, et un premier match pour ce qui les concerne samedi soir, à Clermont-Ferrand, l’heure est désormais arrivée pour les Parisiens d’apporter les preuves que ce changement d’ère est bien une réalité."""
15
+
16
+ text_2 = """Créées en 1991 sur un modèle inspiré de la Fête de la musique, les Nuits des étoiles ont pour thème en 2022 l’exploration spatiale, en partenariat avec l’Agence spatiale européenne."""
17
+
18
+ @st.cache(allow_output_mutation=True)
19
+ def list2text(label_list):
20
+ labels = ""
21
+ for label in label_list:
22
+ labels = labels + label + ","
23
+ labels = labels[:-1]
24
+ return labels
25
+
26
+ label_list_1 = ["monde", "économie", "sciences", "culture", "santé", "politique", "sport", "technologie"]
27
+ label_list_2 = ["positif", "négatif", "neutre"]
28
+
29
+ st.title("French Zero-Shot Text Classification \
30
+ with CamemBERT and XLM-R")
31
+
32
+ # Body
33
+ st.markdown(
34
+ """
35
+
36
+ This application makes use of [CamemBERT](https://camembert-model.fr/) and [XLM-R](https://arxiv.org/abs/1911.02116) models that were fine-tuned on the XNLI corpus. While CamemBERT was fine-tuned only on the French part of the corpus by [Baptiste Doyen](https://huggingface.co/BaptisteDoyen), XLM-R was done so on all parts of it by [Joe Davison](https://huggingface.co/joeddav), including French. Therefore, in this app, both of these two models are intended to be used and made comparison of each other for zero-shot classification in French.
37
+
38
+ """
39
+ )
40
+
41
+ model_list = ['BaptisteDoyen/camembert-base-xnli',
42
+ 'joeddav/xlm-roberta-large-xnli']
43
+
44
+ st.sidebar.header("Select Model")
45
+ model_checkpoint = st.sidebar.radio("", model_list)
46
+
47
+ st.sidebar.write("For the full descriptions of the models:")
48
+ st.sidebar.write("[camembert-base-xnli](https://huggingface.co/BaptisteDoyen/camembert-base-xnli)")
49
+ st.sidebar.write("[xlm-roberta-large-xnli](https://huggingface.co/joeddav/xlm-roberta-large-xnli)")
50
+
51
+ st.sidebar.write("For the XNLI Dataset:")
52
+ st.sidebar.write("[XNLI](https://huggingface.co/datasets/xnli)")
53
+
54
+ st.subheader("Select Text and Label List")
55
+ st.text_area("Text #1", text_1, height=128)
56
+ st.text_area("Text #2", text_2, height=128)
57
+ st.write(f"Label List #1: {list2text(label_list_1)}")
58
+ st.write(f"Label List #2: {list2text(label_list_2)}")
59
+
60
+ text = st.radio("Select Text", ("Text #1", "Text #2", "New Text"))
61
+ labels = st.radio("Select Label List", ("Label List #1", "Label List #2", "New Label List"))
62
+
63
+ if text == "Text #1": selected_text = text_1
64
+ elif text == "Text #2": selected_text = text_2
65
+ elif text == "New Text":
66
+ selected_text = st.text_area("New Text", value="", height=128)
67
+
68
+ if labels == "Label List #1": selected_labels = label_list_1
69
+ elif labels == "Label List #2": selected_labels = label_list_2
70
+ elif labels == "New Label List":
71
+ selected_labels = st.text_area("New Label List (Pls Input as comma-separated)", value="", height=16).split(",")
72
+
73
+ @st.cache(allow_output_mutation=True)
74
+ def setModel(model_checkpoint):
75
+ model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
76
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
77
+ return pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)
78
+
79
+ Run_Button = st.button("Run", key=None)
80
+ if Run_Button == True:
81
+
82
+ zstc_pipeline = setModel(model_checkpoint)
83
+ output = zstc_pipeline(sequences=selected_text, candidate_labels=selected_labels)
84
+ output_labels = output["labels"]
85
+ output_scores = output["scores"]
86
+
87
+ st.header("Result")
88
+ import plotly.graph_objects as go
89
+ fig = go.Figure([go.Bar(x=output_labels, y=output_scores)])
90
+ st.plotly_chart(fig, use_container_width=False, sharing="streamlit")