taisazero
added logo
bb5609c
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import streamlit as st
@st.cache(allow_output_mutation=True, max_entries=1)
def get_model():
tokenizer = AutoTokenizer.from_pretrained("SoLID/sgd-response-generator")
model = AutoModelForSeq2SeqLM.from_pretrained("SoLID/sgd-response-generator")
return (model, tokenizer)
def lexicalize_plan(
model, tokenizer, output_plan, temperature=1.0, num_beams=1
):
input_ids = tokenizer(output_plan, return_tensors="pt").input_ids
output = model.generate(
input_ids,
max_length=512,
do_sample=True,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
temperature=temperature,
num_beams=int(num_beams),
)
output_str = tokenizer.decode(output[0], skip_special_tokens=True).strip()
return output_str
def run():
st.set_page_config(page_title="Schema Guided Dialogue Response Generation")
# sidebar
st.sidebar.title("SGD Response Generator Demo")
st.sidebar.image(
"logo.png",
caption="UNCC & RPI Logos",
)
st.sidebar.markdown("### Controls:")
temperature = st.sidebar.slider(
"Temperature",
min_value=0.5,
max_value=1.5,
value=0.8,
step=0.1,
)
num_beams = st.sidebar.slider(
"Num beams",
min_value=1,
max_value=4,
step=1,
value = 2,
)
# main body
model, tokenizer = get_model()
output_plan = st.text_area("Output Plan: ", value = "[AC:Request [IN:FindRestaurants [SL:location] ] ] [AC:Request [IN:FindRestaurants [SL:category] ] ]", help ="Type in the output plan used by the system to generate a response in English.")
submit_button = st.button("Generate Response")
if submit_button:
text = st.text("Generating Response...")
response = lexicalize_plan (model, tokenizer, output_plan, temperature, num_beams)
text.empty()
st.write("Generated Response: " + str(response))
if __name__ == "__main__":
run()