File size: 2,515 Bytes
db2a738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# pylint: disable=invalid-name, line-too-long, missing-module-docstring
import gc
import os
import time

import gradio
import rich
import torch
from huggingface_hub import snapshot_download
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig

model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
if not torch.cuda.is_available():
    gradio.Error(f"No cuda, cant run {model_name}")
    raise SystemError(f"No cuda, cant run {model_name}")

# snapshot_download?
loc = snapshot_download(repo_id=model_name, local_dir="model")

# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
    time.tzset()  # type: ignore # pylint: disable=no-member
except Exception:  # pylint: disable=broad-except
    # Windows
    logger.warning("Windows, cant run time.tzset()")

model = None
gc.collect()  # for interactive testing

logger.info("start")
has_cuda = torch.cuda.is_available()

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)

if has_cuda:
    model = AutoModelForCausalLM.from_pretrained(
        "model",  # loc
        device_map="auto",
        torch_dtype=torch.bfloat16,  # pylint: disable=no-member
        load_in_8bit=True,
        trust_remote_code=True,
        # use_ram_optimized_load=False,
        # offload_folder="offload_folder",
    )  # .cuda()
else:
    try:
        # model = AutoModel.from_pretrained(model_name, trust_remote_code=True).float()
        model = AutoModelForCausalLM.from_pretrained(
            # model_name, trust_remote_code=True
            "model",
            trust_remote_code=True,
        )  # .float() not supported
    except Exception as exc:
        logger.error(exc)
        logger.warning("Doesnt seem to load for CPU...")
        raise SystemExit(1) from exc

    model = model.eval()

rich.print(f"{model=}")

logger.info("done")

tokenizer = AutoTokenizer.from_pretrained(
    "baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True
)

# model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)

model.generation_config = GenerationConfig.from_pretrained(
    "baichuan-inc/Baichuan2-13B-Chat-4bits"
)
messages = []
messages.append({"role": "user", "content": "解释一下“温故而知新”"})
response = model.chat(tokenizer, messages)

rich.print(response)

logger.info(f"{response=}")