# Copyright 2023 MosaicML spaces authors # SPDX-License-Identifier: Apache-2.0 import datetime import os from threading import Event, Thread from uuid import uuid4 import gradio as gr import requests import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, ) model_name = "mosaicml/mpt-7b-chat" max_new_tokens = 1536 # # small testing model: # model_name = "gpt2" # max_new_tokens = 128 auth_token = os.getenv("HF_TOKEN", None) print(f"Starting to load the model {model_name} into memory") m = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, use_auth_token=auth_token, max_seq_len=8192, ).cuda() tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=auth_token) stop_token_ids = tok.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"]) print(f"Successfully loaded the model {model_name} into memory") start_message = """<|im_start|>system - You are a helpful assistant chatbot trained by MosaicML. - You answer questions. - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|> """ class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False def convert_history_to_text(history): text = start_message + "".join( [ "".join( [ f"<|im_start|>user\n{item[0]}<|im_end|>", f"<|im_start|>assistant\n{item[1]}<|im_end|>", ] ) for item in history[:-1] ] ) text += "".join( [ "".join( [ f"<|im_start|>user\n{history[-1][0]}<|im_end|>", f"<|im_start|>assistant\n{history[-1][1]}", ] ) ] ) return text def log_conversation(conversation_id, history, messages, generate_kwargs): logging_url = os.getenv("LOGGING_URL", None) if logging_url is None: return timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") data = { "conversation_id": conversation_id, "timestamp": timestamp, "history": history, "messages": messages, "generate_kwargs": generate_kwargs, } try: requests.post(logging_url, json=data) except requests.exceptions.RequestException as e: print(f"Error logging conversation: {e}") def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): print(f"history: {history}") # Initialize a StopOnTokens object stop = StopOnTokens() # Construct the input message string for the model by concatenating the current system message and conversation history messages = convert_history_to_text(history) # Tokenize the messages string input_ids = tok(messages, return_tensors="pt").input_ids input_ids = input_ids.to(m.device) streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0.0, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, streamer=streamer, stopping_criteria=StoppingCriteriaList([stop]), ) stream_complete = Event() def generate_and_signal_complete(): m.generate(**generate_kwargs) stream_complete.set() def log_after_stream_complete(): stream_complete.wait() log_conversation( conversation_id, history, messages, { "top_k": top_k, "top_p": top_p, "temperature": temperature, "repetition_penalty": repetition_penalty, }, ) t1 = Thread(target=generate_and_signal_complete) t1.start() t2 = Thread(target=log_after_stream_complete) t2.start() # Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: partial_text += new_text history[-1][1] = partial_text yield history def get_uuid(): return str(uuid4()) with gr.Blocks( theme=gr.themes.Soft(), css=".disclaimer {font-variant-caps: all-small-caps;}", ) as demo: conversation_id = gr.State(get_uuid) gr.Markdown( """

MosaicML MPT-7B-Chat

This demo is of [MPT-7B-Chat](https://huggingface.co/mosaicml/mpt-7b-chat). It is based on [MPT-7B](https://huggingface.co/mosaicml/mpt-7b) fine-tuned with approximately [171,000 conversation samples from this dataset](https://huggingface.co/datasets/sam-mosaic/vicuna_alpaca_hc3_chatml) and another [217,000 from this dataset](https://huggingface.co/datasets/sam-mosaic/hhrlhf_evol_chatml). If you're interested in [training](https://www.mosaicml.com/training) and [deploying](https://www.mosaicml.com/inference) your own MPT or LLMs, [sign up](https://forms.mosaicml.com/demo?utm_source=huggingface&utm_medium=referral&utm_campaign=mpt-7b) for MosaicML platform. This is running on a smaller, shared GPU, so it may take a few seconds to respond. If you want to run it on your own GPU, you can [download the model from HuggingFace](https://huggingface.co/mosaicml/mpt-7b-chat) and run it locally. Or [Duplicate the Space](https://huggingface.co/spaces/mosaicml/mpt-7b-chat?duplicate=true) to skip the queue and run in a private space. """ ) chatbot = gr.Chatbot().style(height=500) with gr.Row(): with gr.Column(): msg = gr.Textbox( label="Chat Message Box", placeholder="Chat Message Box", show_label=False, ).style(container=False) with gr.Column(): with gr.Row(): submit = gr.Button("Submit") stop = gr.Button("Stop") clear = gr.Button("Clear") with gr.Row(): with gr.Accordion("Advanced Options:", open=False): with gr.Row(): with gr.Column(): with gr.Row(): temperature = gr.Slider( label="Temperature", value=0.1, minimum=0.0, maximum=1.0, step=0.1, interactive=True, info="Higher values produce more diverse outputs", ) with gr.Column(): with gr.Row(): top_p = gr.Slider( label="Top-p (nucleus sampling)", value=1.0, minimum=0.0, maximum=1, step=0.01, interactive=True, info=( "Sample from the smallest possible set of tokens whose cumulative probability " "exceeds top_p. Set to 1 to disable and sample from all tokens." ), ) with gr.Column(): with gr.Row(): top_k = gr.Slider( label="Top-k", value=0, minimum=0.0, maximum=200, step=1, interactive=True, info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.", ) with gr.Column(): with gr.Row(): repetition_penalty = gr.Slider( label="Repetition Penalty", value=1.1, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repetition — 1.0 to disable.", ) with gr.Row(): gr.Markdown( "Disclaimer: MPT-7B can produce factually incorrect output, and should not be relied on to produce " "factually accurate information. MPT-7B was trained on various public datasets; while great efforts " "have been taken to clean the pretraining data, it is possible that this model could generate lewd, " "biased, or otherwise offensive outputs.", elem_classes=["disclaimer"], ) with gr.Row(): gr.Markdown( "[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)", elem_classes=["disclaimer"], ) submit_event = msg.submit( fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False, ).then( fn=bot, inputs=[ chatbot, temperature, top_p, top_k, repetition_penalty, conversation_id, ], outputs=chatbot, queue=True, ) submit_click_event = submit.click( fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False, ).then( fn=bot, inputs=[ chatbot, temperature, top_p, top_k, repetition_penalty, conversation_id, ], outputs=chatbot, queue=True, ) stop.click( fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False, ) clear.click(lambda: None, None, chatbot, queue=False) demo.queue(max_size=128, concurrency_count=2) demo.launch()