LoneStriker's picture
Upload folder using huggingface_hub
4007954 verified
raw
history blame contribute delete
No virus
2.31 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging, TextStreamer
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os, torch, wandb, platform, warnings
from datasets import load_dataset
from trl import SFTTrainer
hf_token = '..........'
tokenizer = AutoTokenizer.from_pretrained('./vistral-tokenizer')
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
'Viet-Mistral/Vistral-7B-Chat',
device_map="auto",
token=hf_token,
quantization_config=bnb_config,
)
ft_model = PeftModel.from_pretrained(model, CHECKPOINT_PATH)
#torch.backends.cuda.enable_mem_efficient_sdp(False)
#torch.backends.cuda.enable_flash_sdp(False)
system_prompt = "Bạn là một trợ lí Tiếng Việt nhiệt tình và trung thực. Hãy luôn trả lời một cách hữu ích nhất có thể, đồng thời giữ an toàn."
stop_tokens = [tokenizer.eos_token_id, tokenizer('<|im_end|>')['input_ids'].pop()]
def chat_test():
conversation = [{"role": "system", "content": system_prompt }]
while True:
human = input("Human: ")
if human.lower() == "reset":
conversation = [{"role": "system", "content": system_prompt }]
print("The chat history has been cleared!")
continue
if human.lower() == "exit":
break
conversation.append({"role": "user", "content": human })
formatted = tokenizer.apply_chat_template(conversation, tokenize=False) + "<|im_start|>assistant"
tok = tokenizer(formatted, return_tensors="pt").to(ft_model.device)
input_ids = tok['input_ids']
out_ids = ft_model.generate(
input_ids=input_ids,
attention_mask=tok['attention_mask'],
eos_token_id=stop_tokens,
max_new_tokens=50,
do_sample=True,
top_p=0.95,
top_k=40,
temperature=0.1,
repetition_penalty=1.05,
)
assistant = tokenizer.batch_decode(out_ids[:, input_ids.size(1): ], skip_special_tokens=True)[0].strip()
print("Assistant: ", assistant)
conversation.append({"role": "assistant", "content": assistant })
chat_test()