taisazero commited on
Commit
9390326
1 Parent(s): e464f34

added webapp

Browse files
Files changed (2) hide show
  1. app.py +64 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+
3
+
4
+
5
+ @st.cache(allow_output_mutation=True, max_entries=1)
6
+ def get_model():
7
+ tokenizer = AutoTokenizer.from_pretrained("SoLID/sgd-response-generator")
8
+ model = AutoModelForSeq2SeqLM.from_pretrained("SoLID/sgd-response-generator")
9
+ return (model, tokenizer)
10
+
11
+
12
+ def lexicalize_plan(
13
+ model, tokenizer, output_plan, temperature=1.0, num_beams=1
14
+ ):
15
+
16
+ input_ids = tokenizer(output_plan, return_tensors="pt").input_ids
17
+ output = model.generate(
18
+ input_ids,
19
+ max_length=512,
20
+ do_sample=True,
21
+ top_p=0.95,
22
+ pad_token_id=tokenizer.pad_token_id,
23
+ eos_token_id=tokenizer.eos_token_id,
24
+ early_stopping=True,
25
+ temperature=temperature,
26
+ num_beams=int(num_beams),
27
+ )
28
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True).strip()
29
+ return output_str
30
+
31
+ def run():
32
+ st.set_page_config(page_title="Schema Guided Dialogue Response Generation")
33
+ # sidebar
34
+ st.sidebar.title("SGD Response Generator Demo")
35
+ st.sidebar.image(
36
+ "https://aeiljuispo.cloudimg.io/v7/https://s3.amazonaws.com/moonup/production/uploads/1628568174585-6049d8edbaa99e90d94ee67c.png",
37
+ caption="SoLID at UNCC Logo",
38
+ )
39
+ st.sidebar.markdown("### Controls:")
40
+ temperature = st.sidebar.slider(
41
+ "Temperature",
42
+ min_value=0.5,
43
+ max_value=1.5,
44
+ value=0.8,
45
+ step=0.1,
46
+ )
47
+ num_beams = st.sidebar.slider(
48
+ "Num beams",
49
+ min_value=1,
50
+ max_value=4,
51
+ step=1,
52
+ )
53
+ # main body
54
+ model, tokenizer = get_model()
55
+ output_plan = st.text_input("Output Plan: ", value = "[AC:Request [IN:FindRestaurants [SL:location] ] ] [AC:Request [IN:FindRestaurants [SL:category] ] ]", help ="Type in the output plan used by the system to generate a response in English.")
56
+ submit_button = st.button("Generate Response")
57
+ if submit_button:
58
+ text = st.text("Generating Response...")
59
+ response = lexicalize_plan (model, tokenizer, output_plan, temperature, num_beams)
60
+ text.empty()
61
+ st.text("Generated Response: " + str(response))
62
+
63
+ if __name__ == "__main__":
64
+ run()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ torch
2
+ transformers