File size: 3,387 Bytes
dd486e6
6aaddfa
92cf0ad
 
dd486e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5ad5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import json
import requests
import sseclient

from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt

class LLaMA2ChatPromptFmt(PromptFmt):
    @classmethod
    def ctx(cls, context):
        if context is None or context == "":
            return ""
        else:
            return f"""<<SYS>>
{context}
<</SYS>>
"""

    @classmethod
    def prompt(cls, pingpong, truncate_size):
        ping = pingpong.ping[:truncate_size]
        pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
        return f"""[INST] {ping} [/INST] {pong}"""

class LLaMA2ChatPPManager(PPManager):
    def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None):
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        results = fmt.ctx(self.ctx)

        for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
            results += fmt.prompt(pingpong, truncate_size=truncate_size)

        return results

class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager):
    def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
        if to_idx == -1 or to_idx >= len(self.pingpongs):
            to_idx = len(self.pingpongs)

        results = []

        for pingpong in self.pingpongs[from_idx:to_idx]:
            results.append(fmt.ui(pingpong))

        return results

async def gen_text(
    prompt, 
    hf_model='meta-llama/Llama-2-70b-chat-hf', 
    hf_token=None, 
    parameters=None
):
  if hf_token is None:
    raise ValueError("Hugging Face Token is not set")

  if parameters is None:
    parameters = {
        'max_new_tokens': 512,
        'do_sample': True,
        'return_full_text': False,
        'temperature': 1.0,
        'top_k': 50,
        # 'top_p': 1.0,
        'repetition_penalty': 1.2
    }

  url = f'https://api-inference.huggingface.co/models/{hf_model}'
  headers={
      'Authorization': f'Bearer {hf_token}',
      'Content-type': 'application/json'
  }
  data = {
      'inputs': prompt,
      'stream': True,
      'options': {
          'use_cache': False,
      },
      'parameters': parameters
  }

  r = requests.post(
      url,
      headers=headers,
      data=json.dumps(data),
      stream=True
  )

  client = sseclient.SSEClient(r)
  for event in client.events():
    yield json.loads(event.data)['token']['text']

def gen_text_none_stream(
    prompt, 
    hf_model='meta-llama/Llama-2-70b-chat-hf', 
    hf_token=None, 
):
    parameters = {
        'max_new_tokens': 64,
        'do_sample': True,
        'return_full_text': False,
        'temperature': 0.7,
        'top_k': 10,
        # 'top_p': 1.0,
        'repetition_penalty': 1.2
    }

    url = f'https://api-inference.huggingface.co/models/{hf_model}'
    headers={
        'Authorization': f'Bearer {hf_token}',
        'Content-type': 'application/json'
    }
    data = {
        'inputs': prompt,
        'stream': False,
        'options': {
            'use_cache': False,
        },
        'parameters': parameters
    }

    r = requests.post(
        url,
        headers=headers,
        data=json.dumps(data),
    )

    return json.loads(r.text)[0]["generated_text"]