LoneStriker
commited on
Commit
•
8563ead
1
Parent(s):
0322f34
Upload folder using huggingface_hub
Browse files- .gitattributes +9 -35
- .gitignore +1 -0
- README.md +70 -0
- Vistral-7B-ChatML-Q3_K_L.gguf +3 -0
- Vistral-7B-ChatML-Q3_K_M.gguf +3 -0
- Vistral-7B-ChatML-Q3_K_S.gguf +3 -0
- Vistral-7B-ChatML-Q4_K_M.gguf +3 -0
- Vistral-7B-ChatML-Q4_K_S.gguf +3 -0
- Vistral-7B-ChatML-Q5_K_M.gguf +3 -0
- Vistral-7B-ChatML-Q5_K_S.gguf +3 -0
- Vistral-7B-ChatML-Q6_K.gguf +3 -0
- Vistral-7B-ChatML-Q8_0.gguf +3 -0
- finetune.py +133 -0
- run.py +64 -0
.gitattributes
CHANGED
@@ -1,35 +1,9 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
Vistral-7B-ChatML-Q3_K_L.gguf filter=lfs diff=lfs merge=lfs -text
|
2 |
+
Vistral-7B-ChatML-Q3_K_M.gguf filter=lfs diff=lfs merge=lfs -text
|
3 |
+
Vistral-7B-ChatML-Q3_K_S.gguf filter=lfs diff=lfs merge=lfs -text
|
4 |
+
Vistral-7B-ChatML-Q4_K_M.gguf filter=lfs diff=lfs merge=lfs -text
|
5 |
+
Vistral-7B-ChatML-Q4_K_S.gguf filter=lfs diff=lfs merge=lfs -text
|
6 |
+
Vistral-7B-ChatML-Q5_K_M.gguf filter=lfs diff=lfs merge=lfs -text
|
7 |
+
Vistral-7B-ChatML-Q5_K_S.gguf filter=lfs diff=lfs merge=lfs -text
|
8 |
+
Vistral-7B-ChatML-Q6_K.gguf filter=lfs diff=lfs merge=lfs -text
|
9 |
+
Vistral-7B-ChatML-Q8_0.gguf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.ipynb_checkpoints
|
README.md
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- vi
|
4 |
+
library_name: transformers
|
5 |
+
tags:
|
6 |
+
- LLMs
|
7 |
+
- NLP
|
8 |
+
- Vietnamese
|
9 |
+
license: mit
|
10 |
+
---
|
11 |
+
|
12 |
+
## Model Description
|
13 |
+
|
14 |
+
This model is finetuned from [Viet-Mistral/Vistral-7B-Chat](https://huggingface.co/Viet-Mistral/Vistral-7B-Chat). The dataset is taken from [bkai-foundation-models/vi-self-chat-sharegpt-format](https://huggingface.co/datasets/bkai-foundation-models/vi-self-chat-sharegpt-format)
|
15 |
+
|
16 |
+
This is a **low rank** finetune to add support for chatml template. While the template does not affect model performance, it would be nice to support chatml since most of models based on Mistral already using it.
|
17 |
+
|
18 |
+
The format looks like this:
|
19 |
+
|
20 |
+
```
|
21 |
+
<|im_start|>system
|
22 |
+
Provide some context and/or instructions to the model.
|
23 |
+
<|im_end|>
|
24 |
+
<|im_start|>user
|
25 |
+
The user’s message goes here
|
26 |
+
<|im_end|>
|
27 |
+
<|im_start|>assistant
|
28 |
+
```
|
29 |
+
|
30 |
+
## How to use it
|
31 |
+
|
32 |
+
The recommended way is to use the GGUF `vistral-7b-chatml-Q4_K_M.gguf` file included in this repository. Run it via llama.cpp (remember to pass `-cml` argument to use chatml template)
|
33 |
+
|
34 |
+
```bash
|
35 |
+
./main -m vistral-7b-chatml-Q4_K_M.gguf -p "Bạn là một trợ lí Tiếng Việt nhiệt tình và trung thực." -cml
|
36 |
+
```
|
37 |
+
|
38 |
+
Additionally, you can run the `python3 run.py` inside this repository to try the model using transformers library. This it not the recommended way since you may need to change some params inside in order to make it work.
|
39 |
+
|
40 |
+
This is an example of a conversation using llama.cpp:
|
41 |
+
|
42 |
+
- <|im_start|>system
|
43 |
+
- Bạn là một trợ lí Tiếng Việt nhiệt tình và trung thực.<|im_end|>
|
44 |
+
- <|im_start|>user
|
45 |
+
- > xin chào
|
46 |
+
- Tôi là một trợ lý AI được tạo ra để hỗ trợ con người. Bạn có câu hỏi gì về chủ đề nào đó?<|im_end|>
|
47 |
+
- > trợ lý AI là gì? bạn giải thích được không?
|
48 |
+
- Trợ lý AI là một chương trình máy tính sử dụng trí tuệ nhân tạo và học máy để hỗ trợ con người trong nhiều công việc khác nhau. Nó có thể hiểu và xử lý ngôn ngữ tự nhiên, nhận biết hình ảnh và âm thanh cũng như đưa ra các đề xuất hay câu trả lời phù hợp cho các yêu cầu của bạn.<|im_end|>
|
49 |
+
- > ồ, cảm ơn, vậy bạn có thể làm gì?
|
50 |
+
- Tôi được thiết kế để giúp đỡ con người trong nhiều công việc như giải đáp các câu hỏi, trả lời những vấn đề liên quan đến cuộc sống hàng ngày của họ. Tôi cũng hỗ trợ nghiên cứu và phân tích dữ liệu để đưa ra dự đoán hay tư vấn cho doanh nghiệp hoặc cá nhân.<|im_end|>
|
51 |
+
- > tức là sao?
|
52 |
+
- Ví dụ, nếu bạn đang thắc mắc về một chủ đề nào đó và không biết phải bắt đầu tìm kiếm ở đâu hay làm thế nào để hiểu rõ hơn về nó thì tôi có thể giúp bạn. Tôi cũng có khả năng đọc và phân tích dữ liệu từ nhiều nguồn khác nhau như các tài liệu nghiên cứu khoa học hoặc thông tin trên mạng xã hội, sau đó đưa ra những giải thích chi tiết cho vấn đề của bạn.<|im_end|>
|
53 |
+
|
54 |
+
## Training
|
55 |
+
|
56 |
+
You can also look at the training code in the `finetune.py` file.
|
57 |
+
|
58 |
+
For tokenizer, I changed these things:
|
59 |
+
- Token `[INST]` to `<|im_start|>`, make it become special token
|
60 |
+
- Token `[/INST]` to `<|im_end|>`, make it become special token
|
61 |
+
- Change to `eos_token` to `<|im_end|>`
|
62 |
+
- Update `chat_template` to chatml, taken from [this example](https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json#L52)
|
63 |
+
|
64 |
+
Additionally, there is a checkpoint file in my repository if you want to merge the LORA yourself.
|
65 |
+
|
66 |
+
## More information
|
67 |
+
|
68 |
+
Disclaimer: I'm not expert in machine learning, my background is from cybersecurity so the making of this model is a "hobby" to me. Training is done using a VPS on Google Cloud, I paid with my own money.
|
69 |
+
|
70 |
+
If you want to discuss, feel free to contact me at `contact at ngxson dot com` - [ngxson.com](https://ngxson.com)
|
Vistral-7B-ChatML-Q3_K_L.gguf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ab6a751d538ec035bceaf88e2c68d9df217d7d8836747ef47f3f1bf2d8da88f
|
3 |
+
size 3854781984
|
Vistral-7B-ChatML-Q3_K_M.gguf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f57fff05385a79ddce9154f4589973cbb437ed0ff17dea8b92f1ac1f174978c3
|
3 |
+
size 3551743520
|
Vistral-7B-ChatML-Q3_K_S.gguf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b61fd10da1e28641b20cd78ab7e25ddf5b93a2259ee1895027988293460b9cb0
|
3 |
+
size 3197324832
|
Vistral-7B-ChatML-Q4_K_M.gguf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cdca9c4f26ae61dbd453524bdd71f8f0c03b1a71ba8d2caee219efee146dc010
|
3 |
+
size 4404661312
|
Vistral-7B-ChatML-Q4_K_S.gguf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:306fa0ffd9ae67b1c57a04c9335f011cf8d0f6929f5b3798f119a6d16bcbed94
|
3 |
+
size 4176596032
|
Vistral-7B-ChatML-Q5_K_M.gguf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21605ad7f2077e91d346779d36eeaf854adc13924b7ac822b9435b8ddb028b8b
|
3 |
+
size 5170892352
|
Vistral-7B-ChatML-Q5_K_S.gguf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fff2be0f87c5a8a419d78dd0d8a4ca3fb5f06b3a40b0a5fccd7d068507ecaa96
|
3 |
+
size 5037198912
|
Vistral-7B-ChatML-Q6_K.gguf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f0d076da5edddfa8e4ec3d27c15245ae2bf24b9256eceb18e383142e5339d31
|
3 |
+
size 5985012832
|
Vistral-7B-ChatML-Q8_0.gguf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:560de40a3daddc551331d1fdf21b6c94b0ed31877d3ce360c6feb839186e02e6
|
3 |
+
size 7751441440
|
finetune.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging, TextStreamer
|
3 |
+
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
|
4 |
+
import os, torch, wandb, platform, warnings
|
5 |
+
from datasets import load_dataset
|
6 |
+
from trl import SFTTrainer
|
7 |
+
|
8 |
+
hf_token = ''
|
9 |
+
wnb_token = ''
|
10 |
+
wnb_name = 'vistral-chatml'
|
11 |
+
MODEL = 'Viet-Mistral/Vistral-7B-Chat'
|
12 |
+
resume_from_checkpoint = False
|
13 |
+
output_dir = 'vistral-chatml'
|
14 |
+
tokenizer_path = '.'
|
15 |
+
|
16 |
+
#######################################################
|
17 |
+
## DATASET
|
18 |
+
|
19 |
+
|
20 |
+
from datasets import load_dataset
|
21 |
+
|
22 |
+
|
23 |
+
def generate_system_prompt(i):
|
24 |
+
system_prompt = "Bạn là một trợ lí Tiếng Việt nhiệt tình và trung thực. Hãy luôn trả lời một cách hữu ích nhất có thể, đồng thời giữ an toàn."
|
25 |
+
if i % 2 == 0:
|
26 |
+
system_prompt += "\nCâu trả lời của bạn không nên chứa bất kỳ nội dung gây hại, phân biệt chủng tộc, phân biệt giới tính, độc hại, nguy hiểm hoặc bất hợp pháp nào. Hãy đảm bảo rằng các câu trả lời của bạn không có thiên kiến xã hội và mang tính tích cực."
|
27 |
+
if i % 5 == 0:
|
28 |
+
system_prompt += "\nNếu một câu hỏi không có ý nghĩa hoặc không hợp lý về mặt thông tin, hãy giải thích tại sao thay vì trả lời một điều gì đó không chính xác. Nếu bạn không biết câu trả lời cho một câu hỏi, hãy trẳ lời là bạn không biết và vui lòng không chia sẻ thông tin sai lệch."
|
29 |
+
return system_prompt
|
30 |
+
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
32 |
+
|
33 |
+
def tokenize_chat(input, i):
|
34 |
+
print(generate_system_prompt(i))
|
35 |
+
conversation = [{'role': 'system', 'content': generate_system_prompt(i)}]
|
36 |
+
for msg in input['conversations']:
|
37 |
+
output = {'role': 'user', 'content': msg['value']}
|
38 |
+
if msg['from'] == 'gpt':
|
39 |
+
output['role'] = 'assistant'
|
40 |
+
conversation.append(output)
|
41 |
+
formatted = tokenizer.apply_chat_template(conversation, tokenize=False)
|
42 |
+
return tokenizer(formatted)
|
43 |
+
|
44 |
+
sharegpt_dataset = load_dataset('bkai-foundation-models/vi-self-chat-sharegpt-format')
|
45 |
+
train_data = sharegpt_dataset['train'].shuffle(seed=42)\
|
46 |
+
.select(range(800))\
|
47 |
+
.map(lambda x, i: tokenize_chat(x, i), remove_columns=["conversations"], with_indices=True)
|
48 |
+
|
49 |
+
|
50 |
+
#######################################################
|
51 |
+
## SETUP
|
52 |
+
|
53 |
+
wandb.login(key=wnb_token)
|
54 |
+
wandb.init(name=wnb_name)
|
55 |
+
# use custom tokenizer instead of one comes from the model
|
56 |
+
#tokenizer = AutoTokenizer.from_pretrained(
|
57 |
+
# MODEL,
|
58 |
+
# add_eos_token=False,
|
59 |
+
# add_bos_token=False,
|
60 |
+
# token=hf_token,
|
61 |
+
#)
|
62 |
+
bnb_config = BitsAndBytesConfig(
|
63 |
+
load_in_4bit=True,
|
64 |
+
bnb_4bit_quant_type="nf4",
|
65 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
66 |
+
bnb_4bit_use_double_quant=True,
|
67 |
+
)
|
68 |
+
model = AutoModelForCausalLM.from_pretrained(
|
69 |
+
MODEL,
|
70 |
+
device_map="auto",
|
71 |
+
token=hf_token,
|
72 |
+
quantization_config=bnb_config,
|
73 |
+
trust_remote_code=True,
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
#######################################################
|
78 |
+
## LORA CONFIG
|
79 |
+
|
80 |
+
model.gradient_checkpointing_enable()
|
81 |
+
model = prepare_model_for_kbit_training(model)
|
82 |
+
peft_config = LoraConfig(
|
83 |
+
r=8,
|
84 |
+
lora_alpha=16,
|
85 |
+
target_modules=[
|
86 |
+
"q_proj",
|
87 |
+
"k_proj",
|
88 |
+
"v_proj",
|
89 |
+
"o_proj",
|
90 |
+
"gate_proj",
|
91 |
+
"up_proj",
|
92 |
+
"down_proj",
|
93 |
+
"lm_head",
|
94 |
+
],
|
95 |
+
bias="none",
|
96 |
+
lora_dropout=0.05, # Conventional
|
97 |
+
task_type="CAUSAL_LM",
|
98 |
+
)
|
99 |
+
model = get_peft_model(model, peft_config)
|
100 |
+
model.print_trainable_parameters()
|
101 |
+
|
102 |
+
from accelerate import Accelerator
|
103 |
+
accelerator = Accelerator()
|
104 |
+
model = accelerator.prepare_model(model)
|
105 |
+
|
106 |
+
|
107 |
+
#######################################################
|
108 |
+
## TRAIN
|
109 |
+
|
110 |
+
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
|
111 |
+
trainer = Trainer(
|
112 |
+
model=model,
|
113 |
+
train_dataset=train_data,
|
114 |
+
args=TrainingArguments(
|
115 |
+
report_to='wandb',
|
116 |
+
warmup_steps=1,
|
117 |
+
per_device_train_batch_size=1,
|
118 |
+
gradient_accumulation_steps=4,
|
119 |
+
gradient_checkpointing=True,
|
120 |
+
num_train_epochs=4,
|
121 |
+
learning_rate=2.5e-5,
|
122 |
+
logging_steps=1,
|
123 |
+
optim="paged_adamw_8bit",
|
124 |
+
save_strategy="steps",
|
125 |
+
save_steps=10,
|
126 |
+
save_total_limit=4,
|
127 |
+
output_dir=output_dir
|
128 |
+
),
|
129 |
+
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
130 |
+
)
|
131 |
+
model.config.use_cache = False
|
132 |
+
|
133 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
run.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging, TextStreamer
|
3 |
+
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
|
4 |
+
import os, torch, wandb, platform, warnings
|
5 |
+
from datasets import load_dataset
|
6 |
+
from trl import SFTTrainer
|
7 |
+
|
8 |
+
hf_token = '..........'
|
9 |
+
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained('./vistral-tokenizer')
|
11 |
+
bnb_config = BitsAndBytesConfig(
|
12 |
+
load_in_4bit=True,
|
13 |
+
bnb_4bit_quant_type="nf4",
|
14 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
15 |
+
bnb_4bit_use_double_quant=True,
|
16 |
+
)
|
17 |
+
model = AutoModelForCausalLM.from_pretrained(
|
18 |
+
'Viet-Mistral/Vistral-7B-Chat',
|
19 |
+
device_map="auto",
|
20 |
+
token=hf_token,
|
21 |
+
quantization_config=bnb_config,
|
22 |
+
)
|
23 |
+
ft_model = PeftModel.from_pretrained(model, CHECKPOINT_PATH)
|
24 |
+
|
25 |
+
#torch.backends.cuda.enable_mem_efficient_sdp(False)
|
26 |
+
#torch.backends.cuda.enable_flash_sdp(False)
|
27 |
+
|
28 |
+
system_prompt = "Bạn là một trợ lí Tiếng Việt nhiệt tình và trung thực. Hãy luôn trả lời một cách hữu ích nhất có thể, đồng thời giữ an toàn."
|
29 |
+
|
30 |
+
stop_tokens = [tokenizer.eos_token_id, tokenizer('<|im_end|>')['input_ids'].pop()]
|
31 |
+
|
32 |
+
def chat_test():
|
33 |
+
conversation = [{"role": "system", "content": system_prompt }]
|
34 |
+
while True:
|
35 |
+
human = input("Human: ")
|
36 |
+
if human.lower() == "reset":
|
37 |
+
conversation = [{"role": "system", "content": system_prompt }]
|
38 |
+
print("The chat history has been cleared!")
|
39 |
+
continue
|
40 |
+
|
41 |
+
if human.lower() == "exit":
|
42 |
+
break
|
43 |
+
|
44 |
+
conversation.append({"role": "user", "content": human })
|
45 |
+
formatted = tokenizer.apply_chat_template(conversation, tokenize=False) + "<|im_start|>assistant"
|
46 |
+
tok = tokenizer(formatted, return_tensors="pt").to(ft_model.device)
|
47 |
+
input_ids = tok['input_ids']
|
48 |
+
|
49 |
+
out_ids = ft_model.generate(
|
50 |
+
input_ids=input_ids,
|
51 |
+
attention_mask=tok['attention_mask'],
|
52 |
+
eos_token_id=stop_tokens,
|
53 |
+
max_new_tokens=50,
|
54 |
+
do_sample=True,
|
55 |
+
top_p=0.95,
|
56 |
+
top_k=40,
|
57 |
+
temperature=0.1,
|
58 |
+
repetition_penalty=1.05,
|
59 |
+
)
|
60 |
+
assistant = tokenizer.batch_decode(out_ids[:, input_ids.size(1): ], skip_special_tokens=True)[0].strip()
|
61 |
+
print("Assistant: ", assistant)
|
62 |
+
conversation.append({"role": "assistant", "content": assistant })
|
63 |
+
|
64 |
+
chat_test()
|