Spaces:
Runtime error
Runtime error
from logging import PlaceHolder | |
from re import sub | |
import streamlit as st | |
import imp, time, random | |
import base64 | |
import io | |
import nbformat | |
from PIL import Image | |
from datasets import load_from_disk, load_dataset | |
import os | |
from transformers import pipeline | |
st.set_page_config(layout="wide") | |
def set_submitted_true(): | |
st.session_state.submitted = True | |
st.markdown(""" | |
<style> | |
input, .rtl { | |
unicode-bidi:bidi-override; | |
direction: RTL; | |
} | |
textarea, .rtl { | |
unicode-bidi:bidi-override; | |
direction: RTL; | |
} | |
h2, .rtl { | |
unicode-bidi:bidi-override; | |
direction: RTL; | |
} | |
div[role=tablist], .rtl { | |
unicode-bidi:bidi-override; | |
direction: RTL; | |
} | |
div[role=alert], .rtl { | |
unicode-bidi:bidi-override; | |
direction: RTL; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
latest_iteration = st.empty() | |
bar = st.progress(0) | |
st.markdown("## سیستم پرسش و پاسخ فارسی") | |
st.markdown("") | |
tab1, tab2 = st.tabs(["دمو", "مستندات"]) | |
datasets_names_addresses = {"small-persian-QA": "Hamid-reza/small-persian-QA", | |
"addsent-small-persian-QA": "Hamid-reza/Adv-small-persian-QA", | |
"addany-small-persian-QA": "mohammadhossein/addany-dataset", | |
"back-translation-small-persian-QA": "jalalnb/back_translation_hy_on_small_persian_QA", | |
"invisible-char-small-persian-QA": "jalalnb/invisible_char_on_small_persian_QA"} | |
def load_datasets(datasets_names_addresses): | |
return {dataset_name: load_dataset(dataset_address)["validation"] | |
for dataset_name, dataset_address in datasets_names_addresses.items()} | |
datasets_names_content = load_datasets(datasets_names_addresses) | |
selected_dataset_name = st.sidebar.radio( | |
':دیتاست مورد نظر خود را انتخاب نمایید', | |
list(datasets_names_addresses.keys())) | |
selected_dataset = datasets_names_content[selected_dataset_name] | |
models_names_addresses = {"mbert": ("arashmarioriyad/mbert_v3", "arashmarioriyad/mbert_tokenizer_v3"), | |
"parsbert": ("arashmarioriyad/parsbert_v1", "arashmarioriyad/parsbert_tokenizer_v1"), | |
"addsent-mbert": ("arashmarioriyad/addsent_mbert_v1", "arashmarioriyad/addsent_mbert_tokenizer_v1"), | |
"addsent-parsbert": ("arashmarioriyad/addsent_parsbert_v1", "arashmarioriyad/addsent_parsbert_tokenizer_v1"), | |
"addany-mbert": ("arashmarioriyad/addany_mbert_v1", "arashmarioriyad/addany_mbert_tokenizer_v1"), | |
"addany-parsbert": ("arashmarioriyad/addany_parsbert_v1", "arashmarioriyad/addany_parsbert_tokenizer_v1"), | |
"back-translation-mbert": ("arashmarioriyad/bt_hy_mbert_v1", "arashmarioriyad/bt_hy_mbert_tokenizer_v1"), | |
"back-translation-parsbert": ("arashmarioriyad/bt_hy_parsbert_v1", "arashmarioriyad/bt_hy_parsbert_tokenizer_v1"), | |
"invisible-char-mbert": ("arashmarioriyad/ic_mbert_v1", "arashmarioriyad/ic_mbert_tokenizer_v1"), | |
"invisible-char-parsbert": ("arashmarioriyad/ic_parsbert_v1", "arashmarioriyad/ic_parsbert_tokenizer_v1")} | |
def load_models(models_names_addresses): | |
return {model_name: pipeline("question-answering", | |
model=models_names_addresses[model_name][0], | |
tokenizer=models_names_addresses[model_name][1]) | |
for model_name, model_address in models_names_addresses.items()} | |
models_names_contents = load_models(models_names_addresses) | |
selected_model_name = st.sidebar.radio( | |
':مدل مورد نظر خود را انتخاب نمایید', | |
list(models_names_addresses.keys())) | |
selected_model = models_names_contents[selected_model_name] | |
st.sidebar.info("تمامی دادگان، کد ها و نتایج ارزیابی مدل ها در [صفحه گیت هاب پروژه](https://github.com/NLP-Final-Projects/Adversarial-QA/) قابل دسترسی است", icon="ℹ️") | |
with tab1.form("my_form", clear_on_submit=False): | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
generate_random_data = st.form_submit_button("تولید دادهی تصادفی") | |
if generate_random_data: | |
sample_idx = random.randrange(len(selected_dataset)) | |
st.session_state.context = selected_dataset[sample_idx]["context"] | |
st.session_state.question = selected_dataset[sample_idx]["question"] | |
if 'context' in st.session_state and st.session_state.context is not None: | |
context = st.text_area(label="Context", key="context", height=300, value=st.session_state.context) | |
question = st.text_input(label="Question", key="question", value=st.session_state.question) | |
else: | |
context = st.text_area(label="Context", height=300, placeholder="متن مورد نظر را اینجا وارد کنید ...") | |
question = st.text_input(label="Question", placeholder="سوال خود از متن را اینجا بپرسید ...") | |
submitted = st.form_submit_button("Get Answer") | |
if submitted or ('submitted' in st.session_state and st.session_state.submitted): | |
st.session_state.submitted = False | |
selected_prediction = selected_model(question=question, context=context)["answer"] | |
st.text_area(label=f"Answer ({selected_model_name}):", value=selected_prediction if selected_prediction!="" else "بدون پاسخ") | |