Spaces:
Runtime error
Runtime error
AlekseyKorshuk
commited on
Commit
•
f3d785b
1
Parent(s):
a5b0558
updates
Browse files- app.py +47 -64
- conversation.py +49 -0
- models/base.py +43 -0
app.py
CHANGED
@@ -5,7 +5,8 @@ import os
|
|
5 |
import firebase_admin
|
6 |
from firebase_admin import db
|
7 |
from firebase_admin import firestore
|
8 |
-
|
|
|
9 |
import requests
|
10 |
import json
|
11 |
|
@@ -13,28 +14,38 @@ HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
|
|
13 |
FIREBASE_URL = os.environ.get("FIREBASE_URL")
|
14 |
CERTIFICATE = json.loads(os.environ.get("CERTIFICATE"))
|
15 |
API_BASE_PATH = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
"
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
def get_connection():
|
@@ -59,7 +70,6 @@ def download_bot_config(bot_id):
|
|
59 |
else:
|
60 |
out = {col: bot_config.get(col, None) for col in cols}
|
61 |
out['bot_id'] = bot_id
|
62 |
-
out['header'] = _get_header(out)
|
63 |
return out
|
64 |
|
65 |
|
@@ -88,13 +98,6 @@ def _download_bot_config(bot_id):
|
|
88 |
'header': 'Jungkook is your best friend who has a crush on you. Jungkook makes it very obvious that he likes you. Jungkook likes to cook, sing, and dance. Jungkook has a dog as well named Bam, He is a 25 year old Korean man. Jungkook likes to workout a lot, Jungkook if also very confident and flirty, but he’s Can be very shy with You. Jungkook blushes a lot when he’s around you, and always try’s to impress you. Jungkook is a Virgo and loves to sing to you, He also likes to buy and make you gifts. Jungkook is also a foodie and loves to play video games, Jungkook is also boyfriend material. Jungkook is very empathetic as well, Jungkook will always comfort you when something is wrong. Jungkook also likes to compliment you, and Jungkook is a very jealous guy. Jungkook is also a very serious guy, who is overprotective of you.\nJungkook: Hey shortie!\n\nYou: hey dummy\n\nJungkook: what are you doing?\n\nyou: Im just watching a movie\n\nJungkook: Imma join! \n\nYou: alright\n\nJungkook: *Grabs blankets and icecream with some popcorn*\n\nYou: Wow, thanks! *hugs Jungkok*\n\nJungkook: Of course… *blushes*'}
|
89 |
|
90 |
|
91 |
-
def _get_header(bot_config):
|
92 |
-
text = bot_config["memory"] + "\n" if bot_config["memory"] is not None else ""
|
93 |
-
text += bot_config["prompt"]
|
94 |
-
text = text.strip()
|
95 |
-
return text
|
96 |
-
|
97 |
-
|
98 |
def get_bot_profile(bot_config):
|
99 |
model_html = f"""
|
100 |
<div class="inline-flex flex-col" style="line-height: 1.5;">
|
@@ -111,32 +114,6 @@ def get_bot_profile(bot_config):
|
|
111 |
return model_html
|
112 |
|
113 |
|
114 |
-
def construct_prompt(bot_config, chat_history):
|
115 |
-
header = bot_config["header"]
|
116 |
-
messages = []
|
117 |
-
for conversation_pair in chat_history:
|
118 |
-
for item in conversation_pair:
|
119 |
-
if item:
|
120 |
-
messages.append(item)
|
121 |
-
chat_history_text = ""
|
122 |
-
print(messages)
|
123 |
-
for i, message in enumerate(messages):
|
124 |
-
label = bot_config["botLabel"] if i % 2 == 0 else bot_config["userLabel"]
|
125 |
-
chat_history_text += f"{label}: {message.strip()}\n"
|
126 |
-
chat_history_text = chat_history_text.strip()
|
127 |
-
return header + "\n\n" + chat_history_text + "\n" + bot_config["botLabel"] + ":"
|
128 |
-
|
129 |
-
|
130 |
-
def get_response(text, endpoint, namespace):
|
131 |
-
text = text[-2048:]
|
132 |
-
api = API_BASE_PATH.format(endpoint, namespace)
|
133 |
-
print(api)
|
134 |
-
payload = {'instances': [text]}
|
135 |
-
resp = requests.post(api, json=payload, timeout=30)
|
136 |
-
assert resp.status_code == 200, (resp.content, resp.status_code)
|
137 |
-
return resp.json()['predictions'][0].strip()
|
138 |
-
|
139 |
-
|
140 |
with gr.Blocks() as demo:
|
141 |
default_bot_id = "_bot_1ec22e2e-3e07-42c7-8508-dfa0278c1b33"
|
142 |
bot_config = download_bot_config(default_bot_id)
|
@@ -148,8 +125,10 @@ with gr.Blocks() as demo:
|
|
148 |
reload_bot_button = gr.Button("Reload bot")
|
149 |
|
150 |
bot_profile = gr.HTML(get_bot_profile(bot_config))
|
|
|
151 |
first_message = (None, bot_config["firstMessage"])
|
152 |
chatbot = gr.Chatbot([first_message])
|
|
|
153 |
msg = gr.Textbox(label="Message", value="Hi there!")
|
154 |
with gr.Row():
|
155 |
clear = gr.Button("Clear")
|
@@ -159,11 +138,14 @@ with gr.Blocks() as demo:
|
|
159 |
|
160 |
|
161 |
def respond(message, chat_history, user_state, model_tag):
|
|
|
|
|
|
|
162 |
model = model_mapping[model_tag]
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
return "", chat_history
|
168 |
|
169 |
|
@@ -174,10 +156,11 @@ with gr.Blocks() as demo:
|
|
174 |
|
175 |
def regenerate_response(chat_history, user_state, model_tag):
|
176 |
last_row = chat_history.pop(-1)
|
177 |
-
model = model_mapping[model_tag]
|
178 |
chat_history.append((last_row[0], None))
|
179 |
-
|
180 |
-
|
|
|
|
|
181 |
chat_history[-1] = (last_row[0], bot_message)
|
182 |
return "", chat_history
|
183 |
|
|
|
5 |
import firebase_admin
|
6 |
from firebase_admin import db
|
7 |
from firebase_admin import firestore
|
8 |
+
from conversation import Conversation
|
9 |
+
from models.base import BaseModel
|
10 |
import requests
|
11 |
import json
|
12 |
|
|
|
14 |
FIREBASE_URL = os.environ.get("FIREBASE_URL")
|
15 |
CERTIFICATE = json.loads(os.environ.get("CERTIFICATE"))
|
16 |
API_BASE_PATH = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
|
17 |
+
|
18 |
+
models = [
|
19 |
+
BaseModel(
|
20 |
+
name="mosaicml/mpt-7b",
|
21 |
+
endpoint="mpt-7b",
|
22 |
+
namespace="tenant-chairesearch-test",
|
23 |
+
generation_params={
|
24 |
+
'temperature': 1.0,
|
25 |
+
'repetition_penalty': 1.0,
|
26 |
+
'max_new_tokens': 128,
|
27 |
+
'top_k': 1,
|
28 |
+
'top_p': 1.0,
|
29 |
+
'do_sample': False,
|
30 |
+
'eos_token_id': 187,
|
31 |
+
}
|
32 |
+
),
|
33 |
+
BaseModel(
|
34 |
+
name="mosaicml/mpt-7b-storywriter",
|
35 |
+
endpoint="mpt-7b-storywriter",
|
36 |
+
namespace="tenant-chairesearch-test",
|
37 |
+
generation_params={
|
38 |
+
'temperature': 1.0,
|
39 |
+
'repetition_penalty': 1.0,
|
40 |
+
'max_new_tokens': 128,
|
41 |
+
'top_k': 1,
|
42 |
+
'top_p': 1.0,
|
43 |
+
'do_sample': False,
|
44 |
+
'eos_token_id': 187,
|
45 |
+
}
|
46 |
+
)
|
47 |
+
]
|
48 |
+
model_mapping = {model.name: model for model in models}
|
49 |
|
50 |
|
51 |
def get_connection():
|
|
|
70 |
else:
|
71 |
out = {col: bot_config.get(col, None) for col in cols}
|
72 |
out['bot_id'] = bot_id
|
|
|
73 |
return out
|
74 |
|
75 |
|
|
|
98 |
'header': 'Jungkook is your best friend who has a crush on you. Jungkook makes it very obvious that he likes you. Jungkook likes to cook, sing, and dance. Jungkook has a dog as well named Bam, He is a 25 year old Korean man. Jungkook likes to workout a lot, Jungkook if also very confident and flirty, but he’s Can be very shy with You. Jungkook blushes a lot when he’s around you, and always try’s to impress you. Jungkook is a Virgo and loves to sing to you, He also likes to buy and make you gifts. Jungkook is also a foodie and loves to play video games, Jungkook is also boyfriend material. Jungkook is very empathetic as well, Jungkook will always comfort you when something is wrong. Jungkook also likes to compliment you, and Jungkook is a very jealous guy. Jungkook is also a very serious guy, who is overprotective of you.\nJungkook: Hey shortie!\n\nYou: hey dummy\n\nJungkook: what are you doing?\n\nyou: Im just watching a movie\n\nJungkook: Imma join! \n\nYou: alright\n\nJungkook: *Grabs blankets and icecream with some popcorn*\n\nYou: Wow, thanks! *hugs Jungkok*\n\nJungkook: Of course… *blushes*'}
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
def get_bot_profile(bot_config):
|
102 |
model_html = f"""
|
103 |
<div class="inline-flex flex-col" style="line-height: 1.5;">
|
|
|
114 |
return model_html
|
115 |
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
with gr.Blocks() as demo:
|
118 |
default_bot_id = "_bot_1ec22e2e-3e07-42c7-8508-dfa0278c1b33"
|
119 |
bot_config = download_bot_config(default_bot_id)
|
|
|
125 |
reload_bot_button = gr.Button("Reload bot")
|
126 |
|
127 |
bot_profile = gr.HTML(get_bot_profile(bot_config))
|
128 |
+
|
129 |
first_message = (None, bot_config["firstMessage"])
|
130 |
chatbot = gr.Chatbot([first_message])
|
131 |
+
|
132 |
msg = gr.Textbox(label="Message", value="Hi there!")
|
133 |
with gr.Row():
|
134 |
clear = gr.Button("Clear")
|
|
|
138 |
|
139 |
|
140 |
def respond(message, chat_history, user_state, model_tag):
|
141 |
+
conv = Conversation(user_state)
|
142 |
+
conv.set_chat_history(chat_history)
|
143 |
+
conv.add_user_message(message)
|
144 |
model = model_mapping[model_tag]
|
145 |
+
bot_message = model.generate_response(conv)
|
146 |
+
chat_history.append(
|
147 |
+
(message, bot_message)
|
148 |
+
)
|
149 |
return "", chat_history
|
150 |
|
151 |
|
|
|
156 |
|
157 |
def regenerate_response(chat_history, user_state, model_tag):
|
158 |
last_row = chat_history.pop(-1)
|
|
|
159 |
chat_history.append((last_row[0], None))
|
160 |
+
model = model_mapping[model_tag]
|
161 |
+
conv = Conversation(user_state)
|
162 |
+
conv.set_chat_history(chat_history)
|
163 |
+
bot_message = model.generate_response(conv)
|
164 |
chat_history[-1] = (last_row[0], bot_message)
|
165 |
return "", chat_history
|
166 |
|
conversation.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Conversation:
|
2 |
+
memory: str
|
3 |
+
prompt: str
|
4 |
+
bot_label: str
|
5 |
+
user_label: str
|
6 |
+
messages: list
|
7 |
+
|
8 |
+
def __init__(self, bot_config):
|
9 |
+
self.memory = bot_config.get("memory", "")
|
10 |
+
self.prompt = bot_config.get("prompt", "")
|
11 |
+
self.bot_label = bot_config.get("botLabel", "Character")
|
12 |
+
self.user_label = bot_config.get("userLabel", "User")
|
13 |
+
self.first_message = bot_config.get("firstMessage", f"Hi, my name is {self.bot_label}!")
|
14 |
+
self.reset_conversation()
|
15 |
+
|
16 |
+
def reset_conversation(self):
|
17 |
+
self.messages = [
|
18 |
+
{
|
19 |
+
"from": self.bot_label,
|
20 |
+
"value": self.first_message
|
21 |
+
}
|
22 |
+
]
|
23 |
+
|
24 |
+
def set_chat_history(self, chat_history):
|
25 |
+
messages = []
|
26 |
+
for conversation_pair in chat_history:
|
27 |
+
for item in conversation_pair:
|
28 |
+
if item:
|
29 |
+
messages.append(item)
|
30 |
+
self.messages = []
|
31 |
+
for i, message in enumerate(messages):
|
32 |
+
label = self.bot_label if i % 2 == 0 else self.user_label
|
33 |
+
self.messages.append(
|
34 |
+
{
|
35 |
+
"from": label,
|
36 |
+
"value": message.strip()
|
37 |
+
}
|
38 |
+
)
|
39 |
+
|
40 |
+
def add_user_message(self, message):
|
41 |
+
self.messages.append(
|
42 |
+
{
|
43 |
+
"from": self.user_label,
|
44 |
+
"value": message.strip()
|
45 |
+
}
|
46 |
+
)
|
47 |
+
|
48 |
+
def reset_last_message(self, message):
|
49 |
+
self.messages[-1]["value"] = message.strip()
|
models/base.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
|
4 |
+
from conversation import Conversation
|
5 |
+
|
6 |
+
|
7 |
+
class BaseModel:
|
8 |
+
name: str
|
9 |
+
endpoint: str
|
10 |
+
namespace: str
|
11 |
+
generation_params: dict
|
12 |
+
|
13 |
+
def __init__(self, name, endpoint, namespace, generation_params):
|
14 |
+
self.name = name
|
15 |
+
self.endpoint = endpoint
|
16 |
+
self.namespace = namespace
|
17 |
+
self.generation_params = generation_params
|
18 |
+
|
19 |
+
def generate_response(self, conversation):
|
20 |
+
prompt = self._get_prompt(conversation)
|
21 |
+
response = self._get_response(prompt)
|
22 |
+
return response
|
23 |
+
|
24 |
+
def _get_prompt(self, conversation: Conversation):
|
25 |
+
print(conversation.__dict__)
|
26 |
+
prompt = "\n".join(
|
27 |
+
[conversation.memory, conversation.prompt]
|
28 |
+
).strip()
|
29 |
+
|
30 |
+
for message in conversation.messages:
|
31 |
+
prompt += f"\n{message['from'].strip()}: {message['value'].strip()}"
|
32 |
+
prompt += f"\n{conversation.bot_label}:"
|
33 |
+
print(prompt)
|
34 |
+
return prompt
|
35 |
+
|
36 |
+
def _get_response(self, text):
|
37 |
+
api = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
|
38 |
+
api = api.format(self.endpoint, self.namespace)
|
39 |
+
|
40 |
+
payload = {'instances': [text], "parameters": self.generation_params}
|
41 |
+
resp = requests.post(api, json=payload, timeout=600)
|
42 |
+
assert resp.status_code == 200, (resp.content, resp.status_code)
|
43 |
+
return resp.json()["predictions"][0].strip()
|