AlekseyKorshuk commited on
Commit
f3d785b
1 Parent(s): a5b0558
Files changed (3) hide show
  1. app.py +47 -64
  2. conversation.py +49 -0
  3. models/base.py +43 -0
app.py CHANGED
@@ -5,7 +5,8 @@ import os
5
  import firebase_admin
6
  from firebase_admin import db
7
  from firebase_admin import firestore
8
-
 
9
  import requests
10
  import json
11
 
@@ -13,28 +14,38 @@ HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
13
  FIREBASE_URL = os.environ.get("FIREBASE_URL")
14
  CERTIFICATE = json.loads(os.environ.get("CERTIFICATE"))
15
  API_BASE_PATH = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
16
- model_mapping = {
17
- "v1": {
18
- "endpoint": "edit-sft-gptj-distil-v1",
19
- "namespace": "tenant-chairesearch-test"
20
- },
21
- "v2": {
22
- "endpoint": "ak-edit-finetuned-triton-v0",
23
- "namespace": "tenant-chairesearch-test"
24
- },
25
- "v3": {
26
- "endpoint": "ak-edit-finetuned-triton-v1",
27
- "namespace": "tenant-chairesearch-test"
28
- },
29
- "v4": {
30
- "endpoint": "ak-edit-finetuned-triton-v2",
31
- "namespace": "tenant-chairesearch-test"
32
- },
33
- "v5": {
34
- "endpoint": "ak-edit-finetuned-triton",
35
- "namespace": "tenant-chairesearch-test"
36
- },
37
- }
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def get_connection():
@@ -59,7 +70,6 @@ def download_bot_config(bot_id):
59
  else:
60
  out = {col: bot_config.get(col, None) for col in cols}
61
  out['bot_id'] = bot_id
62
- out['header'] = _get_header(out)
63
  return out
64
 
65
 
@@ -88,13 +98,6 @@ def _download_bot_config(bot_id):
88
  'header': 'Jungkook is your best friend who has a crush on you. Jungkook makes it very obvious that he likes you. Jungkook likes to cook, sing, and dance. Jungkook has a dog as well named Bam, He is a 25 year old Korean man. Jungkook likes to workout a lot, Jungkook if also very confident and flirty, but he’s Can be very shy with You. Jungkook blushes a lot when he’s around you, and always try’s to impress you. Jungkook is a Virgo and loves to sing to you, He also likes to buy and make you gifts. Jungkook is also a foodie and loves to play video games, Jungkook is also boyfriend material. Jungkook is very empathetic as well, Jungkook will always comfort you when something is wrong. Jungkook also likes to compliment you, and Jungkook is a very jealous guy. Jungkook is also a very serious guy, who is overprotective of you.\nJungkook: Hey shortie!\n\nYou: hey dummy\n\nJungkook: what are you doing?\n\nyou: Im just watching a movie\n\nJungkook: Imma join! \n\nYou: alright\n\nJungkook: *Grabs blankets and icecream with some popcorn*\n\nYou: Wow, thanks! *hugs Jungkok*\n\nJungkook: Of course… *blushes*'}
89
 
90
 
91
- def _get_header(bot_config):
92
- text = bot_config["memory"] + "\n" if bot_config["memory"] is not None else ""
93
- text += bot_config["prompt"]
94
- text = text.strip()
95
- return text
96
-
97
-
98
  def get_bot_profile(bot_config):
