armandstrickernlp commited on
Commit
c2c676c
β€’
1 Parent(s): 1448a3a

first commit

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +4 -4
  3. app.py +143 -0
  4. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Gpt2-TOD App
3
- emoji: 🐒
4
- colorFrom: purple
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.0.24
8
  app_file: app.py
 
1
  ---
2
+ title: GPT2 multi-TOD
3
+ emoji: 🌎
4
+ colorFrom: pink
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.0.24
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #tuto : https://gradio.app/creating_a_chatbot/
2
+
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+ import re
6
+
7
+ ckpt = 'armandnlp/gpt2-TOD_finetuned_SGD'
8
+ tokenizer = AutoTokenizer.from_pretrained(ckpt)
9
+ model = AutoModelForCausalLM.from_pretrained(ckpt)
10
+
11
+
12
+
13
+ def format_resp(system_resp):
14
+ # format Belief, Action and Response tags
15
+ system_resp = system_resp.replace('<|belief|>', '*Belief State: ')
16
+ system_resp = system_resp.replace('<|action|>', '*Actions: ')
17
+ system_resp = system_resp.replace('<|response|>', '*System Response: ')
18
+ return system_resp
19
+
20
+
21
+
22
+ def predict(input, history=[]):
23
+
24
+ if history != []:
25
+ # model expects only user and system responses, no belief or action sequences
26
+ # therefore we clean up the history first.
27
+
28
+ # history is a list of token ids which represents all the previous states in the conversation
29
+ # ie. tokenied user inputs + tokenized model outputs
30
+ history_str = tokenizer.decode(history[0])
31
+ turns = re.split('<\|system\|>|<\|user\|>', history_str)[1:]
32
+ for i in range(0, len(turns)-1, 2):
33
+ turns[i] = '<|user|>' + turns[i]
34
+ # keep only the response part of each system_out in the history (no belief and action)
35
+ turns[i+1] = '<|system|>' + turns[i+1].split('<|response|>')[1]
36
+ history4input = tokenizer.encode(''.join(turns), return_tensors='pt')
37
+ else:
38
+ history4input = torch.LongTensor(history)
39
+
40
+ # format input for model by concatenating <|context|> + history4input + new_input + <|endofcontext|>
41
+ new_user_input_ids = tokenizer.encode(' <|user|> '+input, return_tensors='pt')
42
+ context = tokenizer.encode('<|context|>', return_tensors='pt')
43
+ endofcontext = tokenizer.encode(' <|endofcontext|>', return_tensors='pt')
44
+ model_input = torch.cat([context, history4input, new_user_input_ids, endofcontext], dim=-1)
45
+
46
+ # generate output
47
+ out = model.generate(model_input, max_length=1024, eos_token_id=50262).tolist()[0]
48
+
49
+ # formatting the history
50
+ # leave out endof... tokens
51
+ string_out = tokenizer.decode(out)
52
+ system_out = string_out.split('<|endofcontext|>')[1].replace('<|endofbelief|>', '').replace('<|endofaction|>', '').replace('<|endofresponse|>', '')
53
+ resp_tokenized = tokenizer.encode(' <|system|> '+system_out, return_tensors='pt')
54
+ history = torch.cat([torch.LongTensor(history), new_user_input_ids, resp_tokenized], dim=-1).tolist()
55
+ # history = history + last user input + <|system|> <|belief|> ... <|action|> ... <|response|>...
56
+
57
+ # format responses to print out
58
+ # need to output all of the turns, hence why the history must contain belief + action info
59
+ # even if we have to take it out of the model input
60
+ turns = tokenizer.decode(history[0])
61
+ turns = re.split('<\|system\|>|<\|user\|>', turns)[1:] # list of all the user and system turns until now
62
+ # list of tuples [(user, system), (user, system)...]
63
+ # 1 tuple represents 1 exchange at 1 turn
64
+ # system resp is formatted with function above to make more readable
65
+ resps = [(turns[i], format_resp(turns[i+1])) for i in range(0, len(turns)-1, 2)]
66
+
67
+ return resps, history
68
+
69
+
70
+
71
+ examples = [["I want to book a restaurant for 2 people on Saturday."],
72
+ ["What's the weather in Cambridge today ?"],
73
+ ["I need to find a bus to Boston."],
74
+ ["I want to add an event to my calendar."],
75
+ ["I would like to book a plane ticket to New York."],
76
+ ["I want to find a concert around LA."],
77
+ ["Hi, I'd like to find an apartment in London please."],
78
+ ["Can you find me a hotel room near Seattle please ?"],
79
+ ["I want to watch a film online, a comedy would be nice"],
80
+ ["I want to transfer some money please."],
81
+ ["I want to reserve a movie ticket for tomorrow evening"],
82
+ ["Can you play the song Learning to Fly by Tom Petty ?"],
83
+ ["I need to rent a small car."]
84
+ ]
85
+
86
+ description = """
87
+ This is an interactive window to chat with GPT-2 fine-tuned on the Schema-Guided Dialogues dataset,
88
+ in which we find domains such as travel, weather, media, calendar, banking,
89
+ restaurant booking...
90
+ """
91
+
92
+ article = """
93
+ ### Model Outputs
94
+ This task-oriented dialogue system is trained end-to-end, following the method detailed in
95
+ [SimpleTOD](https://arxiv.org/pdf/2005.00796.pdf), where GPT-2 is trained by casting task-oriented
96
+ dialogue as a seq2seq task.
97
+
98
+ From the dialogue history, composed of the previous user and system responses, the model is trained
99
+ to output the belief state, the action decisions and the system response as a sequence. We show all
100
+ three outputs in this demo : the belief state tracks the user goal (restaurant cuisine : Indian or media
101
+ genre : comedy for ex.), the action decisions show how the system should proceed (restaurants request city
102
+ or media offer title for ex.) and the natural language response provides an output the user can interpret.
103
+
104
+ The model responses are *de-lexicalized* : database values in the training set have been replaced with their
105
+ slot names to make the learning process database agnostic. These slots are meant to later be replaced by actual
106
+ results from a database, using the belief state to issue calls.
107
+
108
+ The model is capable of dealing with multiple domains : a list of possible inputs is provided to get the
109
+ conversation going.
110
+
111
+ ### Dataset
112
+ The SGD dataset ([blogpost](https://ai.googleblog.com/2019/10/introducing-schema-guided-dialogue.html) and
113
+ [article](https://arxiv.org/pdf/1909.05855.pdf)) contains multiple task domains... Here is a list of some
114
+ of the services and their descriptions from the dataset:
115
+ * **Restaurants** : A leading provider for restaurant search and reservations
116
+ * **Weather** : Check the weather for any place and any date
117
+ * **Buses** : Find a bus to take you to the city you want
118
+ * **Calendar** : Calendar service to manage personal events and reservations
119
+ * **Flights** : Find your next flight
120
+ * **Events** : Get tickets for the coolest concerts and sports in your area
121
+ * **Homes** : A widely used service for finding apartments and scheduling visits
122
+ * **Hotels** : A popular service for searching and reserving rooms in hotels
123
+ * **Media** : A leading provider of movies for searching and watching on-demand
124
+ * **Banks** : Manage bank accounts and transfer money
125
+ * **Movies** : A go-to provider for finding movies, searching for show times and booking tickets
126
+ * **Music** : A popular provider of a wide range of music content for searching and listening
127
+ * **RentalCars** : Car rental service with extensive coverage of locations and cars
128
+ """
129
+
130
+
131
+ import gradio as gr
132
+
133
+ gr.Interface(fn=predict,
134
+ inputs=["text", "state"],
135
+ outputs=["chatbot", "state"],
136
+ title="Chatting with multi task-oriented GPT2",
137
+ examples=examples,
138
+ description=description,
139
+ article=article
140
+ ).launch()
141
+
142
+
143
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch