gpt2-persian-qa / app.py
m3hrdadfi's picture
Hello gpt2-qa
f474d6a
raw
history blame
6.35 kB
import streamlit as st
import torch
from transformers import pipeline, set_seed
from transformers import AutoTokenizer
from transformers import GPT2LMHeadModel
from mtranslate import translate
import random
import meta
from normalizer import normalize
from utils import (
remote_css,
local_css,
load_json
)
EXAMPLES = load_json("examples.json")
CK = "متن"
QK = "پرسش"
AK = "پاسخ"
class TextGeneration:
def __init__(self):
self.debug = False
self.dummy_output = "مخلوطی از ایتالیایی و انگلیسی"
self.tokenizer = None
self.model = None
self.model_name_or_path = "m3hrdadfi/gpt2-persian-qa"
self.length_margin = 100
set_seed(42)
def load(self):
if not self.debug:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self.model = GPT2LMHeadModel.from_pretrained(self.model_name_or_path)
def generate(self, prompt, generation_kwargs):
if not self.debug:
input_ids = self.tokenizer([prompt], return_tensors="pt")["input_ids"]
max_length = len(input_ids[0]) + self.length_margin
generation_kwargs["max_length"] = max_length
generated = self.model.generate(
input_ids,
**generation_kwargs,
)[0]
answer = self.tokenizer.decode(generated, skip_special_tokens=True)
found = answer.find(f"{AK}: ")
if not found:
return ""
answer = [a.strip() for a in answer[found:].split(f"{AK}: ") if a.strip()]
answer = answer[0] if len(answer) > 0 else ""
return answer
return self.dummy_output
@st.cache(allow_output_mutation=True)
def load_text_generator():
generator = TextGeneration()
generator.load()
return generator
def main():
st.set_page_config(
page_title="GPT2 QA - Persian",
page_icon="⁉️",
layout="wide",
initial_sidebar_state="expanded"
)
remote_css("https://cdn.jsdelivr.net/gh/rastikerdar/vazir-font/dist/font-face.css")
local_css("assets/rtl.css")
generator = load_text_generator()
st.sidebar.markdown(meta.SIDEBAR_INFO)
num_beams = st.sidebar.slider(
label='Number of Beam',
help="Number of beams for beam search",
min_value=4,
max_value=15,
value=5,
step=1
)
repetition_penalty = st.sidebar.slider(
label='Repetition Penalty',
help="The parameter for repetition penalty",
min_value=1.0,
max_value=10.0,
value=1.0,
step=0.1
)
length_penalty = st.sidebar.slider(
label='Length Penalty',
help="Exponential penalty to the length",
min_value=1.0,
max_value=10.0,
value=1.0,
step=0.1
)
early_stopping = st.sidebar.selectbox(
label='Early Stopping ?',
options=(True, False),
help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not",
)
translated = st.sidebar.selectbox(
label='Translation ?',
options=(True, False),
help="Will translate the result in English",
)
generation_kwargs = {
"num_beams": num_beams,
"early_stopping": early_stopping,
"repetition_penalty": repetition_penalty,
"length_penalty": length_penalty,
}
st.markdown(meta.HEADER_INFO)
prompts = [e["title"] for e in EXAMPLES] + ["Custom"]
prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
if prompt == "Custom":
prompt_box = {
"context": meta.C_PROMPT_BOX,
"question": meta.Q_PROMPT_BOX,
"answer": meta.A_PROMPT_BOX,
}
else:
prompt_box = next(e for e in EXAMPLES if e["title"] == prompt)
context = st.text_area("Enter context", prompt_box["context"], height=250)
question = st.text_area("Enter question", prompt_box["question"], height=100)
answer = "پاسخ درست: " + prompt_box["answer"]
st.markdown(
f'<p class="rtl rtl-box">'
f'{answer}'
f'<p>',
unsafe_allow_html=True
)
if translated:
translated_answer = translate(answer, "en", "fa")
st.markdown(
f'<p class="ltr">'
f'{translated_answer}'
f'<p>',
unsafe_allow_html=True
)
generation_kwargs_ph = st.empty()
if st.button("Find the answer 🔎 "):
with st.spinner(text="Searching ..."):
generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
context = normalize(context)
question = normalize(question)
if context and question:
text = f"{context} {QK}: {question} {AK}:"
generated_answer = generator.generate(text, generation_kwargs)
generated_answer = f"{AK}: {generated_answer}".strip()
context = f"{CK}: {context}".strip()
question = f"{QK}: {question}".strip()
st.markdown(
f'<p class="rtl rtl-box">'
f'<span class="result-text">{context}<span><br/><br/>'
f'<span class="result-text">{question}<span><br/><br/>'
f'<span class="result-text generated-text">{generated_answer} </span>'
f'</p>',
unsafe_allow_html=True
)
if translated:
translated_context = translate(context, "en", "fa")
translated_question = translate(question, "en", "fa")
translated_generated_answer = translate(generated_answer, "en", "fa")
st.markdown(
f'<p class="ltr ltr-box">'
f'<span class="result-text">{translated_context}<span><br/><br/>'
f'<span class="result-text">{translated_question}<span><br/><br/>'
f'<span class="result-text generated-text">{translated_generated_answer}</span>'
f'</p>',
unsafe_allow_html=True
)
if __name__ == '__main__':
main()