gosha6037 commited on
Commit
62851f3
2 Parent(s): b649ec8 8ef0aaa

Merge branch 'bloom-personachat' of https://huggingface.co/spaces/hivemind-personalized-chat/chat-gradio

Browse files
personalized-chat-bot/bot_example.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+
3
+ import argparse
4
+ import json
5
+
6
+ from petals.client.remote_model import DistributedBloomForCausalLM
7
+
8
+ from personalized_chat_bot import PersonalizedChatBot, PersonalityManager
9
+ from models.personality_clustering import PersonalityClustering
10
+
11
+ def load_config(path):
12
+ with open(path, 'r') as f:
13
+ config = json.load(f)
14
+ return argparse.Namespace(**config)
15
+
16
+
17
+ def main():
18
+ greating = 'Describe the person you want to talk:'
19
+ print(greating)
20
+ persona_description = input()
21
+ print('Cool! wait a few seconds...')
22
+ personality_clustering = PersonalityClustering()
23
+ personality_clustering.load('./data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl')
24
+
25
+ hook = lambda dct: {int(k): v for k, v in dct.items()}
26
+ with open('prompt_paths.json', 'r') as f:
27
+ prompt_paths = json.load(f, object_hook=hook)
28
+
29
+ pm = PersonalityManager(prompt_paths, personality_clustering)
30
+ prompt_path, closest_persona = pm.get_prompt(persona_description)
31
+ print(f'The closest personality is: {closest_persona}')
32
+ print('Wait a little longer...')
33
+ config = load_config('./scripts/config_176b.json')
34
+
35
+ model = DistributedBloomForCausalLM.from_pretrained(
36
+ config.MODEL_NAME,
37
+ pre_seq_len=config.NUM_PREFIX_TOKENS,
38
+ tuning_mode=config.TUNING_MODE
39
+ ).to(config.DEVICE)
40
+
41
+ generation_config = load_config('generation_config.json')
42
+
43
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
44
+ tokenizer.padding_side = 'right'
45
+ tokenizer.model_max_length = config.MODEL_MAX_LENGTH
46
+
47
+ chatbot = PersonalizedChatBot(model, tokenizer, generation_config=generation_config)
48
+ chatbot.load_prompt(prompt_path)
49
+ print('Done! You can start a dialogue.')
50
+ try:
51
+ while True:
52
+ text = input('You: ')
53
+ answer = chatbot.answer(text)
54
+ print(f'Bloom: {answer}')
55
+ except KeyboardInterrupt:
56
+ print('Thank you for the conversation!')
57
+
58
+
59
+ if __name__ == '__main__':
60
+ main()
personalized-chat-bot/data.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d73016d5eccc0eeb641f623789e6a80c601572aee825603bdfacf84c9e8f705
3
+ size 12635714
personalized-chat-bot/generation_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"MAX_TOKENS": 16, "TOP_K": 100, "TEMPERATURE": 0.8}
personalized-chat-bot/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # coding=utf-8
personalized-chat-bot/models/personality_clustering.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sentence_transformers import SentenceTransformer
3
+ from sklearn.cluster import KMeans
4
+ import pickle
5
+
6
+
7
+ class PersonalityClustering:
8
+ DEFAULT_SENTENCE_TRANSFORMER = 'paraphrase-MiniLM-L6-v2'
9
+
10
+ @property
11
+ def sentence_transformer(self):
12
+ """Ленивая инициализация sentence_transformer."""
13
+ if not self.__sentence_transformer:
14
+ self.__sentence_transformer = SentenceTransformer(self.model_name, device=self.device)
15
+ return self.__sentence_transformer
16
+
17
+ @property
18
+ def clustering(self):
19
+ """Ленивая инициализация кластеризации."""
20
+ if not self.__clustering:
21
+ self.__clustering = KMeans(n_clusters=self.n_clusters)
22
+ return self.__clustering
23
+
24
+ def __init__(self, n_clusters=None, device='cpu', model_name=None):
25
+ if model_name is None:
26
+ self.model_name = self.DEFAULT_SENTENCE_TRANSFORMER
27
+ else:
28
+ self.model_name = model_name
29
+ self.device = device
30
+ self.n_clusters = n_clusters
31
+ self._cluster_centers = None
32
+ self.__clustering = None
33
+ self.__sentence_transformer = None
34
+
35
+ def load(self, path):
36
+ with open(path, "rb") as f:
37
+ self.__clustering, self._cluster_centers = pickle.load(f)
38
+
39
+ def save(self, path):
40
+ with open(path, "wb") as f:
41
+ pickle.dump((self.__clustering, self._cluster_centers), f)
42
+
43
+ def fit(self, personalities):
44
+ personalities = np.array(list(personalities))
45
+ train_embeddings = self.sentence_transformer.encode(personalities)
46
+ clusters = self.clustering.fit_predict(train_embeddings)
47
+ persona_cluster_centers = []
48
+ for clust, center in enumerate(self.clustering.cluster_centers_):
49
+ cur_clust_embed = train_embeddings[clusters == clust]
50
+ cur_clust_personalities = personalities[clusters == clust]
51
+ min_distance_to_center = np.inf
52
+ persona_center = None
53
+ for embed, persona in zip(cur_clust_embed, cur_clust_personalities):
54
+ cur_distance_to_center = np.linalg.norm(embed - center)
55
+ if cur_distance_to_center < min_distance_to_center:
56
+ min_distance_to_center = cur_distance_to_center
57
+ persona_center = persona
58
+ persona_cluster_centers.append(persona_center)
59
+ self._cluster_centers = np.array(persona_cluster_centers)
60
+ return self
61
+
62
+ def predict(self, personalities):
63
+ personalities = np.array(list(personalities))
64
+ embeddings = self.sentence_transformer.encode(personalities)
65
+ clusters = self.clustering.predict(embeddings)
66
+ return clusters
67
+
68
+ def predict_nearest_personality(self, personalities):
69
+ clusters = self.predict(personalities)
70
+ return np.array([self._cluster_centers[clust] for clust in clusters])
71
+
72
+ def fit_predict(self, personalities):
73
+ self.fit(personalities)
74
+ return self.predict(personalities)
personalized-chat-bot/personalized_chat_bot.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import torch
4
+ from sklearn.neighbors import KDTree
5
+
6
+
7
+ class PersonalityManager:
8
+ def __init__(self, prompt_paths, personality_clustering):
9
+ self.prompt_paths = prompt_paths
10
+ self.personality_clustering = personality_clustering
11
+
12
+ self.persona_ids = list(prompt_paths.keys())
13
+ self.personalities = [personality_clustering._cluster_centers[i]
14
+ for i in self.persona_ids]
15
+
16
+ self.embeddings = personality_clustering.sentence_transformer.encode(self.personalities)
17
+ self._nearest_neighbours = KDTree(self.embeddings, metric='euclidean')
18
+
19
+ def get_prompt(self, description):
20
+ embedding = self.personality_clustering.sentence_transformer.encode([description])
21
+ dist, ind = self._nearest_neighbours.query(embedding, k=1)
22
+ persona_id = self.persona_ids[ind[0][0]]
23
+ prompt_path = self.prompt_paths[persona_id]
24
+ cluster_center = self.personality_clustering._cluster_centers[persona_id]
25
+ return prompt_path, cluster_center
26
+
27
+
28
+ class PersonalizedChatBot:
29
+ def __init__(self, model, tokenizer, prompt_path=None, generation_config=None):
30
+ self.model = model
31
+ if prompt_path is not None:
32
+ self.load_prompt(prompt_path)
33
+ self.tokenizer = tokenizer
34
+ self.separator = '\n'
35
+ self.dialog = ''
36
+ self.generation_config = generation_config
37
+
38
+ def load_prompt(self, path):
39
+ self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))
40
+
41
+ def load_config(self, path):
42
+ with open(path, 'r') as f:
43
+ config = json.load(f)
44
+ self.generation_config = argparse.Namespace(**config)
45
+
46
+ def reset_dialog(self, ):
47
+ self.dialog = ''
48
+
49
+ def answer(self, phrase):
50
+ if len(phrase) == 0:
51
+ return
52
+ self.dialog += f"{phrase}{self.separator}"
53
+ inputs = self.tokenizer([self.dialog], return_tensors='pt')['input_ids']
54
+ outputs = self.model.generate(
55
+ inputs,
56
+ temperature=self.generation_config.TEMPERATURE,
57
+ do_sample=True,
58
+ top_k=self.generation_config.TOP_K,
59
+ eos_token_id=self.tokenizer.eos_token_id,
60
+ max_new_tokens=self.generation_config.MAX_TOKENS,
61
+ )
62
+ bloom_answer = self.tokenizer.batch_decode(outputs)[0]
63
+ bloom_answer = bloom_answer[len(self.dialog):].split("\n")[0]
64
+ self.dialog += f"{bloom_answer}{self.separator}"
65
+ return bloom_answer
personalized-chat-bot/prompt_paths.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "113": "./data/models/176b/113_persona_prompt_embedding.pt",
3
+ "54": "./data/models/176b/54_persona_prompt_embedding.pt",
4
+ "169": "./data/models/176b/169_persona_prompt_embedding.pt",
5
+ "364": "./data/models/176b/364_persona_prompt_embedding.pt",
6
+ "214": "./data/models/176b/214_persona_prompt_embedding.pt",
7
+ "125": "./data/models/176b/125_persona_prompt_embedding.pt",
8
+ "103": "./data/models/176b/103_persona_prompt_embedding.pt",
9
+ "200": "./data/models/176b/200_persona_prompt_embedding.pt",
10
+ "296": "./data/models/176b/296_persona_prompt_embedding.pt",
11
+ "20": "./data/models/176b/20_persona_prompt_embedding.pt",
12
+ "384": "./data/models/176b/384_persona_prompt_embedding.pt",
13
+ "365": "./data/models/176b/365_persona_prompt_embedding.pt",
14
+ "451": "./data/models/176b/451_persona_prompt_embedding.pt",
15
+ "80": "./data/models/176b/80_persona_prompt_embedding.pt"
16
+ }
personalized-chat-bot/scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # coding=utf-8
personalized-chat-bot/scripts/config_176b.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "PERSONACHAT_DATASET_NAME": "bavard/personachat_truecased",
3
+ "MODEL_NAME": "bigscience/bloom-petals",
4
+ "INITIAL_PEERS": [],
5
+ "NUM_PREFIX_TOKENS": 16,
6
+ "DEVICE": "cpu",
7
+ "BATCH_SIZE": 4,
8
+ "LR": 0.01,
9
+ "WEIGHT_DECAY": 0.0,
10
+ "NUM_SAMPLES": 1000,
11
+ "SEED": 42,
12
+ "MODEL_MAX_LENGTH": 256,
13
+ "TUNING_MODE": "ptune",
14
+ "N_EPOCH": 10,
15
+ "PADDING_SIDE": "right"
16
+ }
personalized-chat-bot/scripts/config_6b.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "PERSONACHAT_DATASET_NAME": "bavard/personachat_truecased",
3
+ "MODEL_NAME": "bigscience/test-bloomd-6b3",
4
+ "INITIAL_PEERS":["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"],
5
+ "NUM_PREFIX_TOKENS": 16,
6
+ "DEVICE": "cpu",
7
+ "BATCH_SIZE": 4,
8
+ "LR": 0.01,
9
+ "WEIGHT_DECAY": 0.0,
10
+ "NUM_SAMPLES": 1000,
11
+ "SEED": 42,
12
+ "MODEL_MAX_LENGTH": 256,
13
+ "TUNING_MODE": "ptune",
14
+ "N_EPOCH": 1,
15
+ "PADDING_SIDE": "right"
16
+ }
personalized-chat-bot/scripts/fit_personality_clustering.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datasets import load_dataset
3
+ from models.personality_clustering import PersonalityClustering
4
+ import os
5
+
6
+ """Пример запуска
7
+ python -m scripts.fit_personality_clustering --clustering-path data/models --n-clusters 500
8
+ """
9
+
10
+ PERSONACHAT_DATASET = "bavard/personachat_truecased"
11
+
12
+
13
+ def load_persona_chat_personalities(personachat_dataset):
14
+ dataset = load_dataset(personachat_dataset)
15
+ train_personalities = [sent for persona in dataset['train']['personality']
16
+ for sent in persona]
17
+ test_personalities = [sent for persona in dataset['train']['personality']
18
+ for sent in persona]
19
+ personalities = list(set(train_personalities) | set(test_personalities))
20
+ return personalities
21
+
22
+
23
+ def parse_args(args=None):
24
+ parser = argparse.ArgumentParser(add_help=True, description="Class for personality clustering.")
25
+
26
+ parser.add_argument('-clustering-path', '--clustering-path', type=str,
27
+ help='Path to clustering data.')
28
+ parser.add_argument('-n-clusters', '--n-clusters', type=int, default=500,
29
+ help='The number of clusters to form.')
30
+ parser.add_argument('-model-name', '--model-name', type=str, default=None, required=False)
31
+ args = parser.parse_args(args)
32
+ return args
33
+
34
+
35
+ def main():
36
+ args = parse_args()
37
+ personalities = load_persona_chat_personalities(PERSONACHAT_DATASET)
38
+ print('Data loaded')
39
+ model = PersonalityClustering(n_clusters=args.n_clusters)
40
+ print('Model fitting')
41
+ model.fit(personalities)
42
+ print('Model fitted')
43
+ if args.model_name is None:
44
+ model_name = f'personality_clustering_{model.n_clusters}_{model.model_name}_k-means.pkl'
45
+ else:
46
+ model_name = args.model_name
47
+ model.save(os.path.join(args.clustering_path, model_name))
48
+ print(f'{model_name} saved')
49
+
50
+
51
+ if __name__ == '__main__':
52
+ main()
personalized-chat-bot/scripts/train_all.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #python -m scripts.train_bloom_personachat --persona-ids 113 54 169 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
4
+ #python -m scripts.train_bloom_personachat --persona-ids 364 214 125 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
5
+ #python -m scripts.train_bloom_personachat --persona-ids 103 200 296 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
6
+ #python -m scripts.train_bloom_personachat --persona-ids 20 384 365 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
7
+ #python -m scripts.train_bloom_personachat --persona-ids 208 43 99 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
8
+ #python -m scripts.train_bloom_personachat --persona-ids 426 477 470 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
9
+ python -m scripts.train_bloom_personachat --persona-ids 470 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
10
+
11
+ python -m scripts.train_bloom_personachat --persona-ids 329 402 382 --config scripts/config_176b.json --prompt-path data/models/176b/ --wandb-project bloom_personachat_176b
personalized-chat-bot/scripts/train_bloom_personachat.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch.cuda
4
+ from datasets import load_dataset
5
+ import json
6
+ import os
7
+ import transformers
8
+ from torch.utils.data import Subset
9
+ import wandb
10
+ import numpy as np
11
+ import gc
12
+
13
+ from models.personality_clustering import PersonalityClustering
14
+ from util.bloom_trainer import BloomTrainer
15
+ from util.data import PersonaChatDataset
16
+ from util.metrics import perplexity
17
+
18
+ from petals.client.remote_model import DistributedBloomForCausalLM
19
+
20
+ """Пример запуска
21
+ python -m scripts.train_bloom_personachat --persona-ids 6 --config scripts/config.json --prompt-path data/models/
22
+ """
23
+
24
+ DEFAULT_CLUSTERING_MODEL = './data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl'
25
+ MAX_VAL_DATA_SIZE = 4
26
+
27
+
28
+ def load_config(path):
29
+ with open(path, 'r') as f:
30
+ config = json.load(f)
31
+ return argparse.Namespace(**config)
32
+
33
+
34
+ def main():
35
+ args = parse_args()
36
+ persona_clustering = PersonalityClustering()
37
+ persona_clustering.load(args.clustering_model_path)
38
+
39
+ config = load_config(args.config)
40
+
41
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
42
+ tokenizer.padding_side = config.PADDING_SIDE
43
+ tokenizer.model_max_length = config.MODEL_MAX_LENGTH
44
+
45
+ dataset = load_dataset(config.PERSONACHAT_DATASET_NAME)
46
+ personachat_train_dataset = PersonaChatDataset(persona_clustering,
47
+ dataset['train'],
48
+ tokenizer)
49
+ personachat_val_dataset = PersonaChatDataset(persona_clustering,
50
+ dataset['validation'],
51
+ tokenizer)
52
+
53
+ for id in args.persona_ids:
54
+ prompt_path = os.path.join(args.prompt_path, f'{id}_persona_prompt_embedding.pt')
55
+ train_dataset = personachat_train_dataset[id]
56
+ val_dataset = personachat_val_dataset[id]
57
+ honest_validation = True
58
+ if len(val_dataset) < 4:
59
+ val_dataset = personachat_train_dataset[id]
60
+ honest_validation = False
61
+ # для ускорения обрежем размер валидации до некоторой границы
62
+ if len(val_dataset) > MAX_VAL_DATA_SIZE:
63
+ subset_indexes = np.random.choice(len(val_dataset), MAX_VAL_DATA_SIZE, replace=False)
64
+ val_dataset = Subset(val_dataset, subset_indexes)
65
+ # train_dataset.shuffle()
66
+
67
+ wandb_run = wandb.init(
68
+ project=args.wandb_project,
69
+ config={
70
+ 'lr': config.LR,
71
+ 'batch_size': config.BATCH_SIZE,
72
+ 'persona_id': id,
73
+ 'device': config.DEVICE,
74
+ 'model_name': config.MODEL_NAME,
75
+ 'n_epoch': config.N_EPOCH,
76
+ 'honest_validation': honest_validation
77
+ },
78
+ name=f'id{id}',
79
+ reinit=True
80
+ )
81
+ if len(config.INITIAL_PEERS) == 0:
82
+ model = DistributedBloomForCausalLM.from_pretrained(
83
+ config.MODEL_NAME,
84
+ pre_seq_len=config.NUM_PREFIX_TOKENS,
85
+ tuning_mode=config.TUNING_MODE
86
+ ).to(config.DEVICE)
87
+ else:
88
+ model = DistributedBloomForCausalLM.from_pretrained(
89
+ config.MODEL_NAME,
90
+ initial_peers=config.INITIAL_PEERS,
91
+ pre_seq_len=config.NUM_PREFIX_TOKENS,
92
+ tuning_mode=config.TUNING_MODE
93
+ ).to(config.DEVICE)
94
+
95
+ trainer = BloomTrainer(model, config, train_dataset, val_dataset, wandb_run, prompt_path)
96
+ trainer.train()
97
+ eval_perplexity = trainer.evaluate(perplexity)
98
+ trainer.save_model(prompt_path)
99
+ wandb_run.log({'perplexity': eval_perplexity, 'model_path': prompt_path})
100
+
101
+ del model
102
+ gc.collect()
103
+ torch.cuda.empty_cache()
104
+
105
+
106
+ def parse_args(args=None):
107
+ parser = argparse.ArgumentParser(add_help=True,
108
+ description="bloom training script")
109
+ parser.add_argument('--persona-ids', type=int, nargs='+',
110
+ help='Ids of persona')
111
+ parser.add_argument('-clustering-model-path', '--clustering-model-path', type=str,
112
+ default=DEFAULT_CLUSTERING_MODEL,
113
+ help='Path to clustering model')
114
+ parser.add_argument('--config', type=str, help='Path to training config file')
115
+ parser.add_argument('--prompt-path', type=str,
116
+ help='Path to dir with trained soft prompts')
117
+ parser.add_argument('--wandb-project', type=str, default='test_bloom_personachat_176b_v3')
118
+ args = parser.parse_args(args)
119
+ return args
120
+
121
+
122
+ if __name__ == '__main__':
123
+ main()
personalized-chat-bot/util/__init__.py ADDED
File without changes
personalized-chat-bot/util/bloom_trainer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ from torch.utils.data import DataLoader
5
+ from torch.optim import AdamW
6
+ from transformers import get_scheduler
7
+ import torch
8
+
9
+
10
+ from util.metrics import perplexity
11
+
12
+
13
+ class BloomTrainer:
14
+ DEFAULT_VAL_FREQ = 5
15
+ ITERATION_LIMIT = 150
16
+
17
+ def __init__(self, model, config, train_dataset, val_dataset, wandb_run=None, prompt_path=None, val_freq=None):
18
+ self.model = model
19
+ self.config = config
20
+ self.train_dataset = train_dataset
21
+ self.val_dataset = val_dataset
22
+ self.wandb_run = wandb_run
23
+ self.val_freq = val_freq
24
+ if self.val_freq is None:
25
+ self.val_freq = self.DEFAULT_VAL_FREQ
26
+ self.prompt_path = prompt_path
27
+
28
+ self.best_loss = np.inf
29
+
30
+ self.train_loader = DataLoader(self.train_dataset,
31
+ shuffle=True,
32
+ batch_size=config.BATCH_SIZE,
33
+ drop_last=True)
34
+ self.val_loader = DataLoader(self.val_dataset,
35
+ shuffle=True,
36
+ batch_size=config.BATCH_SIZE,
37
+ drop_last=False)
38
+
39
+ self.optimizer = AdamW(self.model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
40
+
41
+ self.lr_scheduler = get_scheduler(
42
+ name="linear",
43
+ optimizer=self.optimizer,
44
+ num_warmup_steps=0,
45
+ num_training_steps= len(self.train_loader) * self.config.N_EPOCH
46
+ )
47
+
48
+ def train(self):
49
+ self.model.train()
50
+ iter_counter = 0
51
+ for epoch in range(self.config.N_EPOCH):
52
+ for batch in self.train_loader:
53
+ batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
54
+ 'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
55
+ outputs = self.model(**batch)
56
+ loss = outputs.loss
57
+ loss.backward()
58
+ self.optimizer.step()
59
+ self.lr_scheduler.step()
60
+ self.optimizer.zero_grad()
61
+ self.wandb_run.log({'loss': loss})
62
+ iter_counter += 1
63
+ if (iter_counter + 1) % self.val_freq == 0:
64
+ eval_perplexity = self.evaluate(perplexity)
65
+ self.wandb_run.log({'perplexity': eval_perplexity})
66
+ if loss.item() < self.best_loss:
67
+ self.best_loss = loss.item()
68
+ self.save_model(self.prompt_path)
69
+ print('Model saved')
70
+ if iter_counter >= self.ITERATION_LIMIT:
71
+ return
72
+
73
+ def evaluate(self, eval_fn):
74
+ logits = []
75
+ labels = []
76
+ self.model.eval()
77
+ with torch.no_grad():
78
+ for batch in self.val_loader:
79
+ batch = {'input_ids': torch.stack(batch['input_ids']).T.to(self.config.DEVICE),
80
+ 'labels': torch.stack(batch['labels']).T.to(self.config.DEVICE)}
81
+ outputs = self.model(**batch)
82
+ labels.extend(batch['input_ids'])
83
+ logits.extend(outputs.logits)
84
+ metric = eval_fn(logits, labels)
85
+ return metric
86
+
87
+ def save_model(self, path):
88
+ torch.save(self.model.transformer.prompt_embeddings.state_dict(), path)
89
+
90
+ def load_model(self, path):
91
+ self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))
personalized-chat-bot/util/data.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from joblib import Parallel, delayed
5
+
6
+
7
+ class OnePersonaDataset(Dataset):
8
+ def __init__(self, data, tokenizer, transforms=None, positive_candidates=True, n_jobs=8):
9
+ super().__init__()
10
+
11
+ self.data = data
12
+ if len(data) == 0:
13
+ self.input_ids = []
14
+ self.history = []
15
+ self.labels = []
16
+ return
17
+
18
+ if positive_candidates:
19
+ self.history = [row['history'] + [row['candidates'][-1], ] for row in data]
20
+ self.labels = np.ones(len(self.history), dtype=int)
21
+ else:
22
+ self.history = [row['history'] + [candidate, ] for row in data
23
+ for candidate in row['candidates']]
24
+ self.labels = itertools.chain.from_iterable([0] * (len(row['candidates']) - 1) + [1]
25
+ for row in data)
26
+ self.labels = np.array(self.labels, dtype=int)
27
+
28
+ if transforms is None:
29
+ self.history = ["\n".join(item) for item in self.history]
30
+ else:
31
+ self.history = Parallel(n_jobs=n_jobs)(delayed(transforms)(item) for item in self.history)
32
+ self.input_ids = tokenizer(self.history, padding='max_length', truncation=True)["input_ids"]
33
+
34
+ def __getitem__(self, idx):
35
+ return {'input_ids': self.input_ids[idx],
36
+ 'labels': self.input_ids[idx],
37
+ 'example': self.history[idx],
38
+ 'class': self.labels[idx]}
39
+
40
+ def __len__(self):
41
+ return len(self.data)
42
+
43
+
44
+ class PersonaChatDataset(Dataset):
45
+ DEFAULT_DATASET_NAME = "bavard/personachat_truecased"
46
+
47
+ def __init__(self, clustering, dataset, tokenizer):
48
+ super().__init__()
49
+
50
+ self.dataset = dataset
51
+ self.clustering = clustering
52
+
53
+ all_personalities = list(set([sent for item in self.dataset
54
+ for sent in item['personality']]))
55
+ predicted_centers = self.clustering.predict(all_personalities)
56
+ self.all_personalities_to_id = {persona: center
57
+ for persona, center in zip(all_personalities, predicted_centers)}
58
+ self.personalities = self.clustering._cluster_centers
59
+
60
+ subdataset_data_by_personality = [[] for _ in range(len(self.personalities))]
61
+
62
+ for i in range(len(self.dataset)):
63
+ item = self.dataset[i]
64
+ cur_persona_ids = [self.all_personalities_to_id[persona] for persona in item['personality']]
65
+ for persona_id in cur_persona_ids:
66
+ subdataset_data_by_personality[persona_id].append(item)
67
+
68
+ self.subdatasets = [OnePersonaDataset(cur_data, tokenizer) for cur_data in subdataset_data_by_personality]
69
+
70
+ def __getitem__(self, persona_id):
71
+ return self.subdatasets[persona_id]
72
+
73
+ def __len__(self, ):
74
+ return len(self.datasets)
personalized-chat-bot/util/dialogue_manager.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DistilBertForSequenceClassification
2
+ from torch import nn
3
+
4
+ class DialogueManagerModel(nn.Module):
5
+ DEFAULT_MODEL = "distilbert-base-uncased"
6
+
7
+ def __init__(self, n_classes, model_name=None, device='cpu'):
8
+ super().__init__()
9
+ if model_name is None:
10
+ self.model = DistilBertForSequenceClassification.from_pretrained(self.DEFAULT_MODEL)
11
+ else:
12
+ raise NotImplementedError()
13
+ self.model.to(device)
14
+ self.n_classes = n_classes
15
+ self.freeze_layers()
16
+ self.model.classifier = nn.Linear(self.model.classifier.in_features, self.n_classes,
17
+ device=device)
18
+
19
+ for param in self.model.classifier.parameters():
20
+ param.requires_grad = True
21
+
22
+ def freeze_layers(self):
23
+ for param in self.model.parameters():
24
+ param.requires_grad = False
25
+
26
+ def forward(self, X):
27
+ return self.model(X)
personalized-chat-bot/util/metrics.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ import torch
4
+
5
+
6
+ def _perplexity(logits, labels, pad_token=3):
7
+ for i in range(len(labels)-1, -1, -1):
8
+ if labels[i] != pad_token:
9
+ last_not_pad_id = i
10
+ break
11
+ logits = logits[:last_not_pad_id + 1]
12
+ labels = labels[:last_not_pad_id + 1]
13
+ log_probas = scipy.special.log_softmax(logits, axis=1).astype(np.float32)
14
+ log_probas = [log_probas[i][labels[i]] for i in range(len(labels))]
15
+ l = np.mean(log_probas)
16
+ return 2 ** (-l)
17
+
18
+
19
+ def perplexity(logits, labels, pad_token=3):
20
+ pp = []
21
+ if isinstance(logits, torch.Tensor):
22
+ logits = logits.detach().cpu().numpy()
23
+ if isinstance(labels, torch.Tensor):
24
+ labels = labels.detach().cpu().numpy()
25
+ for cur_logits, cur_labels in zip(logits, labels):
26
+ pp.append(_perplexity(np.array(cur_logits), np.array(cur_labels).astype(int), pad_token))
27
+ return np.mean(pp)