import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import time import numpy as np from torch.nn import functional as F import os from threading import Thread # init tok = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1") m = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16) m = m.to('cuda:0') class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: #stop_ids = [[29, 13961, 31], [29, 12042, 31], 1, 0] stop_ids = [29, 0] for stop_id in stop_ids: #print(f"^^input ids - {input_ids}") if input_ids[0][-1] == stop_id: return True return False def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] def chat(history, top_p, top_k, temperature): print(f"history is - {history}") # Initialize a StopOnTokens object stop = StopOnTokens() # Construct the input message string for the model by concatenating the current system message and conversation history messages = "".join(["".join(["\n:"+item[0], "\n:"+item[1]]) #curr_system_message + for item in history]) print(f"messages is - {messages}") # Tokenize the messages string model_inputs = tok([messages], return_tensors="pt").to("cuda") streamer = TextIteratorStreamer( tok, timeout=10., skip_prompt=False, skip_special_tokens=True) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=top_p, #0.95, top_k=top_k, #1000, temperature=temperature, #1.0, num_beams=1, stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=m.generate, kwargs=generate_kwargs) t.start() # Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: #print(new_text) if new_text != '<': partial_text += new_text history[-1][1] = partial_text.split(':')[-1] # Yield an empty string to clean up the message textbox and the updated conversation history yield history return partial_text title = """

🔥RedPajama-INCITE-Chat-3B-v1


🏃‍♂️💨Streaming with Transformers & Gradio💪

""" description = """

This is a RedPajama Chat model fine-tuned using data from Dolly 2.0 and Open Assistant over the RedPajama-INCITE-Base-3B-v1 base model.

""" theme = gr.themes.Soft( primary_hue=gr.themes.Color("#ededed", "#fee2e2", "#fecaca", "#fca5a5", "#f87171", "#ef4444", "#dc2626", "#b91c1c", "#991b1b", "#7f1d1d", "#6c1e1e"), neutral_hue="red", ) with gr.Blocks(theme=theme) as demo: gr.HTML(title) gr.HTML('''
Duplicate SpaceDuplicate the Space to skip the queue and run in a private space
''') chatbot = gr.Chatbot().style(height=500) with gr.Row(): with gr.Column(): msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box", show_label=False).style(container=False) with gr.Column(): with gr.Row(): submit = gr.Button("Submit") stop = gr.Button("Stop") clear = gr.Button("Clear") #Advanced options - top_p, temperature, top_k with gr.Accordion("Advanced Options:", open=False): top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p",) top_k = gr.Slider(minimum=0.0, maximum=1000, value=1000, step=1, interactive=True, label="Top-k", ) temperature = gr.Slider( minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",) submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( fn=chat, inputs=[chatbot, top_p, top_k, temperature], outputs=[chatbot], queue=True) #inputs=[system_msg, chatbot] submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( fn=chat, inputs=[chatbot, top_p, top_k, temperature], outputs=[chatbot], queue=True) #inputs=[system_msg, chatbot] stop.click(fn=None, inputs=None, outputs=None, cancels=[ submit_event, submit_click_event], queue=False) clear.click(lambda: None, None, [chatbot], queue=False) gr.Examples([ ["Hello there! How are you doing?"], ["Can you explain to me briefly what is Python programming language?"], ["Explain the plot of Cinderella in a sentence."], ["What are some common mistakes to avoid when writing code?"], ["Write a 500-word blog post on “Benefits of Artificial Intelligence"] ], inputs=msg, label= "Click on any example and press the 'Submit' button" ) gr.HTML(description) demo.queue(max_size=32, concurrency_count=2) demo.launch(debug=True)