Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from transformers import pipeline, set_seed | |
from transformers import AutoTokenizer | |
from transformers import GPT2LMHeadModel | |
from mtranslate import translate | |
import random | |
import meta | |
from normalizer import normalize | |
from utils import ( | |
remote_css, | |
local_css, | |
load_json | |
) | |
EXAMPLES = load_json("examples.json") | |
CK = "متن" | |
QK = "پرسش" | |
AK = "پاسخ" | |
class TextGeneration: | |
def __init__(self): | |
self.debug = False | |
self.dummy_output = "مخلوطی از ایتالیایی و انگلیسی" | |
self.tokenizer = None | |
self.model = None | |
self.model_name_or_path = "m3hrdadfi/gpt2-persian-qa" | |
self.length_margin = 100 | |
set_seed(42) | |
def load(self): | |
if not self.debug: | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) | |
self.model = GPT2LMHeadModel.from_pretrained(self.model_name_or_path) | |
def generate(self, prompt, generation_kwargs): | |
if not self.debug: | |
input_ids = self.tokenizer([prompt], return_tensors="pt")["input_ids"] | |
max_length = len(input_ids[0]) + self.length_margin | |
generation_kwargs["max_length"] = max_length | |
generated = self.model.generate( | |
input_ids, | |
**generation_kwargs, | |
)[0] | |
answer = self.tokenizer.decode(generated, skip_special_tokens=True) | |
found = answer.find(f"{AK}: ") | |
if not found: | |
return "" | |
answer = [a.strip() for a in answer[found:].split(f"{AK}: ") if a.strip()] | |
answer = answer[0] if len(answer) > 0 else "" | |
return answer | |
return self.dummy_output | |
def load_text_generator(): | |
generator = TextGeneration() | |
generator.load() | |
return generator | |
def main(): | |
st.set_page_config( | |
page_title="GPT2 QA - Persian", | |
page_icon="⁉️", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
remote_css("https://cdn.jsdelivr.net/gh/rastikerdar/vazir-font/dist/font-face.css") | |
local_css("assets/rtl.css") | |
generator = load_text_generator() | |
st.sidebar.markdown(meta.SIDEBAR_INFO) | |
num_beams = st.sidebar.slider( | |
label='Number of Beam', | |
help="Number of beams for beam search", | |
min_value=4, | |
max_value=15, | |
value=5, | |
step=1 | |
) | |
repetition_penalty = st.sidebar.slider( | |
label='Repetition Penalty', | |
help="The parameter for repetition penalty", | |
min_value=1.0, | |
max_value=10.0, | |
value=1.0, | |
step=0.1 | |
) | |
length_penalty = st.sidebar.slider( | |
label='Length Penalty', | |
help="Exponential penalty to the length", | |
min_value=1.0, | |
max_value=10.0, | |
value=1.0, | |
step=0.1 | |
) | |
early_stopping = st.sidebar.selectbox( | |
label='Early Stopping ?', | |
options=(True, False), | |
help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not", | |
) | |
translated = st.sidebar.selectbox( | |
label='Translation ?', | |
options=(True, False), | |
help="Will translate the result in English", | |
) | |
generation_kwargs = { | |
"num_beams": num_beams, | |
"early_stopping": early_stopping, | |
"repetition_penalty": repetition_penalty, | |
"length_penalty": length_penalty, | |
} | |
st.markdown(meta.HEADER_INFO) | |
prompts = [e["title"] for e in EXAMPLES] + ["Custom"] | |
prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1) | |
if prompt == "Custom": | |
prompt_box = { | |
"context": meta.C_PROMPT_BOX, | |
"question": meta.Q_PROMPT_BOX, | |
"answer": meta.A_PROMPT_BOX, | |
} | |
else: | |
prompt_box = next(e for e in EXAMPLES if e["title"] == prompt) | |
context = st.text_area("Enter context", prompt_box["context"], height=250) | |
question = st.text_area("Enter question", prompt_box["question"], height=100) | |
answer = "پاسخ درست: " + prompt_box["answer"] | |
st.markdown( | |
f'<p class="rtl rtl-box">' | |
f'{answer}' | |
f'<p>', | |
unsafe_allow_html=True | |
) | |
if translated: | |
translated_answer = translate(answer, "en", "fa") | |
st.markdown( | |
f'<p class="ltr">' | |
f'{translated_answer}' | |
f'<p>', | |
unsafe_allow_html=True | |
) | |
generation_kwargs_ph = st.empty() | |
if st.button("Find the answer 🔎 "): | |
with st.spinner(text="Searching ..."): | |
generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()])) | |
context = normalize(context) | |
question = normalize(question) | |
if context and question: | |
text = f"{context} {QK}: {question} {AK}:" | |
generated_answer = generator.generate(text, generation_kwargs) | |
generated_answer = f"{AK}: {generated_answer}".strip() | |
context = f"{CK}: {context}".strip() | |
question = f"{QK}: {question}".strip() | |
st.markdown( | |
f'<p class="rtl rtl-box">' | |
f'<span class="result-text">{context}<span><br/><br/>' | |
f'<span class="result-text">{question}<span><br/><br/>' | |
f'<span class="result-text generated-text">{generated_answer} </span>' | |
f'</p>', | |
unsafe_allow_html=True | |
) | |
if translated: | |
translated_context = translate(context, "en", "fa") | |
translated_question = translate(question, "en", "fa") | |
translated_generated_answer = translate(generated_answer, "en", "fa") | |
st.markdown( | |
f'<p class="ltr ltr-box">' | |
f'<span class="result-text">{translated_context}<span><br/><br/>' | |
f'<span class="result-text">{translated_question}<span><br/><br/>' | |
f'<span class="result-text generated-text">{translated_generated_answer}</span>' | |
f'</p>', | |
unsafe_allow_html=True | |
) | |
if __name__ == '__main__': | |
main() | |