chansung commited on
Commit
dd486e6
1 Parent(s): 7dca0e3

Create llama2.py

Browse files
Files changed (1) hide show
  1. llama2.py +93 -0
llama2.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from pingpong import PingPong
4
+ from pingpong.pingpong import PPManager
5
+ from pingpong.pingpong import PromptFmt
6
+ from pingpong.pingpong import UIFmt
7
+ from pingpong.gradio import GradioChatUIFmt
8
+
9
+ class LLaMA2ChatPromptFmt(PromptFmt):
10
+ @classmethod
11
+ def ctx(cls, context):
12
+ if context is None or context == "":
13
+ return ""
14
+ else:
15
+ return f"""<<SYS>>
16
+ {context}
17
+ <</SYS>>
18
+ """
19
+
20
+ @classmethod
21
+ def prompt(cls, pingpong, truncate_size):
22
+ ping = pingpong.ping[:truncate_size]
23
+ pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
24
+ return f"""[INST] {ping} [/INST] {pong}"""
25
+
26
+ class LLaMA2ChatPPManager(PPManager):
27
+ def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None):
28
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
29
+ to_idx = len(self.pingpongs)
30
+
31
+ results = fmt.ctx(self.ctx)
32
+
33
+ for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
34
+ results += fmt.prompt(pingpong, truncate_size=truncate_size)
35
+
36
+ return results
37
+
38
+ class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager):
39
+ def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
40
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
41
+ to_idx = len(self.pingpongs)
42
+
43
+ results = []
44
+
45
+ for pingpong in self.pingpongs[from_idx:to_idx]:
46
+ results.append(fmt.ui(pingpong))
47
+
48
+ return results
49
+
50
+ async def gen_text(
51
+ prompt,
52
+ hf_model='meta-llama/Llama-2-70b-chat-hf',
53
+ hf_token=None,
54
+ parameters=None
55
+ ):
56
+ if hf_token is None:
57
+ raise ValueError("Hugging Face Token is not set")
58
+
59
+ if parameters is None:
60
+ parameters = {
61
+ 'max_new_tokens': 512,
62
+ 'do_sample': True,
63
+ 'return_full_text': False,
64
+ 'temperature': 1.0,
65
+ 'top_k': 50,
66
+ # 'top_p': 1.0,
67
+ 'repetition_penalty': 1.2
68
+ }
69
+
70
+ url = f'https://api-inference.huggingface.co/models/{hf_model}'
71
+ headers={
72
+ 'Authorization': f'Bearer {hf_token}',
73
+ 'Content-type': 'application/json'
74
+ }
75
+ data = {
76
+ 'inputs': prompt,
77
+ 'stream': True,
78
+ 'options': {
79
+ 'use_cache': False,
80
+ },
81
+ 'parameters': parameters
82
+ }
83
+
84
+ r = requests.post(
85
+ url,
86
+ headers=headers,
87
+ data=json.dumps(data),
88
+ stream=True
89
+ )
90
+
91
+ client = sseclient.SSEClient(r)
92
+ for event in client.events():
93
+ yield json.loads(event.data)['token']['text']