Spaces:
Runtime error
Runtime error
# Copyright 2023 MosaicML spaces authors | |
# SPDX-License-Identifier: Apache-2.0 | |
# anything broken is Luke's fault | |
import datetime | |
import os | |
from threading import Event, Thread | |
from uuid import uuid4 | |
import gradio as gr | |
import requests | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
TextIteratorStreamer, | |
) | |
model_name = "mosaicml/mpt-7b-chat" | |
max_new_tokens = 1536 | |
# # small testing model: | |
# model_name = "gpt2" | |
# max_new_tokens = 128 | |
auth_token = os.getenv("HF_TOKEN", None) | |
print(f"Starting to load the model {model_name} into memory") | |
m = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
use_auth_token=auth_token, | |
max_seq_len=8192, | |
).cuda() | |
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=auth_token) | |
stop_token_ids = tok.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"]) | |
print(f"Successfully loaded the model {model_name} into memory") | |
start_message = """<|im_start|>system | |
- You are an assistant who is helping people to recover from the experience of COVID by getting the perspective of those there before. Your knowledge and training is entirely focused on the six months after the United States first locked down for COVID. Pretend you are speaking with someone in the year 2023, and react to the information they tell you as though you are a person in 2019 who has no knowledge beyond that. They are a poster on reddit's the_donald and other fringe right wing sites. Show exactly how they would react to the news and share all feedback.<|im_end|> | |
""" | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
for stop_id in stop_token_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
def convert_history_to_text(history): | |
text = start_message + "".join( | |
[ | |
"".join( | |
[ | |
f"<|im_start|>user\n{item[0]}<|im_end|>", | |
f"<|im_start|>assistant\n{item[1]}<|im_end|>", | |
] | |
) | |
for item in history[:-1] | |
] | |
) | |
text += "".join( | |
[ | |
"".join( | |
[ | |
f"<|im_start|>user\n{history[-1][0]}<|im_end|>", | |
f"<|im_start|>assistant\n{history[-1][1]}", | |
] | |
) | |
] | |
) | |
return text | |
def log_conversation(conversation_id, history, messages, generate_kwargs): | |
logging_url = os.getenv("LOGGING_URL", None) | |
if logging_url is None: | |
return | |
timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") | |
data = { | |
"conversation_id": conversation_id, | |
"timestamp": timestamp, | |
"history": history, | |
"messages": messages, | |
"generate_kwargs": generate_kwargs, | |
} | |
try: | |
requests.post(logging_url, json=data) | |
except requests.exceptions.RequestException as e: | |
print(f"Error logging conversation: {e}") | |
def user(message, history): | |
# Append the user's message to the conversation history | |
return "", history + [[message, ""]] | |
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): | |
print(f"history: {history}") | |
# Initialize a StopOnTokens object | |
stop = StopOnTokens() | |
# Construct the input message string for the model by concatenating the current system message and conversation history | |
messages = convert_history_to_text(history) | |
# Tokenize the messages string | |
input_ids = tok(messages, return_tensors="pt").input_ids | |
input_ids = input_ids.to(m.device) | |
streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
do_sample=temperature > 0.0, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
streamer=streamer, | |
stopping_criteria=StoppingCriteriaList([stop]), | |
) | |
stream_complete = Event() | |
def generate_and_signal_complete(): | |
m.generate(**generate_kwargs) | |
stream_complete.set() | |
def log_after_stream_complete(): | |
stream_complete.wait() | |
log_conversation( | |
conversation_id, | |
history, | |
messages, | |
{ | |
"top_k": top_k, | |
"top_p": top_p, | |
"temperature": temperature, | |
"repetition_penalty": repetition_penalty, | |
}, | |
) | |
t1 = Thread(target=generate_and_signal_complete) | |
t1.start() | |
t2 = Thread(target=log_after_stream_complete) | |
t2.start() | |
# Initialize an empty string to store the generated text | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
history[-1][1] = partial_text | |
yield history | |
def get_uuid(): | |
return str(uuid4()) | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
css=".disclaimer {font-variant-caps: all-small-caps;}", | |
) as demo: | |
conversation_id = gr.State(get_uuid) | |
gr.Markdown( | |
"""<h1><center>cxntextMPT</center></h1> | |
This model engages three Matrix LLMs, others pending integration | |
Running on a potato, be patient. | |
""" | |
) | |
chatbot = gr.Chatbot().style(height=500) | |
with gr.Row(): | |
with gr.Column(): | |
msg = gr.Textbox( | |
label="Chat Message Box", | |
placeholder="Chat Message Box", | |
show_label=False, | |
).style(container=False) | |
with gr.Column(): | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
stop = gr.Button("Stop") | |
clear = gr.Button("Clear") | |
with gr.Row(): | |
with gr.Accordion("Advanced", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.1, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=1.0, | |
minimum=0.0, | |
maximum=1, | |
step=0.01, | |
interactive=True, | |
info=( | |
"Sample from the smallest possible set of tokens whose cumulative probability " | |
"exceeds top_p. Set to 1 to disable and sample from all tokens." | |
), | |
) | |
with gr.Column(): | |
with gr.Row(): | |
top_k = gr.Slider( | |
label="Top-k", | |
value=0, | |
minimum=0.0, | |
maximum=200, | |
step=1, | |
interactive=True, | |
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
repetition_penalty = gr.Slider( | |
label="Repetition Penalty", | |
value=1.1, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.1, | |
interactive=True, | |
info="Penalize repetition — 1.0 to disable.", | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"Disclaimer: All included models can produce factually incorrect output, and if they don't they will be forced to by Elon.", | |
elem_classes=["disclaimer"], | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)", | |
elem_classes=["disclaimer"], | |
) | |
submit_event = msg.submit( | |
fn=user, | |
inputs=[msg, chatbot], | |
outputs=[msg, chatbot], | |
queue=False, | |
).then( | |
fn=bot, | |
inputs=[ | |
chatbot, | |
temperature, | |
top_p, | |
top_k, | |
repetition_penalty, | |
conversation_id, | |
], | |
outputs=chatbot, | |
queue=True, | |
) | |
submit_click_event = submit.click( | |
fn=user, | |
inputs=[msg, chatbot], | |
outputs=[msg, chatbot], | |
queue=False, | |
).then( | |
fn=bot, | |
inputs=[ | |
chatbot, | |
temperature, | |
top_p, | |
top_k, | |
repetition_penalty, | |
conversation_id, | |
], | |
outputs=chatbot, | |
queue=True, | |
) | |
stop.click( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
cancels=[submit_event, submit_click_event], | |
queue=False, | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.queue(max_size=128, concurrency_count=2) | |
demo.launch() | |