Baichuan2-13B-Chat / model.py
jZoNg's picture
disable stream
cefe426
raw
history blame contribute delete
No virus
1.39 kB
from threading import Thread
from typing import Iterator
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
model_id = 'baichuan-inc/Baichuan2-13B-Chat'
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_id,
# device_map='auto',
torch_dtype=torch.float16,
trust_remote_code=True
)
model = model.quantize(4).cuda()
model.generation_config = GenerationConfig.from_pretrained(model_id)
else:
model = None
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=False,
trust_remote_code=True
)
def run(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 0.95,
top_k: int = 5
) -> Iterator[str]:
model.generation_config.max_new_tokens = max_new_tokens
model.generation_config.temperature = temperature
model.generation_config.top_p = top_p
model.generation_config.top_k = top_k
history = []
result=""
for i in chat_history:
history.append({"role": "user", "content": i[0]})
history.append({"role": "assistant", "content": i[1]})
print(history)
history.append({"role": "user", "content": message})
for response in model.chat(
tokenizer,
history,
# stream=True,
):
result = result + response
yield result