File size: 2,438 Bytes
de87148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f3b818
de87148
 
 
 
6f3b818
de87148
 
 
 
 
 
 
 
 
 
6f3b818
de87148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f3b818
de87148
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import transformers
from torch import bfloat16
# from dotenv import load_dotenv  # if you wanted to adapt this for a repo that uses auth
from threading import Thread


#HF_AUTH = os.getenv('HF_AUTH')
model_id = "stabilityai/StableBeluga-7B"

bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=bfloat16
)
model_config = transformers.AutoConfig.from_pretrained(
    model_id,
    #use_auth_token=HF_AUTH
)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    config=model_config,
    quantization_config=bnb_config,
    device_map='auto',
    #use_auth_token=HF_AUTH
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_id,
    #use_auth_token=HF_AUTH
)


DESCRIPTION = """
# Stable Beluga 7B Chat
This is a streaming Chat Interface implementation of [StableBeluga-7B](https://huggingface.co/stabilityai/StableBeluga-7B). We'll use it to deploy a Discord bot that you can add to your server!

Sometimes the model doesn't appropriately hit its stop token. Feel free to hit "stop" and "retry" if this happens to you. Or PR a fix to stop the stream if the tokens for User: get hit or something.
"""

system_prompt = "You are helpful AI."

def prompt_build(system_prompt, user_inp, hist):
    prompt = f"""### System:\n{system_prompt}\n\n"""
    
    for pair in hist:
        prompt += f"""### User:\n{pair[0]}\n\n### Assistant:\n{pair[1]}\n\n"""

    prompt += f"""### User:\n{user_inp}\n\n### Assistant:"""
    return prompt

def chat(user_input, history):

    prompt = prompt_build(system_prompt, user_input, history)
    model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")

    streamer = transformers.TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=2048,
        do_sample=True,
        top_p=0.95,
        temperature=0.8,
        top_k=50
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    return model_output


with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    chatbot = gr.ChatInterface(fn=chat)

demo.queue().launch()