File size: 2,112 Bytes
9390326
7743799
9390326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb5609c
 
9390326
 
 
 
 
 
 
 
 
 
 
 
 
 
c767df5
9390326
 
 
87e16bf
9390326
 
 
 
 
87e16bf
9390326
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import streamlit as st  


@st.cache(allow_output_mutation=True, max_entries=1)
def get_model():
    tokenizer = AutoTokenizer.from_pretrained("SoLID/sgd-response-generator")
    model = AutoModelForSeq2SeqLM.from_pretrained("SoLID/sgd-response-generator")
    return (model, tokenizer)


def lexicalize_plan(
    model, tokenizer, output_plan, temperature=1.0, num_beams=1
):
    
    input_ids = tokenizer(output_plan, return_tensors="pt").input_ids
    output = model.generate(
        input_ids,
        max_length=512,
        do_sample=True,
        top_p=0.95,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        early_stopping=True,
        temperature=temperature,
        num_beams=int(num_beams),
    )
    output_str = tokenizer.decode(output[0], skip_special_tokens=True).strip()
    return output_str

def run():
    st.set_page_config(page_title="Schema Guided Dialogue Response Generation")
    # sidebar
    st.sidebar.title("SGD Response Generator Demo")
    st.sidebar.image(
        "logo.png",
        caption="UNCC & RPI Logos",
    )
    st.sidebar.markdown("### Controls:")
    temperature = st.sidebar.slider(
        "Temperature",
        min_value=0.5,
        max_value=1.5,
        value=0.8,
        step=0.1,
    )
    num_beams = st.sidebar.slider(
        "Num beams",
        min_value=1,
        max_value=4,
        step=1,
	value = 2,
    )
    # main body
    model, tokenizer = get_model()
    output_plan = st.text_area("Output Plan: ", value = "[AC:Request [IN:FindRestaurants [SL:location] ] ] [AC:Request [IN:FindRestaurants [SL:category] ] ]", help ="Type in the output plan used by the system to generate a response in English.")
    submit_button = st.button("Generate Response")
    if submit_button:
        text = st.text("Generating Response...")
        response = lexicalize_plan (model, tokenizer, output_plan, temperature, num_beams)
        text.empty()
        st.write("Generated Response: " + str(response))

if __name__ == "__main__":
    run()