File size: 2,789 Bytes
9b6a4ab
 
8d030a2
 
9c042fd
 
4947e7b
9c042fd
 
 
89cb869
8d030a2
de222eb
8d030a2
9b6a4ab
360d9e4
8d030a2
360d9e4
89cb869
 
 
 
 
 
 
 
 
 
8d030a2
 
 
89cb869
 
 
 
d7ec399
89cb869
 
 
 
 
 
 
 
 
 
 
 
 
d7ec399
89cb869
60399ca
 
360d9e4
9c042fd
 
 
360d9e4
9c042fd
da75503
9c042fd
adb2ab9
9c042fd
360d9e4
 
53c5ff4
 
 
 
360d9e4
adb2ab9
 
53c5ff4
 
adb2ab9
 
 
 
 
 
 
 
 
 
 
 
53c5ff4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Test various models."""
# pylint: disable=invalid-name, line-too-long,broad-exception-caught, protected-access
import os
import time
from pathlib import Path

import gradio as gr
import torch
from loguru import logger
from transformers import AutoModel, AutoTokenizer

# ruff: noqa: E402
# os.system("pip install --upgrade torch transformers sentencepiece scipy cpm_kernels accelerate bitsandbytes loguru")

# os.system("pip install torch transformers sentencepiece loguru")



# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
    time.tzset()  # type: ignore # pylint: disable=no-member
except Exception:
    # Windows
    logger.warning("Windows, cant run time.tzset()")

model_name = "THUDM/chatglm2-6b-int4"  # 3.9G

tokenizer = AutoTokenizer.from_pretrained(
    "THUDM/chatglm2-6b-int4", trust_remote_code=True
)

has_cuda = torch.cuda.is_available()
# has_cuda = False  # force cpu

logger.debug("load")
if has_cuda:
    if model_name.endswith("int4"):
        model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
    else:
        model = (
            AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
        )
else:
    model = AutoModel.from_pretrained(
        model_name, trust_remote_code=True
    ).half()  # .float() .half().float()

model = model.eval()
logger.debug("done load")

# tokenizer = AutoTokenizer.from_pretrained("openchat/openchat_v2_w")
# model = AutoModelForCausalLM.from_pretrained("openchat/openchat_v2_w", load_in_8bit_fp32_cpu_offload=True, load_in_8bit=True)

# locate model file cache
cache_loc = Path("~/.cache/huggingface/hub").expanduser()
model_cache_path = [elm for elm in Path(cache_loc).rglob("*") if Path(model_name).name in elm.as_posix() and "pytorch_model.bin" in elm.as_posix()]

logger.debug(f"{model_cache_path=}")

if model_cache_path:
    model_size_gb = model_cache_path[0].stat().st_size / 2**30
    logger.info(f"{model_name=} {model_size_gb=:.2f} GB")


def respond(message, chat_history):
    response, chat_history = model.chat(tokenizer, message, history=chat_history, temperature=0.7, repetition_penalty=1.2, max_length=128)
    chat_history.append((message, response))
    return "", chat_history

theme = gr.themes.Soft(text_size="sm")
with gr.Blocks(theme=theme) as block:
    chatbot = gr.Chatbot()

    with gr.Column():
        with gr.Column(scale=12):
            msg = gr.Textbox()
        with gr.Column(scale=1, min_width=16):
            btn = gr.Button()
        with gr.Column(scale=1, min_width=16):
            clear = gr.ClearButton([msg, chatbot])

    # do not clear prompt
    msg.submit(lambda x, y: [x] + respond(x, y)[1:], [msg, chatbot], [msg, chatbot])

    btn.click(respond, [msg, chatbot], [msg, chatbot])

block.queue().launch()