gosha6037 commited on
Commit
91ffd2c
1 Parent(s): dd71c9a

Added Cluster Bloom

Browse files
Files changed (1) hide show
  1. app.py +48 -28
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import sys
2
  import json
 
3
 
4
  import gradio as gr
5
  import torch
@@ -11,6 +12,7 @@ sys.path.insert(0, './personalized-chat-bot/')
11
 
12
 
13
  from petals.client.remote_model import DistributedBloomForCausalLM
 
14
  from models.personality_clustering import PersonalityClustering
15
 
16
  MODEL_NAME = "bigscience/bloom-petals"
@@ -67,41 +69,59 @@ def predict_common_bloom(model, tokenizer, input_text, history, person_descripti
67
  return response_new, history_new
68
 
69
 
 
 
 
 
 
 
70
  def predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
71
- new_user_input_ids = tokenizer.encode(input_text + '\n', return_tensors='pt')
72
- print('Started predict_common_bloom')
73
- print(f'history: {history}')
74
- if history != []:
75
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
76
- else:
77
- bot_input_ids = new_user_input_ids
78
- print(f'bot_input_ids: {bot_input_ids}')
79
 
80
- history = model.generate(
81
- bot_input_ids,
82
- max_new_tokens=number_of_new_tokens,
83
- pad_token_id=tokenizer.eos_token_id
84
- ).tolist()
85
- print(f'history: {history}')
86
 
87
- decode_all = tokenizer.decode(history[0][:len(bot_input_ids[0])])
88
- all_responses = tokenizer.decode(history[0][len(bot_input_ids[0]):]).split('\n')
89
- if all_responses[0]:
90
- decode_all += all_responses[0] + '\n'
91
- else:
92
- decode_all += all_responses[1] + '\n'
93
- print(f'decode_all: {decode_all}')
94
 
95
- history_new = tokenizer.encode(decode_all, return_tensors='pt')
96
- print(f'history_new: {history_new}')
97
 
98
- decode_all_split = decode_all.split('\n')
99
- print(f'decode_all_split: {decode_all_split}')
 
 
 
 
100
 
101
- response_new = [(decode_all_split[i], decode_all_split[i + 1]) for i in range(0, len(decode_all_split) - 1, 2)]
102
- print(f'response_new: {response_new}')
 
103
 
104
- return response_new, history_new
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
 
107
  def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
 
1
  import sys
2
  import json
3
+ import argparse
4
 
5
  import gradio as gr
6
  import torch
 
12
 
13
 
14
  from petals.client.remote_model import DistributedBloomForCausalLM
15
+ from personalized_chat_bot import PersonalizedChatBot, PersonalityManager
16
  from models.personality_clustering import PersonalityClustering
17
 
18
  MODEL_NAME = "bigscience/bloom-petals"
 
69
  return response_new, history_new
70
 
71
 
72
+ def load_config(path):
73
+ with open(path, 'r') as f:
74
+ config = json.load(f)
75
+ return argparse.Namespace(**config)
76
+
77
+
78
  def predict_cluster_bloom(model, tokenizer, input_text, history, person_description, number_of_new_tokens):
79
+ personality_clustering = PersonalityClustering()
80
+ personality_clustering.load('personalized-chat-bot/data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl')
 
 
 
 
 
 
81
 
82
+ hook = lambda dct: {int(k): v for k, v in dct.items()}
83
+ with open('personalized-chat-bot/prompt_paths.json', 'r') as f:
84
+ prompt_paths = json.load(f, object_hook=hook)
 
 
 
85
 
86
+ pm = PersonalityManager(prompt_paths, personality_clustering)
87
+ prompt_path, closest_persona = pm.get_prompt(person_description)
88
+ print(f'The closest personality is: {closest_persona}')
89
+ print('Wait a little longer...')
90
+ config = load_config('personalized-chat-bot/scripts/config_176b.json')
 
 
91
 
 
 
92
 
93
+ model = DistributedBloomForCausalLM.from_pretrained(
94
+ config.MODEL_NAME,
95
+ pre_seq_len=config.NUM_PREFIX_TOKENS,
96
+ tuning_mode=config.TUNING_MODE,
97
+ # max_new_tokens=number_of_new_tokens,
98
+ ).to(config.DEVICE)
99
 
100
+ generation_config = load_config('personalized-chat-bot/generation_config.json')
101
+ generation_config.max_new_tokens=number_of_new_tokens
102
+ print(f'generation_config: {generation_config}')
103
 
104
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
105
+ tokenizer.padding_side = 'right'
106
+ tokenizer.model_max_length = config.MODEL_MAX_LENGTH
107
+
108
+ chatbot = PersonalizedChatBot(model, tokenizer, generation_config=generation_config)
109
+ chatbot.load_prompt('personalized-chat-bot/' + prompt_path)
110
+ if history != []:
111
+ input_text = tokenizer.decode(history[0]) + '\n' + input_text
112
+ print(f'INPUT: {input_text}')
113
+ output = chatbot.answer(input_text)
114
+ all_text = input_text + '\n' + output
115
+ print(f'all_text: {all_text}')
116
+
117
+ history = tokenizer.encode(all_text, return_tensors='pt')
118
+ print(f'history: {history}')
119
+
120
+ response = tokenizer.decode(history[0]).split("\n")
121
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
122
+ print(f'response: {response}')
123
+
124
+ return response, history
125
 
126
 
127
  def predict_dialo_gpt(model, tokenizer, input_text, history, person_description, number_of_new_tokens):