af1tang commited on
Commit
d6697da
1 Parent(s): b5dcaaa

first private commit

Browse files
README.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - conversational
4
+ license: gpl-3.0
5
+ ---
6
+ ## A conversational agent with many personalities (PersonaGPT)
7
+ PersonaGPT is an open-domain conversational agent capable of decoding _personalized_ responses based on input .
8
+ It builds on the [DialoGPT-medium](https://huggingface.co/microsoft/DialoGPT-medium) pretrained model based on the [GPT-2](https://github.com/openai/gpt-2) architecture.
9
+ This model is trained on the [Persona-Chat](https://arxiv.org/pdf/1801.07243) dataset, with added special tokens to better distinguish between conversational history and personality traits for dyadic conversations. Furthermore, some active learning was used to train the model to do _controlled_ decoding based on certain "action codes" (e.g., "talk about work", "ask about favorite music").
10
+
11
+
12
+ ## Full Repo
13
+
14
+ Preprocessing, training and implementation details can be found in the [personaGPT repo](https://github.com/af1tang/personaGPT).
15
+
16
+ ### How to Use
17
+
18
+
19
+ 1. Load the model and define some helper functions.
20
+
21
+ ```python
22
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
23
+ import torch
24
+ tokenizer = AutoTokenizer.from_pretrained("af1tang/personaGPT")
25
+ model = AutoModelForCausalLM.from_pretrained("af1tang/personaGPT")
26
+ if torch.cuda.is_available():
27
+ model = model.cuda()
28
+ ## utility functions ##
29
+ flatten = lambda l: [item for sublist in l for item in sublist]
30
+
31
+ def to_data(x):
32
+ if torch.cuda.is_available():
33
+ x = x.cpu()
34
+ return x.data.numpy()
35
+
36
+ def to_var(x):
37
+ if not torch.is_tensor(x):
38
+ x = torch.Tensor(x)
39
+ if torch.cuda.is_available():
40
+ x = x.cuda()
41
+ return x
42
+
43
+ def display_dialog_history(dialog_hx):
44
+ for j, line in enumerate(dialog_hx):
45
+ msg = tokenizer.decode(line)
46
+ if j %2 == 0:
47
+ print(">> User: "+ msg)
48
+ else:
49
+ print("Bot: "+msg)
50
+ print()
51
+
52
+ def generate_next(bot_input_ids, do_sample=True, top_k=10, top_p=.92,
53
+ max_length=1000, pad_token=tokenizer.eos_token_id):
54
+ full_msg = model.generate(bot_input_ids, do_sample=True,
55
+ top_k=top_k, top_p=top_p,
56
+ max_length=max_length, pad_token_id=tokenizer.eos_token_id)
57
+ msg = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:]
58
+ return msg
59
+ ```
60
+
61
+ 2. Give your chatbot partner a set of personalities.
62
+
63
+
64
+ ```python
65
+ # get personality facts for conversation
66
+ personas = []
67
+ for i in range(3):
68
+ response = input(">> Fact %d: "%(i+1))+ tokenizer.eos_token
69
+ personas.append(response)
70
+ personas = tokenizer.encode(''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>']))
71
+ ```
72
+
73
+ 3. The first use of PersonaGPT is to do _personalized_ dialog generation. Use the following loop to interact with the model.
74
+
75
+ ```python
76
+ # converse for 8 turns
77
+ dialog_hx = []
78
+ for step in range(8):
79
+ # encode the user input
80
+ user_inp = tokenizer.encode(input(">> User: ") + tokenizer.eos_token)
81
+ # append to the chat history
82
+ dialog_hx.append(user_inp)
83
+
84
+ # generated a response while limiting the total chat history to 1000 tokens,
85
+ bot_input_ids = to_var([personas + flatten(dialog_hx)]).long()
86
+ msg = generate_next(bot_input_ids)
87
+ dialog_hx.append(msg)
88
+ print("Bot: {}".format(tokenizer.decode(msg, skip_special_tokens=True)))
89
+ ```
90
+
91
+
92
+ Example of personalized decoding:
93
+
94
+ | | Persona Facts |
95
+ |---|---|
96
+ | 1 | i am 64 years old but feel quite young. |
97
+ | 2 | i try to eat healthy but limit mcdonalds to once a week. |
98
+ | 3 | i regret working as a doctor for the last 20 years. |
99
+ | 4 | my secret hobby is making self-help youtube videos. |
100
+ | 5 | i've been to spain many times and wish i could go again. |
101
+
102
+ | Role | Response|
103
+ |---|---|
104
+ | Human | yo whats good fam? |
105
+ | Bot | hello, how are you doing? |
106
+ | Human | not bad what you doing today?|
107
+ | Bot | eating a lot healthier now that i'm in my 60s |
108
+ | Human | i'm more of a fast food guy |
109
+ | Bot | me too but i'm very healthy so i try to avoid it |
110
+ | Human | what do you do for fun? |
111
+ | Bot | i make videos for fun and eat very healthy |
112
+ | Human | nice what kind of videos? |
113
+ | Bot | i love making videos about self help stuff |
114
+ | Human | i think i need some help |
115
+ | Bot | what kind of help would you like to get |
116
+ | Human | i'm 21 and feel lost in life |
117
+ | Bot | i can relate, i'm only 65 |
118
+
119
+
120
+ 4. The second use of PersonaGPT is to do _controlled response generation_. Use the following program to interact with the model.
121
+
122
+ ```python
123
+ ## available actions ##
124
+ action_space = [ 'ask about kids.', "ask about pets.", 'talk about work.',
125
+ 'ask about marital status.', 'talk about travel.', 'ask about age and gender.',
126
+ 'ask about hobbies.', 'ask about favorite food.', 'talk about movies.',
127
+ 'talk about music.', 'talk about politics.']
128
+ # converse for 8 turns
129
+ dialog_hx = []
130
+ for step in range(8):
131
+ # choose an action
132
+ act = None
133
+ while act not in action_space:
134
+ display_dialog_history(dialog_hx)
135
+ print()
136
+ print(" actions: ")
137
+ for k,v in enumerate(action_space): print(k,v)
138
+ try:
139
+ act = action_space[int(input(" input [0-10]: " ))]
140
+ except:
141
+ act = None
142
+ print()
143
+ # format into prefix code
144
+ action_prefix = tokenizer.encode(''.join(['<|act|> '] + [act] + ['<|p1|>'] + [] + ['<|sep|>'] + ['<|start|>']))
145
+ bot_input_ids = to_var([action_prefix + flatten(dialog_hx)]).long()
146
+
147
+ # generate query conditioned on action
148
+ msg = generate_next(bot_input_ids, top_k=top_k, top_p=top_p, max_length=max_length)
149
+ dialog_hx.append(msg)
150
+
151
+ # generate bot response
152
+ bot_input_ids = to_var([personas+ flatten(dialog_hx)]).long()
153
+ msg = generate_next(bot_input_ids, top_k=top_k, top_p=top_p, max_length=max_length)
154
+ dialog_hx.append(msg)
155
+ display_dialog_history(dialog_hx)
156
+ ```
157
+
158
+
159
+ Example of controlled response generation:
160
+
161
+
162
+ | | Bot 2's Persona Facts |
163
+ |---|---|
164
+ | 1 | i am a quiet engineer. |
165
+ | 2 | i'm single and am looking for love. |
166
+ | 3 | sadly, i don't have any relatable hobbies. |
167
+ | 4 | luckily, however, i am tall and athletic. |
168
+ | 5 | on friday nights, i watch re-runs of the simpsons alone. |
169
+
170
+ | Role | Action | Response |
171
+ |---|---|---|
172
+ | Bot 1 | talk about movies | hello do you like movies? |
173
+ | Bot 2 | - | i love the simpsons. how about you? |
174
+ | Bot 1 | talk about music | i like the walking dead. do you play an instrument? |
175
+ | Bot 2 | - | yes i play the violin. |
176
+ | Bot 1 | ask about marital status | are you married or single? |
177
+ | Bot 2 | - | i am single but i am looking for love. |
178
+ | Bot 1 | talk about work | my job is working in a factory. what's your occupation? |
179
+ | Bot 2 | - | engineer. i'm very quiet so no one hears me. |
180
+ | Bot 1 | talk about hobbies | do you have any hobbies? |
181
+ | Bot 2 | - | i watch reruns of the simpsons. |
182
+ | Bot 1 | ask about favorite food | what's your favorite food? |
183
+ | Bot 2 | - | i love pizza. how about yourself? |
184
+ | Bot 1 | ask about pets | i also love pizza. do you like animals? |
185
+ | Bot 2 | - | i have two dogs. what is your occupation? |
186
+ | Bot 1 | talk about work | i'm a factory worker. what's your dream job? |
187
+ | Bot 2 | - | i'd love to be a writer one day. |
added_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<|sep|>": 50257, "<|cls|>": 50258, "<|start|>": 50259, "<|p1|>": 50260, "<|p2|>": 50261}
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "gelu_new",
3
+ "architectures": [
4
+ "GPT2LMHeadModel"
5
+ ],
6
+ "attn_pdrop": 0.1,
7
+ "bos_token_id": 50256,
8
+ "embd_pdrop": 0.1,
9
+ "eos_token_id": 50256,
10
+ "initializer_range": 0.02,
11
+ "layer_norm_epsilon": 1e-05,
12
+ "model_type": "gpt2",
13
+ "n_ctx": 1024,
14
+ "n_embd": 1024,
15
+ "n_head": 16,
16
+ "n_layer": 24,
17
+ "n_positions": 1024,
18
+ "resid_pdrop": 0.1,
19
+ "summary_activation": null,
20
+ "summary_first_dropout": 0.1,
21
+ "summary_proj_to_labels": true,
22
+ "summary_type": "cls_index",
23
+ "summary_use_proj": true,
24
+ "vocab_size": 50262
25
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7582d23d2d19e37339c15685b171dce5e6ea98787a82b40ff3c4ad8011a87a6
3
+ size 1444551925
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "sep_token": "<|sep|>", "pad_token": "<|endoftext|>", "cls_token": "<|cls|>", "additional_special_tokens": ["<|start|>", "<|p1|>", "<|p2|>"]}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"pad_token": "<|endoftext|>", "cls_token": "<|cls|>", "sep_token": "<|sep|>", "special_tokens_map_file": null, "full_tokenizer_file": null}
vocab.json ADDED
The diff for this file is too large to render. See raw diff