File size: 5,421 Bytes
7cc686b
 
 
 
 
 
 
 
 
 
 
 
119226a
7cc686b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2f200e
5d23cab
 
7cc686b
 
 
 
 
 
5d23cab
 
7cc686b
 
 
 
 
 
 
 
 
d2f200e
 
 
7cc686b
 
 
 
 
 
 
 
 
d2f200e
7cc686b
 
 
d2f200e
7cc686b
 
 
 
f0812f2
d2f200e
23eb0dd
 
 
 
 
f0812f2
d2f200e
f0812f2
7cc686b
 
 
 
 
 
 
 
 
 
 
d2f200e
 
 
 
 
 
7cc686b
 
23eb0dd
7cc686b
23eb0dd
7cc686b
 
 
23eb0dd
 
 
 
 
 
 
 
 
f0812f2
7cc686b
 
d2f200e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import time
import numpy as np
from torch.nn import functional as F
import os
from threading import Thread

# init
tok = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
m = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
m = m.to('cuda:0')

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        #stop_ids = [[29, 13961, 31], [29, 12042, 31], 1, 0]
        stop_ids = [29, 0]
        for stop_id in stop_ids:
            #print(f"^^input ids - {input_ids}")
            if input_ids[0][-1] == stop_id:
                return True
        return False

        
def user(message, history):
    # Append the user's message to the conversation history
    return "", history + [[message, ""]]



def chat(history, top_p, top_k, temperature): 

    print(f"history is - {history}")
    # Initialize a StopOnTokens object
    stop = StopOnTokens()

    # Construct the input message string for the model by concatenating the current system message and conversation history
    messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])  #curr_system_message + 
                for item in history])
    print(f"messages is - {messages}")
    
    # Tokenize the messages string
    model_inputs = tok([messages], return_tensors="pt").to("cuda")
    streamer = TextIteratorStreamer(
        tok, timeout=10., skip_prompt=False, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=top_p, #0.95,
        top_k=top_k, #1000,
        temperature=temperature, #1.0,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop])
    )
    t = Thread(target=m.generate, kwargs=generate_kwargs)
    t.start()

    # Initialize an empty string to store the generated text
    partial_text = ""
    for new_text in streamer:
        #print(new_text)
        if new_text != '<':
            partial_text += new_text
            history[-1][1] = partial_text.split('<bot>:')[-1]
            # Yield an empty string to clean up the message textbox and the updated conversation history
            yield history
    return partial_text


title = """<h1 align="center">🔥RedPajama-INCITE-Chat-3B-v1</h1><br><h2 align="center">🏃‍♂️💨Streaming with Transformers & Gradio💪</h2>"""
description = """<br><br><h3 align="center">This is a RedPajama Chat model fine-tuned using data from Dolly 2.0 and Open Assistant over the RedPajama-INCITE-Base-3B-v1 base model.</h3>"""
theme = gr.themes.Soft(
    primary_hue=gr.themes.Color("#ededed", "#fee2e2", "#fecaca", "#fca5a5", "#f87171", "#ef4444", "#dc2626", "#b91c1c", "#991b1b", "#7f1d1d", "#6c1e1e"),
    neutral_hue="red",
)


with gr.Blocks(theme=theme) as demo:
    gr.HTML(title)
    gr.HTML('''<center><a href="https://huggingface.co/spaces/ysharma/RedPajama-Chat-3B?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
    chatbot = gr.Chatbot().style(height=500)
    with gr.Row():
        with gr.Column():
            msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box",
                             show_label=False).style(container=False)
        with gr.Column():
            with gr.Row():
                submit = gr.Button("Submit")
                stop = gr.Button("Stop")
                clear = gr.Button("Clear")
    
    #Advanced options - top_p, temperature, top_k
    with gr.Accordion("Advanced Options:", open=False):
        top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p",)
        top_k = gr.Slider(minimum=0.0, maximum=1000, value=1000, step=1, interactive=True, label="Top-k", )
        temperature = gr.Slider( minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",)

    submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
        fn=chat, inputs=[chatbot, top_p, top_k, temperature], outputs=[chatbot], queue=True)  #inputs=[system_msg, chatbot]
    submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
        fn=chat, inputs=[chatbot, top_p, top_k, temperature], outputs=[chatbot], queue=True)  #inputs=[system_msg, chatbot]
    stop.click(fn=None, inputs=None, outputs=None, cancels=[
               submit_event, submit_click_event], queue=False)
    clear.click(lambda: None, None, [chatbot], queue=False)

    gr.Examples([
        ["Hello there! How are you doing?"],
        ["Can you explain to me briefly what is Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["What are some common mistakes to avoid when writing code?"],
        ["Write a 500-word blog post on “Benefits of Artificial Intelligence"]
    ], inputs=msg, label= "Click on any example and press the 'Submit' button"
      )
    gr.HTML(description)

demo.queue(max_size=32, concurrency_count=2)
demo.launch(debug=True)