Arash Mari Oriyad commited on
Commit
975f394
1 Parent(s): a78f9e6

add webdemo app using streamlit

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import PlaceHolder
2
+ from re import sub
3
+ import streamlit as st
4
+ import imp, time, random
5
+ import base64
6
+ import io
7
+ import nbformat
8
+ from PIL import Image
9
+ from datasets import load_from_disk, load_dataset
10
+ import os
11
+ from transformers import pipeline
12
+
13
+
14
+ st.set_page_config(layout="wide")
15
+
16
+ def set_submitted_true():
17
+ st.session_state.submitted = True
18
+
19
+ st.markdown("""
20
+ <style>
21
+ input, .rtl {
22
+ unicode-bidi:bidi-override;
23
+ direction: RTL;
24
+ }
25
+ textarea, .rtl {
26
+ unicode-bidi:bidi-override;
27
+ direction: RTL;
28
+ }
29
+ h2, .rtl {
30
+ unicode-bidi:bidi-override;
31
+ direction: RTL;
32
+ }
33
+ div[role=tablist], .rtl {
34
+ unicode-bidi:bidi-override;
35
+ direction: RTL;
36
+ }
37
+ div[role=alert], .rtl {
38
+ unicode-bidi:bidi-override;
39
+ direction: RTL;
40
+ }
41
+ </style>
42
+ """, unsafe_allow_html=True)
43
+
44
+ latest_iteration = st.empty()
45
+ bar = st.progress(0)
46
+
47
+
48
+ st.markdown("## سیستم پرسش و پاسخ فارسی")
49
+ st.markdown("")
50
+
51
+ tab1, tab2 = st.tabs(["دمو", "مستندات"])
52
+
53
+
54
+ datasets_names_addresses = {"small-persian-QA": "Hamid-reza/small-persian-QA",
55
+ "addsent-small-persian-QA": "Hamid-reza/Adv-small-persian-QA",
56
+ "addany-small-persian-QA": "mohammadhossein/addany-dataset",
57
+ "back-translation-small-persian-QA": "jalalnb/back_translation_hy_on_small_persian_QA",
58
+ "invisible-char-small-persian-QA": "jalalnb/invisible_char_on_small_persian_QA"}
59
+
60
+ @st.cache(allow_output_mutation=True)
61
+ def load_datasets(datasets_names_addresses):
62
+ return {dataset_name: load_dataset(dataset_address)["validation"]
63
+ for dataset_name, dataset_address in datasets_names_addresses.items()}
64
+
65
+ datasets_names_content = load_datasets(datasets_names_addresses)
66
+
67
+ selected_dataset_name = st.sidebar.radio(
68
+ ':دیتاست مورد نظر خود را انتخاب نمایید',
69
+ list(datasets_names_addresses.keys()))
70
+ selected_dataset = datasets_names_content[selected_dataset_name]
71
+
72
+
73
+ models_names_addresses = {"mbert": ("arashmarioriyad/mbert_v3", "arashmarioriyad/mbert_tokenizer_v3"),
74
+ "parsbert": ("arashmarioriyad/parsbert_v1", "arashmarioriyad/parsbert_tokenizer_v1"),
75
+ "addsent-mbert": ("arashmarioriyad/addsent_mbert_v1", "arashmarioriyad/addsent_mbert_tokenizer_v1"),
76
+ "addsent-parsbert": ("arashmarioriyad/addsent_parsbert_v1", "arashmarioriyad/addsent_parsbert_tokenizer_v1"),
77
+ "addany-mbert": ("arashmarioriyad/addany_mbert_v1", "arashmarioriyad/addany_mbert_tokenizer_v1"),
78
+ "addany-parsbert": ("arashmarioriyad/addany_parsbert_v1", "arashmarioriyad/addany_parsbert_tokenizer_v1"),
79
+ "back-translation-mbert": ("arashmarioriyad/bt_hy_mbert_v1", "arashmarioriyad/bt_hy_mbert_tokenizer_v1"),
80
+ "back-translation-parsbert": ("arashmarioriyad/bt_hy_parsbert_v1", "arashmarioriyad/bt_hy_parsbert_tokenizer_v1"),
81
+ "invisible-char-mbert": ("arashmarioriyad/ic_mbert_v1", "arashmarioriyad/ic_mbert_tokenizer_v1"),
82
+ "invisible-char-parsbert": ("arashmarioriyad/ic_parsbert_v1", "arashmarioriyad/ic_parsbert_tokenizer_v1")}
83
+
84
+ @st.cache(allow_output_mutation=True)
85
+ def load_models(models_names_addresses):
86
+ return {model_name: pipeline("question-answering",
87
+ model=models_names_addresses[model_name][0],
88
+ tokenizer=models_names_addresses[model_name][1])
89
+ for model_name, model_address in models_names_addresses.items()}
90
+
91
+ models_names_contents = load_models(models_names_addresses)
92
+
93
+ selected_model_name = st.sidebar.radio(
94
+ ':مدل مورد نظر خود را انتخاب نمایید',
95
+ list(models_names_addresses.keys()))
96
+ selected_model = models_names_contents[selected_model_name]
97
+
98
+
99
+ st.sidebar.info("تمامی دادگان، کد ها و نتایج ارزیابی مدل ها در [صفحه گیت هاب پروژه](https://github.com/NLP-Final-Projects/Adversarial-QA/) قابل دسترسی است", icon="ℹ️")
100
+
101
+
102
+
103
+ with tab1.form("my_form", clear_on_submit=False):
104
+
105
+ col1, col2, col3 = st.columns(3)
106
+ with col1:
107
+ generate_random_data = st.form_submit_button("تولید داده‌ی تصادفی")
108
+ if generate_random_data:
109
+ sample_idx = random.randrange(len(selected_dataset))
110
+ st.session_state.context = selected_dataset[sample_idx]["context"]
111
+ st.session_state.question = selected_dataset[sample_idx]["question"]
112
+
113
+ if 'context' in st.session_state and st.session_state.context is not None:
114
+ context = st.text_area(label="Context", key="context", height=300, value=st.session_state.context)
115
+ question = st.text_input(label="Question", key="question", value=st.session_state.question)
116
+ else:
117
+ context = st.text_area(label="Context", height=300, placeholder="متن مورد نظر را اینجا وارد کنید ...")
118
+ question = st.text_input(label="Question", placeholder="سوال خود از متن را اینجا بپرسید ...")
119
+
120
+ submitted = st.form_submit_button("Get Answer")
121
+ if submitted or ('submitted' in st.session_state and st.session_state.submitted):
122
+ st.session_state.submitted = False
123
+ selected_prediction = selected_model(question=question, context=context)["answer"]
124
+ st.text_area(label=f"Answer ({selected_model_name}):", value=selected_prediction if selected_prediction!="" else "بدون پاسخ")