File size: 2,297 Bytes
5caf4e2 421e099 9b15988 421e099 9b15988 421e099 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
# Cấu hình mô hình
MODEL = "Viet-Mistral/Vistral-7B-Chat"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device =', device)
# Load mô hình và tokenizer
model = AutoModelForCausalLM.from_pretrained(
'Viet-Mistral/Vistral-7B-Chat',
torch_dtype=torch.bfloat16, # change to torch.float16 if you're using V100
device_map="auto"#,
#use_cache=True,
#cache_dir='/workspace/thviet/hf_cache'
)
tokenizer = AutoTokenizer.from_pretrained(MODEL, cache_dir='/workspace/thviet/hf_cache')
lora_config = LoraConfig.from_pretrained(
"thviet79/model-QA-medical"# Thay bằng đường dẫn đến mô hình LoRA trên Hugging Face
#cache_dir='/workspace/thviet/hf_cache'
)
# Áp dụng cấu hình LoRA vào mô hình
model = get_peft_model(model, lora_config)
# Chuẩn bị hội thoại và input
system_prompt = "Bạn là một trợ lí ảo Tiếng Việt về lĩnh vực y tế."
question = "Chào bác sĩ,\nRăng cháu hiện tại có mủ ở dưới lợi nhưng khi đau cháu sẽ không ngủ được (quá đau). Tuy nhiên chỉ vài ngày là hết mà thỉnh thoảng nó lại bị đau. Chị cháu bảo là trước chị cháu cũng bị như vậy chỉ là đau răng tuổi dậy thì thôi. Bác sĩ cho cháu hỏi đau răng kèm có mủ dưới lợi là bệnh gì? Cháu có cần đi chữa trị không? Cháu cảm ơn."
conversation = [{"role": "system", "content": system_prompt }]
human = f"Vui lòng trả lời câu hỏi sau: {question}"
conversation.append({"role": "user", "content": human })
# Chuyển các tensor đầu vào sang đúng thiết bị
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(device)
# Tạo đầu ra từ mô hình
out_ids = model.generate(
input_ids=input_ids,
max_new_tokens=768,
do_sample=True,
top_p=0.95,
top_k=40,
temperature=0.1,
repetition_penalty=1.05,
)
# Giải mã và in kết quả
assistant = tokenizer.batch_decode(out_ids[:, input_ids.size(1):], skip_special_tokens=True)[0].strip()
print("Assistant: ", assistant)
|