strauss23 jon-tow commited on
Commit
144b8a2
0 Parent(s):

Duplicate from CarperAI/StableVicuna

Browse files

Co-authored-by: Jonathan Tow <jon-tow@users.noreply.huggingface.co>

Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +138 -0
  4. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: StableVicuna
3
+ emoji: 🦙
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ duplicated_from: CarperAI/StableVicuna
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ from string import Template
4
+ from threading import Thread
5
+
6
+ import torch
7
+ import gradio as gr
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BatchEncoding, TextIteratorStreamer
9
+
10
+
11
+ auth_token = os.environ.get("HUGGINGFACE_TOKEN")
12
+ tokenizer = AutoTokenizer.from_pretrained(
13
+ "CarperAI/stable-vicuna-13b-fp16",
14
+ use_auth_token=auth_token if auth_token else True,
15
+ )
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ "CarperAI/stable-vicuna-13b-fp16",
18
+ torch_dtype=torch.float16,
19
+ low_cpu_mem_usage=True,
20
+ device_map="auto",
21
+ use_auth_token=auth_token if auth_token else True,
22
+ )
23
+ model.eval()
24
+
25
+
26
+ max_context_length = model.config.max_position_embeddings
27
+ max_new_tokens = 768
28
+
29
+
30
+ prompt_template = Template("""\
31
+ ### Human: $human
32
+ ### Assistant: $bot\
33
+ """)
34
+
35
+
36
+ system_prompt = "### Assistant: I am StableVicuna, a large language model created by CarperAI. I am here to chat!"
37
+ system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt")
38
+ max_sys_tokens = system_prompt_tokens['input_ids'].size(-1)
39
+
40
+
41
+ def bot(history):
42
+ history = history or []
43
+
44
+ # Inject prompt formatting into the history
45
+ prompt_history = []
46
+ for human, bot in history:
47
+ if bot is not None:
48
+ bot = bot.replace("<br>", "\n")
49
+ bot = bot.rstrip()
50
+ prompt_history.append(
51
+ prompt_template.substitute(
52
+ human=human, bot=bot if bot is not None else "")
53
+ )
54
+
55
+ msg_tokens = tokenizer(
56
+ "\n\n".join(prompt_history).strip(),
57
+ return_tensors="pt",
58
+ add_special_tokens=False # Use <BOS> from the system prompt
59
+ )
60
+
61
+ # Take only the most recent context up to the max context length and prepend the
62
+ # system prompt with the messages
63
+ max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
64
+ inputs = BatchEncoding({
65
+ k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
66
+ for k in msg_tokens
67
+ }).to('cuda')
68
+ # Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models
69
+ if inputs.get("token_type_ids", None) is not None:
70
+ inputs.pop("token_type_ids")
71
+
72
+ streamer = TextIteratorStreamer(
73
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
74
+ )
75
+ generate_kwargs = dict(
76
+ inputs,
77
+ streamer=streamer,
78
+ max_new_tokens=max_new_tokens,
79
+ do_sample=True,
80
+ top_p=1.0,
81
+ temperature=1.0,
82
+ )
83
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
84
+ thread.start()
85
+
86
+ partial_text = ""
87
+ for new_text in streamer:
88
+ # Process out the prompt separator
89
+ new_text = new_text.replace("<br>", "\n")
90
+ if "###" in new_text:
91
+ new_text = new_text.split("###")[0]
92
+ partial_text += new_text.strip()
93
+ history[-1][1] = partial_text
94
+ break
95
+ else:
96
+ # Filter empty trailing new lines
97
+ if new_text == "\n":
98
+ new_text = new_text.strip()
99
+ partial_text += new_text
100
+ history[-1][1] = partial_text
101
+ yield history
102
+ return partial_text
103
+
104
+
105
+ def user(user_message, history):
106
+ return "", history + [[user_message, None]]
107
+
108
+
109
+ with gr.Blocks() as demo:
110
+ gr.Markdown("# StableVicuna by CarperAI")
111
+ gr.HTML("<a href='https://huggingface.co/CarperAI/stable-vicuna-13b-delta'><code>CarperAI/stable-vicuna-13b-delta</a>")
112
+ gr.HTML('''<center><a href="https://huggingface.co/spaces/CarperAI/StableVicuna?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')
113
+
114
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
115
+ state = gr.State([])
116
+ with gr.Row():
117
+ with gr.Column():
118
+ msg = gr.Textbox(
119
+ label="Send a message",
120
+ placeholder="Send a message",
121
+ show_label=False
122
+ ).style(container=False)
123
+ with gr.Column():
124
+ with gr.Row():
125
+ submit = gr.Button("Send")
126
+ stop = gr.Button("Stop")
127
+ clear = gr.Button("Clear History")
128
+
129
+ submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
130
+ fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
131
+ submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
132
+ fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
133
+
134
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False)
135
+ clear.click(lambda: None, None, [chatbot], queue=True)
136
+
137
+ demo.queue(max_size=32)
138
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ accelerate
2
+ torch
3
+ bitsandbytes
4
+ transformers>=4.28.0,<4.29.0