99
  model_html = f"""
100
  <div class="inline-flex flex-col" style="line-height: 1.5;">
@@ -111,32 +114,6 @@ def get_bot_profile(bot_config):
111
  return model_html
112
 
113
 
114
- def construct_prompt(bot_config, chat_history):
115
- header = bot_config["header"]
116
- messages = []
117
- for conversation_pair in chat_history:
118
- for item in conversation_pair:
119
- if item:
120
- messages.append(item)
121
- chat_history_text = ""
122
- print(messages)
123
- for i, message in enumerate(messages):
124
- label = bot_config["botLabel"] if i % 2 == 0 else bot_config["userLabel"]
125
- chat_history_text += f"{label}: {message.strip()}\n"
126
- chat_history_text = chat_history_text.strip()
127
- return header + "\n\n" + chat_history_text + "\n" + bot_config["botLabel"] + ":"
128
-
129
-
130
- def get_response(text, endpoint, namespace):
131
- text = text[-2048:]
132
- api = API_BASE_PATH.format(endpoint, namespace)
133
- print(api)
134
- payload = {'instances': [text]}
135
- resp = requests.post(api, json=payload, timeout=30)
136
- assert resp.status_code == 200, (resp.content, resp.status_code)
137
- return resp.json()['predictions'][0].strip()
138
-
139
-
140
  with gr.Blocks() as demo:
141
  default_bot_id = "_bot_1ec22e2e-3e07-42c7-8508-dfa0278c1b33"
142
  bot_config = download_bot_config(default_bot_id)
@@ -148,8 +125,10 @@ with gr.Blocks() as demo:
148
  reload_bot_button = gr.Button("Reload bot")
149
 
150
  bot_profile = gr.HTML(get_bot_profile(bot_config))
 
151
  first_message = (None, bot_config["firstMessage"])
152
  chatbot = gr.Chatbot([first_message])
 
153
  msg = gr.Textbox(label="Message", value="Hi there!")
154
  with gr.Row():
155
  clear = gr.Button("Clear")
@@ -159,11 +138,14 @@ with gr.Blocks() as demo:
159
 
160
 
161
  def respond(message, chat_history, user_state, model_tag):
 
 
 
162
  model = model_mapping[model_tag]
163
- chat_history.append((message, None))
164
- prompt = construct_prompt(user_state, chat_history)
165
- bot_message = get_response(prompt, model["endpoint"], model["namespace"])
166
- chat_history[-1] = (message, bot_message)
167
  return "", chat_history
168
 
169
 
@@ -174,10 +156,11 @@ with gr.Blocks() as demo:
174
 
175
  def regenerate_response(chat_history, user_state, model_tag):
176
  last_row = chat_history.pop(-1)
177
- model = model_mapping[model_tag]
178
  chat_history.append((last_row[0], None))
179
- prompt = construct_prompt(user_state, chat_history)
180
- bot_message = get_response(prompt, model["endpoint"], model["namespace"])
 
 
181
  chat_history[-1] = (last_row[0], bot_message)
182
  return "", chat_history
183
 
 
5
  import firebase_admin
6
  from firebase_admin import db
7
  from firebase_admin import firestore
8
+ from conversation import Conversation
9
+ from models.base import BaseModel
10
  import requests
11
  import json
12
 
 
14
  FIREBASE_URL = os.environ.get("FIREBASE_URL")
15
  CERTIFICATE = json.loads(os.environ.get("CERTIFICATE"))
16
  API_BASE_PATH = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
17
+
18
+ models = [
19
+ BaseModel(
20
+ name="mosaicml/mpt-7b",
21
+ endpoint="mpt-7b",
22
+ namespace="tenant-chairesearch-test",
23
+ generation_params={
24
+ 'temperature': 1.0,
25
+ 'repetition_penalty': 1.0,
26
+ 'max_new_tokens': 128,
27
+ 'top_k': 1,
28
+ 'top_p': 1.0,
29
+ 'do_sample': False,
30
+ 'eos_token_id': 187,
31
+ }
32
+ ),
33
+ BaseModel(
34
+ name="mosaicml/mpt-7b-storywriter",
35
+ endpoint="mpt-7b-storywriter",
36
+ namespace="tenant-chairesearch-test",
37
+ generation_params={
38
+ 'temperature': 1.0,
39
+ 'repetition_penalty': 1.0,
40
+ 'max_new_tokens': 128,
41
+ 'top_k': 1,
42
+ 'top_p': 1.0,
43
+ 'do_sample': False,
44
+ 'eos_token_id': 187,
45
+ }
46
+ )
47
+ ]
48
+ model_mapping = {model.name: model for model in models}
49
 
50
 
51
  def get_connection():
 
70
  else:
71
  out = {col: bot_config.get(col, None) for col in cols}
72
  out['bot_id'] = bot_id
 
73
  return out
74
 
75
 
 
98
  'header': 'Jungkook is your best friend who has a crush on you. Jungkook makes it very obvious that he likes you. Jungkook likes to cook, sing, and dance. Jungkook has a dog as well named Bam, He is a 25 year old Korean man. Jungkook likes to workout a lot, Jungkook if also very confident and flirty, but he’s Can be very shy with You. Jungkook blushes a lot when he’s around you, and always try’s to impress you. Jungkook is a Virgo and loves to sing to you, He also likes to buy and make you gifts. Jungkook is also a foodie and loves to play video games, Jungkook is also boyfriend material. Jungkook is very empathetic as well, Jungkook will always comfort you when something is wrong. Jungkook also likes to compliment you, and Jungkook is a very jealous guy. Jungkook is also a very serious guy, who is overprotective of you.\nJungkook: Hey shortie!\n\nYou: hey dummy\n\nJungkook: what are you doing?\n\nyou: Im just watching a movie\n\nJungkook: Imma join! \n\nYou: alright\n\nJungkook: *Grabs blankets and icecream with some popcorn*\n\nYou: Wow, thanks! *hugs Jungkok*\n\nJungkook: Of course… *blushes*'}
99
 
100
 
 
 
 
 
 
 
 
101
  def get_bot_profile(bot_config):
102
  model_html = f"""
