akdeniz27 commited on
Commit
c8f4a50
1 Parent(s): 348dfbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -18
app.py CHANGED
@@ -1,8 +1,8 @@
1
- # Turkish Zero-Shot Text Classification with XLM-RoBERTa
2
 
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
4
  import sentencepiece
5
- import torch
6
  import plotly.graph_objects as go
7
  import streamlit as st
8
 
@@ -28,20 +28,13 @@ label_list_2 = ["positive", "negative", "neutral"]
28
  st.title("Turkish Zero-Shot Text Classification \
29
  with Multilingual XLM-RoBERTa and mDeBERTa Models")
30
 
31
- model_list = ['vicgalle/xlm-roberta-large-xnli-anli',
32
- 'joeddav/xlm-roberta-large-xnli',
33
- 'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7']
34
 
35
- st.sidebar.header("Select Model")
36
- model_checkpoint = st.sidebar.radio("", model_list)
37
 
38
- st.sidebar.write("For details of models:")
39
- st.sidebar.write("https://huggingface.co/vicgalle")
40
- st.sidebar.write("https://huggingface.co/joeddav")
41
- st.sidebar.write("https://huggingface.co/MoritzLaurer")
42
-
43
- st.sidebar.write("For XNLI Dataset:")
44
- st.sidebar.write("https://huggingface.co/datasets/xnli")
45
 
46
  st.subheader("Select Text and Label List")
47
  st.text_area("Text #1", text_1, height=128)
@@ -63,9 +56,10 @@ elif labels == "New Label List":
63
  selected_labels = st.text_area("New Label List (Pls Input as comma-separated)", value="", height=16).split(",")
64
 
65
  @st.cache(allow_output_mutation=True)
66
- def setModel(model_checkpoint):
67
- model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
68
- tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
 
69
  return pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)
70
 
71
  Run_Button = st.button("Run", key=None)
 
1
+ # Zero-Shot Text Classification with Multilingual T5 (mT5)
2
 
3
+ from torch.nn.functional import softmax
4
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer
5
  import sentencepiece
 
6
  import plotly.graph_objects as go
7
  import streamlit as st
8
 
 
28
  st.title("Turkish Zero-Shot Text Classification \
29
  with Multilingual XLM-RoBERTa and mDeBERTa Models")
30
 
31
+ model_name = "alan-turing-institute/mt5-large-finetuned-mnli-xtreme-xnli"
 
 
32
 
33
+ st.sidebar.write("For details of used model:")
34
+ st.sidebar.write("https://huggingface.co/alan-turing-institute/mt5-large-finetuned-mnli-xtreme-xnli")
35
 
36
+ st.sidebar.write("For Xtreme XNLI Dataset:")
37
+ st.sidebar.write("https://www.tensorflow.org/datasets/catalog/xtreme_xnli")
 
 
 
 
 
38
 
39
  st.subheader("Select Text and Label List")
40
  st.text_area("Text #1", text_1, height=128)
 
56
  selected_labels = st.text_area("New Label List (Pls Input as comma-separated)", value="", height=16).split(",")
57
 
58
  @st.cache(allow_output_mutation=True)
59
+ def setModel(model_name):
60
+ tokenizer = MT5Tokenizer.from_pretrained(model_name)
61
+ model = MT5ForConditionalGeneration.from_pretrained(model_name)
62
+ model.eval()
63
  return pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)
64
 
65
  Run_Button = st.button("Run", key=None)