# Copyright 2023 MosaicML spaces authors # SPDX-License-Identifier: Apache-2.0 import datetime import os from threading import Event, Thread from uuid import uuid4 import gradio as gr import requests import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, ) model_name = "mosaicml/mpt-7b-chat" max_new_tokens = 1536 # # small testing model: # model_name = "gpt2" # max_new_tokens = 128 auth_token = os.getenv("HF_TOKEN", None) print(f"Starting to load the model {model_name} into memory") m = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, use_auth_token=auth_token, max_seq_len=8192, ).cuda() tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=auth_token) stop_token_ids = tok.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"]) print(f"Successfully loaded the model {model_name} into memory") start_message = """<|im_start|>system - You are a helpful assistant chatbot trained by MosaicML. - You answer questions. - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|> """ class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False def convert_history_to_text(history): text = start_message + "".join( [ "".join( [ f"<|im_start|>user\n{item[0]}<|im_end|>", f"<|im_start|>assistant\n{item[1]}<|im_end|>", ] ) for item in history[:-1] ] ) text += "".join( [ "".join( [ f"<|im_start|>user\n{history[-1][0]}<|im_end|>", f"<|im_start|>assistant\n{history[-1][1]}", ] ) ] ) return text def log_conversation(conversation_id, history, messages, generate_kwargs): logging_url = os.getenv("LOGGING_URL", None) if logging_url is None: return timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") data = { "conversation_id": conversation_id, "timestamp": timestamp, "history": history, "messages": messages, "generate_kwargs": generate_kwargs, } try: requests.post(logging_url, json=data) except requests.exceptions.RequestException as e: print(f"Error logging conversation: {e}") def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): print(f"history: {history}") # Initialize a StopOnTokens object stop = StopOnTokens() # Construct the input message string for the model by concatenating the current system message and conversation history messages = convert_history_to_text(history) # Tokenize the messages string input_ids = tok(messages, return_tensors="pt").input_ids input_ids = input_ids.to(m.device) streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0.0, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, streamer=streamer, stopping_criteria=StoppingCriteriaList([stop]), ) stream_complete = Event() def generate_and_signal_complete(): m.generate(**generate_kwargs) stream_complete.set() def log_after_stream_complete(): stream_complete.wait() log_conversation( conversation_id, history, messages, { "top_k": top_k, "top_p": top_p, "temperature": temperature, "repetition_penalty": repetition_penalty, }, ) t1 = Thread(target=generate_and_signal_complete) t1.start() t2 = Thread(target=log_after_stream_complete) t2.start() # Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: partial_text += new_text history[-1][1] = partial_text yield history def get_uuid(): return str(uuid4()) with gr.Blocks( theme=gr.themes.Soft(), css=".disclaimer {font-variant-caps: all-small-caps;}", ) as demo: conversation_id = gr.State(get_uuid) gr.Markdown( """