j.gilyazev commited on
Commit
deb7fd3
1 Parent(s): c1c5bd9

add personalized-chat-bot

Browse files
personalized-chat-bot/bot_example.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+
3
+ import argparse
4
+ import json
5
+
6
+ from petals.client.remote_model import DistributedBloomForCausalLM
7
+
8
+ from personalized_chat_bot import PersonalizedChatBot, PersonalityManager
9
+ from models.personality_clustering import PersonalityClustering
10
+
11
+ def load_config(path):
12
+ with open(path, 'r') as f:
13
+ config = json.load(f)
14
+ return argparse.Namespace(**config)
15
+
16
+
17
+ def main():
18
+ greating = 'Describe the person you want to talk:'
19
+ print(greating)
20
+ persona_description = input()
21
+ print('Cool! wait a few seconds...')
22
+ personality_clustering = PersonalityClustering()
23
+ personality_clustering.load('./data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl')
24
+
25
+ hook = lambda dct: {int(k): v for k, v in dct.items()}
26
+ with open('prompt_paths.json', 'r') as f:
27
+ prompt_paths = json.load(f, object_hook=hook)
28
+
29
+ pm = PersonalityManager(prompt_paths, personality_clustering)
30
+ prompt_path, closest_persona = pm.get_prompt(persona_description)
31
+ print(f'The closest personality is: {closest_persona}')
32
+ print('Wait a little longer...')
33
+ config = load_config('./scripts/config_176b.json')
34
+
35
+ model = DistributedBloomForCausalLM.from_pretrained(
36
+ config.MODEL_NAME,
37
+ pre_seq_len=config.NUM_PREFIX_TOKENS,
38
+ tuning_mode=config.TUNING_MODE
39
+ ).to(config.DEVICE)
40
+
41
+ generation_config = load_config('generation_config.json')
42
+
43
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
44
+ tokenizer.padding_side = 'right'
45
+ tokenizer.model_max_length = config.MODEL_MAX_LENGTH
46
+
47
+ chatbot = PersonalizedChatBot(model, tokenizer, generation_config=generation_config)
48
+ chatbot.load_prompt(prompt_path)
49
+ print('Done! You can start a dialogue.')
50
+ try:
51
+ while True:
52
+ text = input('You: ')
53
+ answer = chatbot.answer(text)
54
+ print(f'Bloom: {answer}')
55
+ except KeyboardInterrupt:
56
+ print('Thank you for the conversation!')
57
+
58
+
59
+ if __name__ == '__main__':
60
+ main()
personalized-chat-bot/personalized_chat_bot.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import torch
4
+ from sklearn.neighbors import KDTree
5
+
6
+
7
+ class PersonalityManager:
8
+ def __init__(self, prompt_paths, personality_clustering):
9
+ self.prompt_paths = prompt_paths
10
+ self.personality_clustering = personality_clustering
11
+
12
+ self.persona_ids = list(prompt_paths.keys())
13
+ self.personalities = [personality_clustering._cluster_centers[i]
14
+ for i in self.persona_ids]
15
+
16
+ self.embeddings = personality_clustering.sentence_transformer.encode(self.personalities)
17
+ self._nearest_neighbours = KDTree(self.embeddings, metric='euclidean')
18
+
19
+ def get_prompt(self, description):
20
+ embedding = self.personality_clustering.sentence_transformer.encode([description])
21
+ dist, ind = self._nearest_neighbours.query(embedding, k=1)
22
+ persona_id = self.persona_ids[ind[0][0]]
23
+ prompt_path = self.prompt_paths[persona_id]
24
+ cluster_center = self.personality_clustering._cluster_centers[persona_id]
25
+ return prompt_path, cluster_center
26
+
27
+
28
+ class PersonalizedChatBot:
29
+ def __init__(self, model, tokenizer, prompt_path=None, generation_config=None):
30
+ self.model = model
31
+ if prompt_path is not None:
32
+ self.load_prompt(prompt_path)
33
+ self.tokenizer = tokenizer
34
+ self.separator = '\n'
35
+ self.dialog = ''
36
+ self.generation_config = generation_config
37
+
38
+ def load_prompt(self, path):
39
+ self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))
40
+
41
+ def load_config(self, path):
42
+ with open(path, 'r') as f:
43
+ config = json.load(f)
44
+ self.generation_config = argparse.Namespace(**config)
45
+
46
+ def reset_dialog(self, ):
47
+ self.dialog = ''
48
+
49
+ def answer(self, phrase):
50
+ if len(phrase) == 0:
51
+ return
52
+ self.dialog += f"{phrase}{self.separator}"
53
+ inputs = self.tokenizer([self.dialog], return_tensors='pt')['input_ids']
54
+ outputs = self.model.generate(
55
+ inputs,
56
+ temperature=self.generation_config.TEMPERATURE,
57
+ do_sample=True,
58
+ top_k=self.generation_config.TOP_K,
59
+ eos_token_id=self.tokenizer.eos_token_id,
60
+ max_new_tokens=self.generation_config.MAX_TOKENS,
61
+ )
62
+ bloom_answer = self.tokenizer.batch_decode(outputs)[0]
63
+ bloom_answer = bloom_answer[len(self.dialog):].split("\n")[0]
64
+ self.dialog += f"{bloom_answer}{self.separator}"
65
+ return bloom_answer