File size: 2,501 Bytes
deb7fd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import json
import torch
from sklearn.neighbors import KDTree


class PersonalityManager:
    def __init__(self, prompt_paths, personality_clustering):
        self.prompt_paths = prompt_paths
        self.personality_clustering = personality_clustering

        self.persona_ids = list(prompt_paths.keys())
        self.personalities = [personality_clustering._cluster_centers[i]
                              for i in self.persona_ids]

        self.embeddings = personality_clustering.sentence_transformer.encode(self.personalities)
        self._nearest_neighbours = KDTree(self.embeddings, metric='euclidean')

    def get_prompt(self, description):
        embedding = self.personality_clustering.sentence_transformer.encode([description])
        dist, ind = self._nearest_neighbours.query(embedding, k=1)
        persona_id = self.persona_ids[ind[0][0]]
        prompt_path = self.prompt_paths[persona_id]
        cluster_center = self.personality_clustering._cluster_centers[persona_id]
        return prompt_path, cluster_center


class PersonalizedChatBot:
    def __init__(self, model, tokenizer, prompt_path=None, generation_config=None):
        self.model = model
        if prompt_path is not None:
            self.load_prompt(prompt_path)
        self.tokenizer = tokenizer
        self.separator = '\n'
        self.dialog = ''
        self.generation_config = generation_config

    def load_prompt(self, path):
        self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))

    def load_config(self, path):
        with open(path, 'r') as f:
            config = json.load(f)
        self.generation_config = argparse.Namespace(**config)

    def reset_dialog(self, ):
        self.dialog = ''

    def answer(self, phrase):
        if len(phrase) == 0:
            return
        self.dialog += f"{phrase}{self.separator}"
        inputs = self.tokenizer([self.dialog], return_tensors='pt')['input_ids']
        outputs = self.model.generate(
            inputs,
            temperature=self.generation_config.TEMPERATURE,
            do_sample=True,
            top_k=self.generation_config.TOP_K,
            eos_token_id=self.tokenizer.eos_token_id,
            max_new_tokens=self.generation_config.MAX_TOKENS,
        )
        bloom_answer = self.tokenizer.batch_decode(outputs)[0]
        bloom_answer = bloom_answer[len(self.dialog):].split("\n")[0]
        self.dialog += f"{bloom_answer}{self.separator}"
        return bloom_answer