chat-gradio / personalized-chat-bot /personalized_chat_bot.py
j.gilyazev
add personalized-chat-bot
deb7fd3
raw
history blame
2.5 kB
import argparse
import json
import torch
from sklearn.neighbors import KDTree
class PersonalityManager:
def __init__(self, prompt_paths, personality_clustering):
self.prompt_paths = prompt_paths
self.personality_clustering = personality_clustering
self.persona_ids = list(prompt_paths.keys())
self.personalities = [personality_clustering._cluster_centers[i]
for i in self.persona_ids]
self.embeddings = personality_clustering.sentence_transformer.encode(self.personalities)
self._nearest_neighbours = KDTree(self.embeddings, metric='euclidean')
def get_prompt(self, description):
embedding = self.personality_clustering.sentence_transformer.encode([description])
dist, ind = self._nearest_neighbours.query(embedding, k=1)
persona_id = self.persona_ids[ind[0][0]]
prompt_path = self.prompt_paths[persona_id]
cluster_center = self.personality_clustering._cluster_centers[persona_id]
return prompt_path, cluster_center
class PersonalizedChatBot:
def __init__(self, model, tokenizer, prompt_path=None, generation_config=None):
self.model = model
if prompt_path is not None:
self.load_prompt(prompt_path)
self.tokenizer = tokenizer
self.separator = '\n'
self.dialog = ''
self.generation_config = generation_config
def load_prompt(self, path):
self.model.transformer.prompt_embeddings.load_state_dict(torch.load(path))
def load_config(self, path):
with open(path, 'r') as f:
config = json.load(f)
self.generation_config = argparse.Namespace(**config)
def reset_dialog(self, ):
self.dialog = ''
def answer(self, phrase):
if len(phrase) == 0:
return
self.dialog += f"{phrase}{self.separator}"
inputs = self.tokenizer([self.dialog], return_tensors='pt')['input_ids']
outputs = self.model.generate(
inputs,
temperature=self.generation_config.TEMPERATURE,
do_sample=True,
top_k=self.generation_config.TOP_K,
eos_token_id=self.tokenizer.eos_token_id,
max_new_tokens=self.generation_config.MAX_TOKENS,
)
bloom_answer = self.tokenizer.batch_decode(outputs)[0]
bloom_answer = bloom_answer[len(self.dialog):].split("\n")[0]
self.dialog += f"{bloom_answer}{self.separator}"
return bloom_answer