valeriylo commited on
Commit
cf30256
·
1 Parent(s): a5ae538

Create interact_llamacpp.py

Browse files
Files changed (1) hide show
  1. interact_llamacpp.py +73 -0
interact_llamacpp.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ from llama_cpp import Llama
3
+
4
+ SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
5
+ SYSTEM_TOKEN = 1788
6
+ USER_TOKEN = 1404
7
+ BOT_TOKEN = 9225
8
+ LINEBREAK_TOKEN = 13
9
+
10
+ ROLE_TOKENS = {
11
+ "user": USER_TOKEN,
12
+ "bot": BOT_TOKEN,
13
+ "system": SYSTEM_TOKEN
14
+ }
15
+
16
+
17
+ def get_message_tokens(model, role, content):
18
+ message_tokens = model.tokenize(content.encode("utf-8"))
19
+ message_tokens.insert(1, ROLE_TOKENS[role])
20
+ message_tokens.insert(2, LINEBREAK_TOKEN)
21
+ message_tokens.append(model.token_eos())
22
+ return message_tokens
23
+
24
+
25
+ def get_system_tokens(model):
26
+ system_message = {
27
+ "role": "system",
28
+ "content": SYSTEM_PROMPT
29
+ }
30
+ return get_message_tokens(model, **system_message)
31
+
32
+
33
+ def interact(
34
+ model_path,
35
+ n_ctx=2000,
36
+ top_k=30,
37
+ top_p=0.9,
38
+ temperature=0.2,
39
+ repeat_penalty=1.1
40
+ ):
41
+ model = Llama(
42
+ model_path=model_path,
43
+ n_ctx=n_ctx,
44
+ n_parts=1,
45
+ )
46
+
47
+ system_tokens = get_system_tokens(model)
48
+ tokens = system_tokens
49
+ model.eval(tokens)
50
+
51
+ while True:
52
+ user_message = input("User: ")
53
+ message_tokens = get_message_tokens(model=model, role="user", content=user_message)
54
+ role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
55
+ tokens += message_tokens + role_tokens
56
+ generator = model.generate(
57
+ tokens,
58
+ top_k=top_k,
59
+ top_p=top_p,
60
+ temp=temperature,
61
+ repeat_penalty=repeat_penalty
62
+ )
63
+ for token in generator:
64
+ token_str = model.detokenize([token]).decode("utf-8", errors="ignore")
65
+ tokens.append(token)
66
+ if token == model.token_eos():
67
+ break
68
+ print(token_str, end="", flush=True)
69
+ print()
70
+
71
+
72
+ if __name__ == "__main__":
73
+ fire.Fire(interact)