bofenghuang commited on
Commit
a4b1443
β€’
0 Parent(s):

Initial commit

Browse files
Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +10 -0
  3. app.py +356 -0
  4. requirements.txt +8 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Vigogne-Chat
3
+ emoji: πŸ¦™
4
+ colorFrom: purple
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
app.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 Bofeng Huang
4
+
5
+ """
6
+ Modified from: https://huggingface.co/spaces/mosaicml/mpt-7b-chat/raw/main/app.py
7
+
8
+ Usage:
9
+ CUDA_VISIBLE_DEVICES=0
10
+
11
+ python vigogne/demo/demo_chat.py \
12
+ --base_model_name_or_path huggyllama/llama-7b \
13
+ --lora_model_name_or_path bofenghuang/vigogne-chat-7b
14
+ """
15
+
16
+ # import datetime
17
+ import logging
18
+ import os
19
+ import re
20
+ from threading import Event, Thread
21
+ from typing import List, Optional
22
+
23
+
24
+ # from uuid import uuid4
25
+
26
+ import fire
27
+ import json
28
+ import gradio as gr
29
+
30
+ # import requests
31
+ import torch
32
+ from peft import PeftModel
33
+ from transformers import (
34
+ AutoModelForCausalLM,
35
+ AutoTokenizer,
36
+ GenerationConfig,
37
+ StoppingCriteriaList,
38
+ TextIteratorStreamer,
39
+ )
40
+
41
+ from vigogne.constants import ASSISTANT, USER
42
+ from vigogne.preprocess import generate_inference_chat_prompt
43
+ from vigogne.inference.inference_utils import StopWordsCriteria
44
+
45
+ logging.basicConfig(
46
+ format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
47
+ datefmt="%Y-%m-%dT%H:%M:%SZ",
48
+ )
49
+ logger = logging.getLogger(__name__)
50
+ logger.setLevel(logging.DEBUG)
51
+
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+
54
+ try:
55
+ if torch.backends.mps.is_available():
56
+ device = "mps"
57
+ except:
58
+ pass
59
+
60
+ logger.info(f"Model will be loaded on device `{device}`")
61
+
62
+
63
+ # def log_conversation(conversation_id, history, messages, generate_kwargs):
64
+ # logging_url = os.getenv("LOGGING_URL", None)
65
+ # if logging_url is None:
66
+ # return
67
+
68
+ # timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
69
+
70
+ # data = {
71
+ # "conversation_id": conversation_id,
72
+ # "timestamp": timestamp,
73
+ # "history": history,
74
+ # "messages": messages,
75
+ # "generate_kwargs": generate_kwargs,
76
+ # }
77
+
78
+ # try:
79
+ # requests.post(logging_url, json=data)
80
+ # except requests.exceptions.RequestException as e:
81
+ # print(f"Error logging conversation: {e}")
82
+
83
+
84
+ def user(message, history):
85
+ # Append the user's message to the conversation history
86
+ return "", history + [[message, ""]]
87
+
88
+
89
+ # def get_uuid():
90
+ # return str(uuid4())
91
+
92
+
93
+ def main(
94
+ base_model_name_or_path: str = "huggyllama/llama-7b",
95
+ lora_model_name_or_path: str = "bofenghuang/vigogne-chat-7b",
96
+ load_8bit: bool = False,
97
+ server_name: Optional[str] = "0.0.0.0",
98
+ server_port: Optional[str] = None,
99
+ share: bool = False,
100
+ ):
101
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False)
102
+
103
+ if device == "cuda":
104
+ model = AutoModelForCausalLM.from_pretrained(
105
+ base_model_name_or_path,
106
+ load_in_8bit=load_8bit,
107
+ torch_dtype=torch.float16,
108
+ device_map="auto",
109
+ )
110
+ model = PeftModel.from_pretrained(
111
+ model,
112
+ lora_model_name_or_path,
113
+ torch_dtype=torch.float16,
114
+ )
115
+ elif device == "mps":
116
+ model = AutoModelForCausalLM.from_pretrained(
117
+ base_model_name_or_path,
118
+ device_map={"": device},
119
+ torch_dtype=torch.float16,
120
+ )
121
+ model = PeftModel.from_pretrained(
122
+ model,
123
+ lora_model_name_or_path,
124
+ device_map={"": device},
125
+ torch_dtype=torch.float16,
126
+ )
127
+ else:
128
+ model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, device_map={"": device}, low_cpu_mem_usage=True)
129
+ model = PeftModel.from_pretrained(
130
+ model,
131
+ lora_model_name_or_path,
132
+ device_map={"": device},
133
+ )
134
+
135
+ if not load_8bit and device != "cpu":
136
+ model.half() # seems to fix bugs for some users.
137
+
138
+ model.eval()
139
+
140
+ # NB
141
+ stop_words = [f"<|{ASSISTANT}|>", f"<|{USER}|>"]
142
+ stop_words_criteria = StopWordsCriteria(stop_words=stop_words, tokenizer=tokenizer)
143
+ pattern_trailing_stop_words = re.compile(rf'(?:{"|".join([re.escape(stop_word) for stop_word in stop_words])})\W*$')
144
+
145
+ def bot(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, conversation_id=None):
146
+ # logger.info(f"History: {json.dumps(history, indent=4, ensure_ascii=False)}")
147
+
148
+ # Construct the input message string for the model by concatenating the current system message and conversation history
149
+ messages = generate_inference_chat_prompt(history, tokenizer)
150
+ logger.info(messages)
151
+ assert messages is not None, "User input is too long!"
152
+
153
+ # Tokenize the messages string
154
+ input_ids = tokenizer(messages, return_tensors="pt")["input_ids"].to(device)
155
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
156
+ generate_kwargs = dict(
157
+ input_ids=input_ids,
158
+ generation_config=GenerationConfig(
159
+ temperature=temperature,
160
+ do_sample=temperature > 0.0,
161
+ top_p=top_p,
162
+ top_k=top_k,
163
+ repetition_penalty=repetition_penalty,
164
+ max_new_tokens=max_new_tokens,
165
+ ),
166
+ streamer=streamer,
167
+ stopping_criteria=StoppingCriteriaList([stop_words_criteria]),
168
+ )
169
+
170
+ # stream_complete = Event()
171
+
172
+ def generate_and_signal_complete():
173
+ model.generate(**generate_kwargs)
174
+ # stream_complete.set()
175
+
176
+ # def log_after_stream_complete():
177
+ # stream_complete.wait()
178
+ # log_conversation(
179
+ # conversation_id,
180
+ # history,
181
+ # messages,
182
+ # {
183
+ # "top_k": top_k,
184
+ # "top_p": top_p,
185
+ # "temperature": temperature,
186
+ # "repetition_penalty": repetition_penalty,
187
+ # },
188
+ # )
189
+
190
+ t1 = Thread(target=generate_and_signal_complete)
191
+ t1.start()
192
+
193
+ # t2 = Thread(target=log_after_stream_complete)
194
+ # t2.start()
195
+
196
+ # Initialize an empty string to store the generated text
197
+ partial_text = ""
198
+ for new_text in streamer:
199
+ # NB
200
+ new_text = pattern_trailing_stop_words.sub("", new_text)
201
+
202
+ partial_text += new_text
203
+ history[-1][1] = partial_text
204
+ yield history
205
+
206
+ logger.info(f"Response: {history[-1][1]}")
207
+
208
+ with gr.Blocks(
209
+ theme=gr.themes.Soft(),
210
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
211
+ ) as demo:
212
+ # conversation_id = gr.State(get_uuid)
213
+ gr.Markdown(
214
+ """<h1><center>πŸ¦™ Vigogne Chat</center></h1>
215
+
216
+ This demo is of [Vigogne-Chat-7B](https://huggingface.co/bofenghuang/vigogne-chat-7b). It's based on [LLaMA-7B](https://github.com/facebookresearch/llama) finetuned to conduct French πŸ‡«πŸ‡· dialogues between a user and an AI assistant.
217
+
218
+ For more information, please visit the [Github repo](https://github.com/bofenghuang/vigogne) of the Vigogne project.
219
+ """
220
+ )
221
+ chatbot = gr.Chatbot().style(height=500)
222
+ with gr.Row():
223
+ with gr.Column():
224
+ msg = gr.Textbox(
225
+ label="Chat Message Box",
226
+ placeholder="Chat Message Box",
227
+ show_label=False,
228
+ ).style(container=False)
229
+ with gr.Column():
230
+ with gr.Row():
231
+ submit = gr.Button("Submit")
232
+ stop = gr.Button("Stop")
233
+ clear = gr.Button("Clear")
234
+ with gr.Row():
235
+ with gr.Accordion("Advanced Options:", open=False):
236
+ with gr.Row():
237
+ with gr.Column():
238
+ with gr.Row():
239
+ max_new_tokens = gr.Slider(
240
+ label="Max New Tokens",
241
+ value=512,
242
+ minimum=0,
243
+ maximum=1024,
244
+ step=1,
245
+ interactive=True,
246
+ info="The Max number of new tokens to generate.",
247
+ )
248
+ with gr.Column():
249
+ with gr.Row():
250
+ temperature = gr.Slider(
251
+ label="Temperature",
252
+ value=0.1,
253
+ minimum=0.0,
254
+ maximum=1.0,
255
+ step=0.1,
256
+ interactive=True,
257
+ info="Higher values produce more diverse outputs.",
258
+ )
259
+ with gr.Column():
260
+ with gr.Row():
261
+ top_p = gr.Slider(
262
+ label="Top-p (nucleus sampling)",
263
+ value=1.0,
264
+ minimum=0.0,
265
+ maximum=1,
266
+ step=0.01,
267
+ interactive=True,
268
+ info=(
269
+ "Sample from the smallest possible set of tokens whose cumulative probability "
270
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
271
+ ),
272
+ )
273
+ with gr.Column():
274
+ with gr.Row():
275
+ top_k = gr.Slider(
276
+ label="Top-k",
277
+ value=0,
278
+ minimum=0.0,
279
+ maximum=200,
280
+ step=1,
281
+ interactive=True,
282
+ info="Sample from a shortlist of top-k tokens β€” 0 to disable and sample from all tokens.",
283
+ )
284
+ with gr.Column():
285
+ with gr.Row():
286
+ repetition_penalty = gr.Slider(
287
+ label="Repetition Penalty",
288
+ value=1.0,
289
+ minimum=1.0,
290
+ maximum=2.0,
291
+ step=0.1,
292
+ interactive=True,
293
+ info="Penalize repetition β€” 1.0 to disable.",
294
+ )
295
+ with gr.Row():
296
+ gr.Markdown(
297
+ "Disclaimer: Vigogne is still under development, and there are many limitations that have to be addressed. Please note that it is possible that the model generates harmful or biased content, incorrect information or generally unhelpful answers.",
298
+ elem_classes=["disclaimer"],
299
+ )
300
+ with gr.Row():
301
+ gr.Markdown(
302
+ "Acknowledgements: This demo is built on top of [MPT-7B-Chat](https://huggingface.co/spaces/mosaicml/mpt-7b-chat). Thanks for their contribution!",
303
+ elem_classes=["disclaimer"],
304
+ )
305
+
306
+ submit_event = msg.submit(
307
+ fn=user,
308
+ inputs=[msg, chatbot],
309
+ outputs=[msg, chatbot],
310
+ queue=False,
311
+ ).then(
312
+ fn=bot,
313
+ inputs=[
314
+ chatbot,
315
+ max_new_tokens,
316
+ temperature,
317
+ top_p,
318
+ top_k,
319
+ repetition_penalty,
320
+ # conversation_id,
321
+ ],
322
+ outputs=chatbot,
323
+ queue=True,
324
+ )
325
+ submit_click_event = submit.click(
326
+ fn=user,
327
+ inputs=[msg, chatbot],
328
+ outputs=[msg, chatbot],
329
+ queue=False,
330
+ ).then(
331
+ fn=bot,
332
+ inputs=[
333
+ chatbot,
334
+ max_new_tokens,
335
+ temperature,
336
+ top_p,
337
+ top_k,
338
+ repetition_penalty,
339
+ # conversation_id,
340
+ ],
341
+ outputs=chatbot,
342
+ queue=True,
343
+ )
344
+ stop.click(
345
+ fn=None,
346
+ inputs=None,
347
+ outputs=None,
348
+ cancels=[submit_event, submit_click_event],
349
+ queue=False,
350
+ )
351
+ clear.click(lambda: None, None, chatbot, queue=False)
352
+
353
+ demo.queue(max_size=128, concurrency_count=2)
354
+ demo.launch(enable_queue=True, share=share, server_name=server_name, server_port=server_port)
355
+
356
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ loralib
3
+ sentencepiece
4
+ git+https://github.com/huggingface/transformers.git
5
+ accelerate
6
+ bitsandbytes
7
+ git+https://github.com/huggingface/peft.git
8
+ gradio