File size: 1,394 Bytes
7f9f96d
 
 
 
e4f695b
7f9f96d
 
 
 
 
e4f695b
7f9f96d
e4f695b
 
7f9f96d
e4f695b
 
 
 
 
 
 
 
 
7f9f96d
 
e4f695b
 
 
 
 
 
 
 
3360664
 
 
 
7f9f96d
e4f695b
 
7f9f96d
e4f695b
 
 
 
 
7f9f96d
e4f695b
 
b878b56
 
 
cefe426
b878b56
 
 
3360664
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from threading import Thread
from typing import Iterator

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig

model_id = 'baichuan-inc/Baichuan2-13B-Chat'

if torch.cuda.is_available():
  model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # device_map='auto',
    torch_dtype=torch.float16,
    trust_remote_code=True
  )
  model = model.quantize(4).cuda()
  model.generation_config = GenerationConfig.from_pretrained(model_id)
else:
  model = None
tokenizer = AutoTokenizer.from_pretrained(
  model_id,
  use_fast=False,
  trust_remote_code=True
)

def run(
  message: str,
  chat_history: list[tuple[str, str]],
  max_new_tokens: int = 1024,
  temperature: float = 1.0,
  top_p: float = 0.95,
  top_k: int = 5
) -> Iterator[str]:
  model.generation_config.max_new_tokens = max_new_tokens
  model.generation_config.temperature = temperature
  model.generation_config.top_p = top_p
  model.generation_config.top_k = top_k

  history = []
  result=""

  for i in chat_history:
    history.append({"role": "user", "content": i[0]})
    history.append({"role": "assistant", "content": i[1]})
  
  print(history)

  history.append({"role": "user", "content": message})
  
  for response in model.chat(
    tokenizer,
    history,
    # stream=True,
  ):
    result = result + response
    yield result