Spaces:
Runtime error
Runtime error
File size: 5,421 Bytes
7cc686b 119226a 7cc686b d2f200e 5d23cab 7cc686b 5d23cab 7cc686b d2f200e 7cc686b d2f200e 7cc686b d2f200e 7cc686b f0812f2 d2f200e 23eb0dd f0812f2 d2f200e f0812f2 7cc686b d2f200e 7cc686b 23eb0dd 7cc686b 23eb0dd 7cc686b 23eb0dd f0812f2 7cc686b d2f200e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import time
import numpy as np
from torch.nn import functional as F
import os
from threading import Thread
# init
tok = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
m = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
m = m.to('cuda:0')
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
#stop_ids = [[29, 13961, 31], [29, 12042, 31], 1, 0]
stop_ids = [29, 0]
for stop_id in stop_ids:
#print(f"^^input ids - {input_ids}")
if input_ids[0][-1] == stop_id:
return True
return False
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def chat(history, top_p, top_k, temperature):
print(f"history is - {history}")
# Initialize a StopOnTokens object
stop = StopOnTokens()
# Construct the input message string for the model by concatenating the current system message and conversation history
messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
for item in history])
print(f"messages is - {messages}")
# Tokenize the messages string
model_inputs = tok([messages], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
tok, timeout=10., skip_prompt=False, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=top_p, #0.95,
top_k=top_k, #1000,
temperature=temperature, #1.0,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=m.generate, kwargs=generate_kwargs)
t.start()
# Initialize an empty string to store the generated text
partial_text = ""
for new_text in streamer:
#print(new_text)
if new_text != '<':
partial_text += new_text
history[-1][1] = partial_text.split('<bot>:')[-1]
# Yield an empty string to clean up the message textbox and the updated conversation history
yield history
return partial_text
title = """<h1 align="center">🔥RedPajama-INCITE-Chat-3B-v1</h1><br><h2 align="center">🏃♂️💨Streaming with Transformers & Gradio💪</h2>"""
description = """<br><br><h3 align="center">This is a RedPajama Chat model fine-tuned using data from Dolly 2.0 and Open Assistant over the RedPajama-INCITE-Base-3B-v1 base model.</h3>"""
theme = gr.themes.Soft(
primary_hue=gr.themes.Color("#ededed", "#fee2e2", "#fecaca", "#fca5a5", "#f87171", "#ef4444", "#dc2626", "#b91c1c", "#991b1b", "#7f1d1d", "#6c1e1e"),
neutral_hue="red",
)
with gr.Blocks(theme=theme) as demo:
gr.HTML(title)
gr.HTML('''<center><a href="https://huggingface.co/spaces/ysharma/RedPajama-Chat-3B?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
chatbot = gr.Chatbot().style(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box",
show_label=False).style(container=False)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
stop = gr.Button("Stop")
clear = gr.Button("Clear")
#Advanced options - top_p, temperature, top_k
with gr.Accordion("Advanced Options:", open=False):
top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p",)
top_k = gr.Slider(minimum=0.0, maximum=1000, value=1000, step=1, interactive=True, label="Top-k", )
temperature = gr.Slider( minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=chat, inputs=[chatbot, top_p, top_k, temperature], outputs=[chatbot], queue=True) #inputs=[system_msg, chatbot]
submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
fn=chat, inputs=[chatbot, top_p, top_k, temperature], outputs=[chatbot], queue=True) #inputs=[system_msg, chatbot]
stop.click(fn=None, inputs=None, outputs=None, cancels=[
submit_event, submit_click_event], queue=False)
clear.click(lambda: None, None, [chatbot], queue=False)
gr.Examples([
["Hello there! How are you doing?"],
["Can you explain to me briefly what is Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["What are some common mistakes to avoid when writing code?"],
["Write a 500-word blog post on “Benefits of Artificial Intelligence"]
], inputs=msg, label= "Click on any example and press the 'Submit' button"
)
gr.HTML(description)
demo.queue(max_size=32, concurrency_count=2)
demo.launch(debug=True) |