File size: 2,510 Bytes
dd486e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os

from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt

class LLaMA2ChatPromptFmt(PromptFmt):
    @classmethod
    def ctx(cls, context):
        if context is None or context == "":
            return ""
        else:
            return f"""<<SYS>>
{context}
<</SYS>>
"""

    @classmethod
    def prompt(cls, pingpong, truncate_size):
        ping = pingpong.ping[:truncate_size]
        pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
        return f"""[INST] {ping} [/INST] {pong}"""

class LLaMA2ChatPPManager(PPManager):
    def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None):
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        results = fmt.ctx(self.ctx)

        for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
            results += fmt.prompt(pingpong, truncate_size=truncate_size)

        return results

class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager):
    def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        results = []

        for pingpong in self.pingpongs[from_idx:to_idx]:
            results.append(fmt.ui(pingpong))

        return results

async def gen_text(
    prompt, 
    hf_model='meta-llama/Llama-2-70b-chat-hf', 
    hf_token=None, 
    parameters=None
):
  if hf_token is None:
    raise ValueError("Hugging Face Token is not set")

  if parameters is None:
    parameters = {
        'max_new_tokens': 512,
        'do_sample': True,
        'return_full_text': False,
        'temperature': 1.0,
        'top_k': 50,
        # 'top_p': 1.0,
        'repetition_penalty': 1.2
    }

  url = f'https://api-inference.huggingface.co/models/{hf_model}'
  headers={
      'Authorization': f'Bearer {hf_token}',
      'Content-type': 'application/json'
  }
  data = {
      'inputs': prompt,
      'stream': True,
      'options': {
          'use_cache': False,
      },
      'parameters': parameters
  }

  r = requests.post(
      url,
      headers=headers,
      data=json.dumps(data),
      stream=True
  )

  client = sseclient.SSEClient(r)
  for event in client.events():
    yield json.loads(event.data)['token']['text']