Husnain commited on
Commit
b5f45b3
1 Parent(s): 85bab23

💎 [Feature] Enable gpt-3.5 in chat_api

Browse files
Files changed (1) hide show
  1. networks/huggingface_streamer.py +199 -0
networks/huggingface_streamer.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import requests
4
+
5
+
6
+ from tclogger import logger
7
+ from transformers import AutoTokenizer
8
+
9
+ from constants.models import (
10
+ MODEL_MAP,
11
+ STOP_SEQUENCES_MAP,
12
+ TOKEN_LIMIT_MAP,
13
+ TOKEN_RESERVED,
14
+ )
15
+ from constants.envs import PROXIES
16
+ from messagers.message_outputer import OpenaiStreamOutputer
17
+
18
+
19
+ class HuggingfaceStreamer:
20
+ def __init__(self, model: str):
21
+ if model in MODEL_MAP.keys():
22
+ self.model = model
23
+ else:
24
+ self.model = "default"
25
+ self.model_fullname = MODEL_MAP[self.model]
26
+ self.message_outputer = OpenaiStreamOutputer(model=self.model)
27
+
28
+ if self.model == "gemma-7b":
29
+ # this is not wrong, as repo `google/gemma-7b-it` is gated and must authenticate to access it
30
+ # so I use mistral-7b as a fallback
31
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_MAP["mistral-7b"])
32
+ else:
33
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
34
+
35
+ def parse_line(self, line):
36
+ line = line.decode("utf-8")
37
+ line = re.sub(r"data:\s*", "", line)
38
+ data = json.loads(line)
39
+ try:
40
+ content = data["token"]["text"]
41
+ except:
42
+ logger.err(data)
43
+ return content
44
+
45
+ def count_tokens(self, text):
46
+ tokens = self.tokenizer.encode(text)
47
+ token_count = len(tokens)
48
+ logger.note(f"Prompt Token Count: {token_count}")
49
+ return token_count
50
+
51
+ def chat_response(
52
+ self,
53
+ prompt: str = None,
54
+ temperature: float = 0.5,
55
+ top_p: float = 0.95,
56
+ max_new_tokens: int = None,
57
+ api_key: str = None,
58
+ use_cache: bool = False,
59
+ ):
60
+ # https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
61
+ # curl --proxy http://<server>:<port> https://api-inference.huggingface.co/models/<org>/<model_name> -X POST -d '{"inputs":"who are you?","parameters":{"max_new_token":64}}' -H 'Content-Type: application/json' -H 'Authorization: Bearer <HF_TOKEN>'
62
+ self.request_url = (
63
+ f"https://api-inference.huggingface.co/models/{self.model_fullname}"
64
+ )
65
+ self.request_headers = {
66
+ "Content-Type": "application/json",
67
+ }
68
+
69
+ if api_key:
70
+ logger.note(
71
+ f"Using API Key: {api_key[:3]}{(len(api_key)-7)*'*'}{api_key[-4:]}"
72
+ )
73
+ self.request_headers["Authorization"] = f"Bearer {api_key}"
74
+
75
+ if temperature is None or temperature < 0:
76
+ temperature = 0.0
77
+ # temperature must 0 < and < 1 for HF LLM models
78
+ temperature = max(temperature, 0.01)
79
+ temperature = min(temperature, 0.99)
80
+ top_p = max(top_p, 0.01)
81
+ top_p = min(top_p, 0.99)
82
+
83
+ token_limit = int(
84
+ TOKEN_LIMIT_MAP[self.model] - TOKEN_RESERVED - self.count_tokens(prompt)
85
+ )
86
+ if token_limit <= 0:
87
+ raise ValueError("Prompt exceeded token limit!")
88
+
89
+ if max_new_tokens is None or max_new_tokens <= 0:
90
+ max_new_tokens = token_limit
91
+ else:
92
+ max_new_tokens = min(max_new_tokens, token_limit)
93
+
94
+ # References:
95
+ # huggingface_hub/inference/_client.py:
96
+ # class InferenceClient > def text_generation()
97
+ # huggingface_hub/inference/_text_generation.py:
98
+ # class TextGenerationRequest > param `stream`
99
+ # https://huggingface.co/docs/text-generation-inference/conceptual/streaming#streaming-with-curl
100
+ # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
101
+ self.request_body = {
102
+ "inputs": prompt,
103
+ "parameters": {
104
+ "temperature": temperature,
105
+ "top_p": top_p,
106
+ "max_new_tokens": max_new_tokens,
107
+ "return_full_text": False,
108
+ },
109
+ "options": {
110
+ "use_cache": use_cache,
111
+ },
112
+ "stream": True,
113
+ }
114
+
115
+ if self.model in STOP_SEQUENCES_MAP.keys():
116
+ self.stop_sequences = STOP_SEQUENCES_MAP[self.model]
117
+ # self.request_body["parameters"]["stop_sequences"] = [
118
+ # self.STOP_SEQUENCES[self.model]
119
+ # ]
120
+
121
+ logger.back(self.request_url)
122
+ stream_response = requests.post(
123
+ self.request_url,
124
+ headers=self.request_headers,
125
+ json=self.request_body,
126
+ proxies=PROXIES,
127
+ stream=True,
128
+ )
129
+ status_code = stream_response.status_code
130
+ if status_code == 200:
131
+ logger.success(status_code)
132
+ else:
133
+ logger.err(status_code)
134
+
135
+ return stream_response
136
+
137
+ def chat_return_dict(self, stream_response):
138
+ # https://platform.openai.com/docs/guides/text-generation/chat-completions-response-format
139
+ final_output = self.message_outputer.default_data.copy()
140
+ final_output["choices"] = [
141
+ {
142
+ "index": 0,
143
+ "finish_reason": "stop",
144
+ "message": {
145
+ "role": "assistant",
146
+ "content": "",
147
+ },
148
+ }
149
+ ]
150
+ logger.back(final_output)
151
+
152
+ final_content = ""
153
+ for line in stream_response.iter_lines():
154
+ if not line:
155
+ continue
156
+ content = self.parse_line(line)
157
+
158
+ if content.strip() == self.stop_sequences:
159
+ logger.success("\n[Finished]")
160
+ break
161
+ else:
162
+ logger.back(content, end="")
163
+ final_content += content
164
+
165
+ if self.model in STOP_SEQUENCES_MAP.keys():
166
+ final_content = final_content.replace(self.stop_sequences, "")
167
+
168
+ final_content = final_content.strip()
169
+ final_output["choices"][0]["message"]["content"] = final_content
170
+ return final_output
171
+
172
+ def chat_return_generator(self, stream_response):
173
+ is_finished = False
174
+ line_count = 0
175
+ for line in stream_response.iter_lines():
176
+ if line:
177
+ line_count += 1
178
+ else:
179
+ continue
180
+
181
+ content = self.parse_line(line)
182
+
183
+ if content.strip() == self.stop_sequences:
184
+ content_type = "Finished"
185
+ logger.success("\n[Finished]")
186
+ is_finished = True
187
+ else:
188
+ content_type = "Completions"
189
+ if line_count == 1:
190
+ content = content.lstrip()
191
+ logger.back(content, end="")
192
+
193
+ output = self.message_outputer.output(
194
+ content=content, content_type=content_type
195
+ )
196
+ yield output
197
+
198
+ if not is_finished:
199
+ yield self.message_outputer.output(content="", content_type="Finished")