# pylint: disable=invalid-name, line-too-long, missing-module-docstring import gc import os import time import gradio import rich import torch from huggingface_hub import snapshot_download from loguru import logger from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.utils import GenerationConfig model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits" if not torch.cuda.is_available(): gradio.Error(f"No cuda, cant run {model_name}") raise SystemError(f"No cuda, cant run {model_name}") # snapshot_download? loc = snapshot_download(repo_id=model_name, local_dir="model") # fix timezone in Linux os.environ["TZ"] = "Asia/Shanghai" try: time.tzset() # type: ignore # pylint: disable=no-member except Exception: # pylint: disable=broad-except # Windows logger.warning("Windows, cant run time.tzset()") model = None gc.collect() # for interactive testing logger.info("start") has_cuda = torch.cuda.is_available() tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True) if has_cuda: model = AutoModelForCausalLM.from_pretrained( "model", # loc device_map="auto", torch_dtype=torch.bfloat16, # pylint: disable=no-member load_in_8bit=True, trust_remote_code=True, # use_ram_optimized_load=False, # offload_folder="offload_folder", ) # .cuda() else: try: # model = AutoModel.from_pretrained(model_name, trust_remote_code=True).float() model = AutoModelForCausalLM.from_pretrained( # model_name, trust_remote_code=True "model", trust_remote_code=True, ) # .float() not supported except Exception as exc: logger.error(exc) logger.warning("Doesnt seem to load for CPU...") raise SystemExit(1) from exc model = model.eval() rich.print(f"{model=}") logger.info("done") tokenizer = AutoTokenizer.from_pretrained( "baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True ) # model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True) model.generation_config = GenerationConfig.from_pretrained( "baichuan-inc/Baichuan2-13B-Chat-4bits" ) messages = [] messages.append({"role": "user", "content": "解释一下“温故而知新”"}) response = model.chat(tokenizer, messages) rich.print(response) logger.info(f"{response=}")