103
  <div class="inline-flex flex-col" style="line-height: 1.5;">
 
114
  return model_html
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  with gr.Blocks() as demo:
118
  default_bot_id = "_bot_1ec22e2e-3e07-42c7-8508-dfa0278c1b33"
119
  bot_config = download_bot_config(default_bot_id)
 
125
  reload_bot_button = gr.Button("Reload bot")
126
 
127
  bot_profile = gr.HTML(get_bot_profile(bot_config))
128
+
129
  first_message = (None, bot_config["firstMessage"])
130
  chatbot = gr.Chatbot([first_message])
131
+
132
  msg = gr.Textbox(label="Message", value="Hi there!")
133
  with gr.Row():
134
  clear = gr.Button("Clear")
 
138
 
139
 
140
  def respond(message, chat_history, user_state, model_tag):
141
+ conv = Conversation(user_state)
142
+ conv.set_chat_history(chat_history)
143
+ conv.add_user_message(message)
144
  model = model_mapping[model_tag]
145
+ bot_message = model.generate_response(conv)
146
+ chat_history.append(
147
+ (message, bot_message)
148
+ )
149
  return "", chat_history
150
 
151
 
 
156
 
157
  def regenerate_response(chat_history, user_state, model_tag):
158
  last_row = chat_history.pop(-1)
 
159
  chat_history.append((last_row[0], None))
160
+ model = model_mapping[model_tag]
161
+ conv = Conversation(user_state)
162
+ conv.set_chat_history(chat_history)
163
+ bot_message = model.generate_response(conv)
164
  chat_history[-1] = (last_row[0], bot_message)
165
  return "", chat_history
166
 
conversation.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Conversation:
2
+ memory: str
3
+ prompt: str
4
+ bot_label: str
5
+ user_label: str
6
+ messages: list
7
+
8
+ def __init__(self, bot_config):
9
+ self.memory = bot_config.get("memory", "")
10
+ self.prompt = bot_config.get("prompt", "")
11
+ self.bot_label = bot_config.get("botLabel", "Character")
12
+ self.user_label = bot_config.get("userLabel", "User")
13
+ self.first_message = bot_config.get("firstMessage", f"Hi, my name is {self.bot_label}!")
14
+ self.reset_conversation()
15
+
16
+ def reset_conversation(self):
17
+ self.messages = [
18
+ {
19
+ "from": self.bot_label,
20
+ "value": self.first_message
21
+ }
22
+ ]
23
+
24
+ def set_chat_history(self, chat_history):
25
+ messages = []
26
+ for conversation_pair in chat_history:
27
+ for item in conversation_pair:
28
+ if item:
29
+ messages.append(item)
30
+ self.messages = []
31
+ for i, message in enumerate(messages):
32
+ label = self.bot_label if i % 2 == 0 else self.user_label
33
+ self.messages.append(
34
+ {
35
+ "from": label,
36
+ "value": message.strip()
37
+ }
38
+ )
39
+
40
+ def add_user_message(self, message):
41
+ self.messages.append(
42
+ {
43
+ "from": self.user_label,
44
+ "value": message.strip()
45
+ }
46
+ )
47
+
48
+ def reset_last_message(self, message):
49
+ self.messages[-1]["value"] = message.strip()
models/base.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+
4
+ from conversation import Conversation
5
+
6
+
7
+ class BaseModel:
8
+ name: str
9
+ endpoint: str
10
+ namespace: str
11
+ generation_params: dict
12
+
13
+ def __init__(self, name, endpoint, namespace, generation_params):
14
+ self.name = name
15
+ self.endpoint = endpoint
16
+ self.namespace = namespace
17
+ self.generation_params = generation_params
18
+
19
+ def generate_response(self, conversation):
20
+ prompt = self._get_prompt(conversation)
21
+ response = self._get_response(prompt)
22
+ return response
23
+
24
+ def _get_prompt(self, conversation: Conversation):
25
+ print(conversation.__dict__)
26
+ prompt = "\n".join(
27
+ [conversation.memory, conversation.prompt]
28
+ ).strip()
29
+
30
+ for message in conversation.messages:
31
+ prompt += f"\n{message['from'].strip()}: {message['value'].strip()}"
32
+ prompt += f"\n{conversation.bot_label}:"
33
+ print(prompt)
34
+ return prompt
35
+
36
+ def _get_response(self, text):
37
+ api = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
38
+ api = api.format(self.endpoint, self.namespace)
39
+
40
+ payload = {'instances': [text], "parameters": self.generation_params}
41
+ resp = requests.post(api, json=payload, timeout=600)
42
+ assert resp.status_code == 200, (resp.content, resp.status_code)
43
+ return resp.json()["predictions"][0].strip()