Added MPT
Browse files
app.py
CHANGED
@@ -7,19 +7,20 @@ import numpy as np
|
|
7 |
from typing import Iterator
|
8 |
|
9 |
import gradio as gr
|
|
|
10 |
import pandas as pd
|
11 |
import torch
|
12 |
|
13 |
-
from easyllm.clients import huggingface
|
14 |
from transformers import AutoTokenizer
|
|
|
|
|
|
|
15 |
|
16 |
-
huggingface.prompt_builder = "llama2"
|
17 |
-
huggingface.api_key = os.environ["HUGGINGFACE_TOKEN"]
|
18 |
MAX_MAX_NEW_TOKENS = 250
|
19 |
DEFAULT_MAX_NEW_TOKENS = 250
|
20 |
MAX_INPUT_TOKEN_LENGTH = 4000
|
21 |
EMBED_DIM = 1024
|
22 |
-
K =
|
23 |
EF = 100
|
24 |
TEXT_FILE = 'data.txt'
|
25 |
SEARCH_INDEX = "search_index.bin"
|
@@ -28,20 +29,24 @@ DOCUMENT_DATASET = "chunked_data.parquet"
|
|
28 |
COSINE_THRESHOLD = 0.7
|
29 |
|
30 |
|
31 |
-
|
32 |
-
|
33 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
print("Running on device:", torch_device)
|
35 |
print("CPU threads:", torch.get_num_threads())
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def read_text_from_file(file_path):
|
44 |
-
with open(file_path, "r") as text_file:
|
45 |
text = text_file.read()
|
46 |
texts = text.split("&&")
|
47 |
return [t.strip() for t in texts]
|
@@ -89,37 +94,56 @@ def get_input_token_length(message: str, chat_history: list[tuple[str, str]], sy
|
|
89 |
input_ids = tokenizer([prompt], return_tensors="np", add_special_tokens=False)["input_ids"]
|
90 |
return input_ids.shape[-1]
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
# https://www.philschmid.de/llama-2#how-to-prompt-llama-2-chat
|
94 |
def get_completion(
|
95 |
prompt,
|
96 |
system_prompt=None,
|
97 |
-
model=
|
98 |
-
max_new_tokens=
|
99 |
temperature=0.2,
|
100 |
top_p=0.95,
|
101 |
top_k=50,
|
102 |
stream=False,
|
103 |
debug=False,
|
104 |
):
|
|
|
105 |
if temperature < 1e-2:
|
106 |
temperature = 1e-2
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
response = huggingface.ChatCompletion.create(
|
112 |
-
model=model,
|
113 |
-
messages=messages,
|
114 |
-
temperature=temperature, # this is the degree of randomness of the model's output
|
115 |
-
max_tokens=250, # this is the number of new tokens being generated
|
116 |
-
top_p=top_p,
|
117 |
-
top_k=top_k,
|
118 |
-
stream=stream,
|
119 |
-
debug=debug,
|
120 |
)
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
# load the index for the data
|
125 |
def load_hnsw_index(index_file):
|
@@ -177,17 +201,18 @@ def generate_condensed_query(query, history):
|
|
177 |
chat_history += f"Assistant: {turn[1]}\n"
|
178 |
|
179 |
condense_question_prompt = create_condense_question_prompt(query, chat_history)
|
180 |
-
condensed_question = json.loads(get_completion(condense_question_prompt, max_new_tokens=64, temperature=0))
|
181 |
-
|
|
|
182 |
|
183 |
|
184 |
DEFAULT_SYSTEM_PROMPT = """\
|
185 |
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
186 |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
|
187 |
"""
|
188 |
-
MAX_MAX_NEW_TOKENS = 2048
|
189 |
-
DEFAULT_MAX_NEW_TOKENS = 1024
|
190 |
-
MAX_INPUT_TOKEN_LENGTH = 4000
|
191 |
|
192 |
DESCRIPTION = """
|
193 |
# AVA Southampton Chatbot 🤗
|
@@ -265,7 +290,8 @@ def generate(
|
|
265 |
|
266 |
output = ""
|
267 |
for idx, response in enumerate(generator):
|
268 |
-
token = response["choices"][0]["delta"].get("content", "") or ""
|
|
|
269 |
output += token
|
270 |
if idx == 0:
|
271 |
history.append((message, output))
|
@@ -273,7 +299,7 @@ def generate(
|
|
273 |
history[-1] = (message, output)
|
274 |
|
275 |
history = [
|
276 |
-
(wrap_html_code(history[i][0]
|
277 |
for i in range(0, len(history))
|
278 |
]
|
279 |
yield history
|
@@ -483,4 +509,4 @@ with gr.Blocks(css="style.css") as demo:
|
|
483 |
api_name=False,
|
484 |
)
|
485 |
|
486 |
-
demo.queue(max_size=20).launch(debug=True
|
|
|
7 |
from typing import Iterator
|
8 |
|
9 |
import gradio as gr
|
10 |
+
from gradio_client import Client
|
11 |
import pandas as pd
|
12 |
import torch
|
13 |
|
|
|
14 |
from transformers import AutoTokenizer
|
15 |
+
from awq import AutoAWQForCausalLM
|
16 |
+
from transformers import AutoTokenizer
|
17 |
+
|
18 |
|
|
|
|
|
19 |
MAX_MAX_NEW_TOKENS = 250
|
20 |
DEFAULT_MAX_NEW_TOKENS = 250
|
21 |
MAX_INPUT_TOKEN_LENGTH = 4000
|
22 |
EMBED_DIM = 1024
|
23 |
+
K = 2
|
24 |
EF = 100
|
25 |
TEXT_FILE = 'data.txt'
|
26 |
SEARCH_INDEX = "search_index.bin"
|
|
|
29 |
COSINE_THRESHOLD = 0.7
|
30 |
|
31 |
|
|
|
|
|
32 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
print("Running on device:", torch_device)
|
34 |
print("CPU threads:", torch.get_num_threads())
|
35 |
|
36 |
+
biencoder = SentenceTransformer("intfloat/e5-large-v2", device="cpu")
|
37 |
+
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length=512, device="cpu")
|
38 |
+
model_name_or_path = "TheBloke/TinyLlama-1.1B-1T-OpenOrca-AWQ"
|
39 |
|
40 |
+
# Load model
|
41 |
+
# model = AutoAWQForCausalLM.from_quantized(model_name_or_path, fuse_layers=True,
|
42 |
+
# trust_remote_code=False, safetensors=True)
|
43 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=False)
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-30b-chat", trust_remote_code=False)
|
45 |
+
chat_client = Client("https://mosaicml-mpt-30b-chat.hf.space/", serialize = False)
|
46 |
+
chat_bot = [["", None]]
|
47 |
|
48 |
def read_text_from_file(file_path):
|
49 |
+
with open(file_path, "r", encoding="utf-8") as text_file:
|
50 |
text = text_file.read()
|
51 |
texts = text.split("&&")
|
52 |
return [t.strip() for t in texts]
|
|
|
94 |
input_ids = tokenizer([prompt], return_tensors="np", add_special_tokens=False)["input_ids"]
|
95 |
return input_ids.shape[-1]
|
96 |
|
97 |
+
def prompt_builder(prompt, system_message="You are a helpful chatbot which gives correct and truthful answers"):
|
98 |
+
return f'''<|im_start|>system
|
99 |
+
{system_message}<|im_end|>
|
100 |
+
<|im_start|>user
|
101 |
+
{prompt}<|im_end|>
|
102 |
+
<|im_start|>assistant
|
103 |
+
|
104 |
+
|
105 |
+
'''
|
106 |
|
107 |
# https://www.philschmid.de/llama-2#how-to-prompt-llama-2-chat
|
108 |
def get_completion(
|
109 |
prompt,
|
110 |
system_prompt=None,
|
111 |
+
# model=model,
|
112 |
+
max_new_tokens=250,
|
113 |
temperature=0.2,
|
114 |
top_p=0.95,
|
115 |
top_k=50,
|
116 |
stream=False,
|
117 |
debug=False,
|
118 |
):
|
119 |
+
global chat_bot
|
120 |
if temperature < 1e-2:
|
121 |
temperature = 1e-2
|
122 |
+
answer=chat_client.predict(
|
123 |
+
prompt, # str in 'Type an input and press Enter' Textbox component
|
124 |
+
chat_bot,
|
125 |
+
fn_index=1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
)
|
127 |
+
chat_bot = answer[1]
|
128 |
+
yield answer[1][0][1]
|
129 |
+
# prompt = prompt_builder(prompt)
|
130 |
+
# tokens = tokenizer(
|
131 |
+
# prompt,
|
132 |
+
# return_tensors='pt'
|
133 |
+
# ).input_ids.cuda()
|
134 |
+
|
135 |
+
# # Generate output
|
136 |
+
# for i in range(max_new_tokens):
|
137 |
+
# generation_output = model.generate(
|
138 |
+
# tokens,
|
139 |
+
# do_sample=True,
|
140 |
+
# temperature=temperature,
|
141 |
+
# top_p=top_p,
|
142 |
+
# top_k=top_k,
|
143 |
+
# max_new_tokens=1
|
144 |
+
# )
|
145 |
+
# tokens = generation_output
|
146 |
+
# yield tokenizer.decode(generation_output[0][-1])
|
147 |
|
148 |
# load the index for the data
|
149 |
def load_hnsw_index(index_file):
|
|
|
201 |
chat_history += f"Assistant: {turn[1]}\n"
|
202 |
|
203 |
condense_question_prompt = create_condense_question_prompt(query, chat_history)
|
204 |
+
# condensed_question = json.loads(get_completion(condense_question_prompt, max_new_tokens=64, temperature=0))
|
205 |
+
condensed_question = "".join([token for token in get_completion(condense_question_prompt, max_new_tokens=64, temperature=0)])
|
206 |
+
return condensed_question
|
207 |
|
208 |
|
209 |
DEFAULT_SYSTEM_PROMPT = """\
|
210 |
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
211 |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
|
212 |
"""
|
213 |
+
# MAX_MAX_NEW_TOKENS = 2048
|
214 |
+
# DEFAULT_MAX_NEW_TOKENS = 1024
|
215 |
+
# MAX_INPUT_TOKEN_LENGTH = 4000
|
216 |
|
217 |
DESCRIPTION = """
|
218 |
# AVA Southampton Chatbot 🤗
|
|
|
290 |
|
291 |
output = ""
|
292 |
for idx, response in enumerate(generator):
|
293 |
+
# token = response["choices"][0]["delta"].get("content", "") or ""
|
294 |
+
token = response
|
295 |
output += token
|
296 |
if idx == 0:
|
297 |
history.append((message, output))
|
|
|
299 |
history[-1] = (message, output)
|
300 |
|
301 |
history = [
|
302 |
+
(wrap_html_code(history[i][0]), wrap_html_code(history[i][1]))
|
303 |
for i in range(0, len(history))
|
304 |
]
|
305 |
yield history
|
|
|
509 |
api_name=False,
|
510 |
)
|
511 |
|
512 |
+
demo.queue(max_size=20).launch(debug=True)
|