# Copyright 2023 MosaicML spaces authors # SPDX-License-Identifier: Apache-2.0 # and # the https://huggingface.co/spaces/HuggingFaceH4/databricks-dolly authors import datetime import os from threading import Event, Thread from uuid import uuid4 import gradio as gr import requests import torch from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from quick_pipeline import InstructionTextGenerationPipeline as pipeline # Configuration HF_TOKEN = os.getenv("HF_TOKEN", None) examples = [ # to do: add coupled hparams so e.g. poem has higher temp "Write a travel blog about a 3-day trip to Thailand.", "Write a short story about a robot that has a nice day.", "Convert the following to a single line of JSON:\n\n```name: John\nage: 30\naddress:\n street:123 Main St.\n city: San Francisco\n state: CA\n zip: 94101\n```", "Write a quick email to congratulate MosaicML about the launch of their inference offering.", "Explain how a candle works to a 6 year old in a few sentences.", "What are some of the most common misconceptions about birds?", ] # Initialize the model and tokenizer generate = pipeline( "mosaicml/mpt-7b-instruct", torch_dtype=torch.bfloat16, trust_remote_code=True, use_auth_token=HF_TOKEN, ) stop_token_ids = generate.tokenizer.convert_tokens_to_ids(["<|endoftext|>"]) # Define a custom stopping criteria class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False def log_conversation(session_id, instruction, response, generate_kwargs): logging_url = os.getenv("LOGGING_URL", None) if logging_url is None: return timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") data = { "session_id": session_id, "timestamp": timestamp, "instruction": instruction, "response": response, "generate_kwargs": generate_kwargs, } try: requests.post(logging_url, json=data) except requests.exceptions.RequestException as e: print(f"Error logging conversation: {e}") def process_stream(instruction, temperature, top_p, top_k, max_new_tokens, session_id): # Tokenize the input input_ids = generate.tokenizer( generate.format_instruction(instruction), return_tensors="pt" ).input_ids input_ids = input_ids.to(generate.model.device) # Initialize the streamer and stopping criteria streamer = TextIteratorStreamer( generate.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) stop = StopOnTokens() if temperature < 0.1: temperature = 0.0 do_sample = False else: do_sample = True gkw = { **generate.generate_kwargs, **{ "input_ids": input_ids, "max_new_tokens": max_new_tokens, "temperature": temperature, "do_sample": do_sample, "top_p": top_p, "top_k": top_k, "streamer": streamer, "stopping_criteria": StoppingCriteriaList([stop]), }, } response = "" stream_complete = Event() def generate_and_signal_complete(): generate.model.generate(**gkw) stream_complete.set() def log_after_stream_complete(): stream_complete.wait() log_conversation( session_id, instruction, response, { "top_k": top_k, "top_p": top_p, "temperature": temperature, }, ) t1 = Thread(target=generate_and_signal_complete) t1.start() t2 = Thread(target=log_after_stream_complete) t2.start() for new_text in streamer: response += new_text yield response with gr.Blocks( theme=gr.themes.Soft(), css=".disclaimer {font-variant-caps: all-small-caps;}", ) as demo: session_id = gr.State(lambda: str(uuid4())) gr.Markdown( """