Cran-May's picture
Duplicate from mikeee/baichuan-13b-chat-try
db2a738
raw
history blame contribute delete
No virus
2.52 kB
# pylint: disable=invalid-name, line-too-long, missing-module-docstring
import gc
import os
import time
import gradio
import rich
import torch
from huggingface_hub import snapshot_download
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
if not torch.cuda.is_available():
gradio.Error(f"No cuda, cant run {model_name}")
raise SystemError(f"No cuda, cant run {model_name}")
# snapshot_download?
loc = snapshot_download(repo_id=model_name, local_dir="model")
# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
time.tzset() # type: ignore # pylint: disable=no-member
except Exception: # pylint: disable=broad-except
# Windows
logger.warning("Windows, cant run time.tzset()")
model = None
gc.collect() # for interactive testing
logger.info("start")
has_cuda = torch.cuda.is_available()
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)
if has_cuda:
model = AutoModelForCausalLM.from_pretrained(
"model", # loc
device_map="auto",
torch_dtype=torch.bfloat16, # pylint: disable=no-member
load_in_8bit=True,
trust_remote_code=True,
# use_ram_optimized_load=False,
# offload_folder="offload_folder",
) # .cuda()
else:
try:
# model = AutoModel.from_pretrained(model_name, trust_remote_code=True).float()
model = AutoModelForCausalLM.from_pretrained(
# model_name, trust_remote_code=True
"model",
trust_remote_code=True,
) # .float() not supported
except Exception as exc:
logger.error(exc)
logger.warning("Doesnt seem to load for CPU...")
raise SystemExit(1) from exc
model = model.eval()
rich.print(f"{model=}")
logger.info("done")
tokenizer = AutoTokenizer.from_pretrained(
"baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True
)
# model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained(
"baichuan-inc/Baichuan2-13B-Chat-4bits"
)
messages = []
messages.append({"role": "user", "content": "解释一下“温故而知新”"})
response = model.chat(tokenizer, messages)
rich.print(response)
logger.info(f"{response=}")