from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import streamlit as st @st.cache(allow_output_mutation=True, max_entries=1) def get_model(): tokenizer = AutoTokenizer.from_pretrained("SoLID/sgd-response-generator") model = AutoModelForSeq2SeqLM.from_pretrained("SoLID/sgd-response-generator") return (model, tokenizer) def lexicalize_plan( model, tokenizer, output_plan, temperature=1.0, num_beams=1 ): input_ids = tokenizer(output_plan, return_tensors="pt").input_ids output = model.generate( input_ids, max_length=512, do_sample=True, top_p=0.95, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, early_stopping=True, temperature=temperature, num_beams=int(num_beams), ) output_str = tokenizer.decode(output[0], skip_special_tokens=True).strip() return output_str def run(): st.set_page_config(page_title="Schema Guided Dialogue Response Generation") # sidebar st.sidebar.title("SGD Response Generator Demo") st.sidebar.image( "logo.png", caption="UNCC & RPI Logos", ) st.sidebar.markdown("### Controls:") temperature = st.sidebar.slider( "Temperature", min_value=0.5, max_value=1.5, value=0.8, step=0.1, ) num_beams = st.sidebar.slider( "Num beams", min_value=1, max_value=4, step=1, value = 2, ) # main body model, tokenizer = get_model() output_plan = st.text_area("Output Plan: ", value = "[AC:Request [IN:FindRestaurants [SL:location] ] ] [AC:Request [IN:FindRestaurants [SL:category] ] ]", help ="Type in the output plan used by the system to generate a response in English.") submit_button = st.button("Generate Response") if submit_button: text = st.text("Generating Response...") response = lexicalize_plan (model, tokenizer, output_plan, temperature, num_beams) text.empty() st.write("Generated Response: " + str(response)) if __name__ == "__main__": run()