import streamlit as st from transformers import AutoTokenizer from transformers import GPT2LMHeadModel from transformers import set_seed import meta from normalizer import normalize from utils import load_json from utils import local_css EXAMPLES = load_json("examples.json") CK = "" QK = "Q:" AK = "A:" class TextGeneration: def __init__(self): self.debug = False self.dummy_output = "Destiny's Child" self.tokenizer = None self.model = None self.model_name_or_path = "m3hrdadfi/gpt2-QA" self.length_margin = 100 set_seed(42) def load(self): if not self.debug: self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) self.model = GPT2LMHeadModel.from_pretrained(self.model_name_or_path) def generate(self, prompt, generation_kwargs): if not self.debug: input_ids = self.tokenizer([prompt], return_tensors="pt")["input_ids"] max_length = len(input_ids[0]) + self.length_margin max_length = min(max_length, 1024) generation_kwargs["max_length"] = max_length generated = self.model.generate( input_ids, **generation_kwargs, )[0] answer = self.tokenizer.decode(generated, skip_special_tokens=True) found = answer.find(f"{AK}") if not found: return "" answer = [a.strip() for a in answer[found:].split(f"{AK}") if a.strip()] answer = answer[0] if len(answer) > 0 else "" return answer return self.dummy_output @st.cache(allow_output_mutation=True) def load_text_generator(): generator = TextGeneration() generator.load() return generator def main(): st.set_page_config( page_title="GPT2 QA", page_icon="⁉️", layout="wide", initial_sidebar_state="expanded" ) local_css("assets/style.css") generator = load_text_generator() st.sidebar.markdown(meta.SIDEBAR_INFO) num_beams = st.sidebar.slider( label='Number of Beam', help="Number of beams for beam search", min_value=4, max_value=15, value=5, step=1 ) repetition_penalty = st.sidebar.slider( label='Repetition Penalty', help="The parameter for repetition penalty", min_value=1.0, max_value=3.0, value=1.0, step=0.1 ) length_penalty = st.sidebar.slider( label='Length Penalty', help="Exponential penalty to the length", min_value=0.0, max_value=2.0, value=1.0, step=0.1 ) early_stopping = st.sidebar.selectbox( label='Early Stopping ?', options=(True, False), help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not", ) generation_kwargs = { "num_beams": num_beams, "early_stopping": early_stopping, "repetition_penalty": repetition_penalty, "length_penalty": length_penalty, } st.markdown(meta.HEADER_INFO) prompts = [e["title"] for e in EXAMPLES] + ["Custom"] prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1) if prompt == "Custom": prompt_box = { "context": meta.C_PROMPT_BOX, "question": meta.Q_PROMPT_BOX, "answers": [meta.A_PROMPT_BOX], } else: prompt_box = next(e for e in EXAMPLES if e["title"] == prompt) context = st.text_area("Enter context", prompt_box["context"], height=200) question = st.text_area("Enter question", prompt_box["question"], height=100) answer = "Ground Truth Answers: " + \ "".join([f"{answer}" for answer in prompt_box["answers"]]) st.markdown( f'

' f'{answer}' f'

', unsafe_allow_html=True ) generation_kwargs_ph = st.empty() if st.button("Find the answer 🔎 "): with st.spinner(text="Searching ..."): generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()])) context = normalize(context) question = normalize(question) if context and question: text = f"{context} {QK} {question} {AK}" generated_answer = generator.generate(text, generation_kwargs) generated_answer = f"{AK} {generated_answer}".strip() context = f"{CK} {context}".strip() question = f"{QK} {question}".strip() st.markdown( f'

' f'{context}

' f'{question}

' f'{generated_answer} ' f'

', unsafe_allow_html=True ) if __name__ == '__main__': main()