Spaces:
Running
Running
:zap: [Enhance] Auto calculate max_tokens if not set
Browse files- apis/chat_api.py +3 -3
- networks/message_streamer.py +28 -1
- requirements.txt +1 -0
apis/chat_api.py
CHANGED
@@ -56,7 +56,7 @@ class ChatAPIApp:
|
|
56 |
if api_key.startswith("hf_"):
|
57 |
return api_key
|
58 |
else:
|
59 |
-
logger.warn(f"Invalid HF Token")
|
60 |
else:
|
61 |
logger.warn("Not provide HF Token!")
|
62 |
return None
|
@@ -71,11 +71,11 @@ class ChatAPIApp:
|
|
71 |
description="(list) Messages",
|
72 |
)
|
73 |
temperature: float = Field(
|
74 |
-
default=0
|
75 |
description="(float) Temperature",
|
76 |
)
|
77 |
max_tokens: int = Field(
|
78 |
-
default
|
79 |
description="(int) Max tokens",
|
80 |
)
|
81 |
stream: bool = Field(
|
|
|
56 |
if api_key.startswith("hf_"):
|
57 |
return api_key
|
58 |
else:
|
59 |
+
logger.warn(f"Invalid HF Token!")
|
60 |
else:
|
61 |
logger.warn("Not provide HF Token!")
|
62 |
return None
|
|
|
71 |
description="(list) Messages",
|
72 |
)
|
73 |
temperature: float = Field(
|
74 |
+
default=0,
|
75 |
description="(float) Temperature",
|
76 |
)
|
77 |
max_tokens: int = Field(
|
78 |
+
default=-1,
|
79 |
description="(int) Max tokens",
|
80 |
)
|
81 |
stream: bool = Field(
|
networks/message_streamer.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import json
|
2 |
import re
|
3 |
import requests
|
|
|
4 |
from messagers.message_outputer import OpenaiStreamOutputer
|
5 |
from utils.logger import logger
|
6 |
from utils.enver import enver
|
@@ -22,6 +23,12 @@ class MessageStreamer:
|
|
22 |
"mistral-7b": "</s>",
|
23 |
"openchat-3.5": "<|end_of_turn|>",
|
24 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def __init__(self, model: str):
|
27 |
if model in self.MODEL_MAP.keys():
|
@@ -30,6 +37,7 @@ class MessageStreamer:
|
|
30 |
self.model = "default"
|
31 |
self.model_fullname = self.MODEL_MAP[self.model]
|
32 |
self.message_outputer = OpenaiStreamOutputer()
|
|
|
33 |
|
34 |
def parse_line(self, line):
|
35 |
line = line.decode("utf-8")
|
@@ -38,11 +46,17 @@ class MessageStreamer:
|
|
38 |
content = data["token"]["text"]
|
39 |
return content
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def chat_response(
|
42 |
self,
|
43 |
prompt: str = None,
|
44 |
temperature: float = 0.01,
|
45 |
-
max_new_tokens: int =
|
46 |
api_key: str = None,
|
47 |
):
|
48 |
# https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
|
@@ -60,6 +74,19 @@ class MessageStreamer:
|
|
60 |
)
|
61 |
self.request_headers["Authorization"] = f"Bearer {api_key}"
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
# References:
|
64 |
# huggingface_hub/inference/_client.py:
|
65 |
# class InferenceClient > def text_generation()
|
|
|
1 |
import json
|
2 |
import re
|
3 |
import requests
|
4 |
+
from tiktoken import get_encoding as tiktoken_get_encoding
|
5 |
from messagers.message_outputer import OpenaiStreamOutputer
|
6 |
from utils.logger import logger
|
7 |
from utils.enver import enver
|
|
|
23 |
"mistral-7b": "</s>",
|
24 |
"openchat-3.5": "<|end_of_turn|>",
|
25 |
}
|
26 |
+
TOKEN_LIMIT_MAP = {
|
27 |
+
"mixtral-8x7b": 32768,
|
28 |
+
"mistral-7b": 32768,
|
29 |
+
"openchat-3.5": 8192,
|
30 |
+
}
|
31 |
+
TOKEN_RESERVED = 32
|
32 |
|
33 |
def __init__(self, model: str):
|
34 |
if model in self.MODEL_MAP.keys():
|
|
|
37 |
self.model = "default"
|
38 |
self.model_fullname = self.MODEL_MAP[self.model]
|
39 |
self.message_outputer = OpenaiStreamOutputer()
|
40 |
+
self.tokenizer = tiktoken_get_encoding("cl100k_base")
|
41 |
|
42 |
def parse_line(self, line):
|
43 |
line = line.decode("utf-8")
|
|
|
46 |
content = data["token"]["text"]
|
47 |
return content
|
48 |
|
49 |
+
def count_tokens(self, text):
|
50 |
+
tokens = self.tokenizer.encode(text)
|
51 |
+
token_count = len(tokens)
|
52 |
+
logger.note(f"Prompt Token Count: {token_count}")
|
53 |
+
return token_count
|
54 |
+
|
55 |
def chat_response(
|
56 |
self,
|
57 |
prompt: str = None,
|
58 |
temperature: float = 0.01,
|
59 |
+
max_new_tokens: int = None,
|
60 |
api_key: str = None,
|
61 |
):
|
62 |
# https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
|
|
|
74 |
)
|
75 |
self.request_headers["Authorization"] = f"Bearer {api_key}"
|
76 |
|
77 |
+
token_limit = (
|
78 |
+
self.TOKEN_LIMIT_MAP[self.model]
|
79 |
+
- self.TOKEN_RESERVED
|
80 |
+
- self.count_tokens(prompt)
|
81 |
+
)
|
82 |
+
if token_limit <= 0:
|
83 |
+
raise ValueError("Prompt exceeded token limit!")
|
84 |
+
|
85 |
+
if max_new_tokens is None or max_new_tokens <= 0:
|
86 |
+
max_new_tokens = token_limit
|
87 |
+
else:
|
88 |
+
max_new_tokens = min(max_new_tokens, token_limit)
|
89 |
+
|
90 |
# References:
|
91 |
# huggingface_hub/inference/_client.py:
|
92 |
# class InferenceClient > def text_generation()
|
requirements.txt
CHANGED
@@ -6,5 +6,6 @@ pydantic
|
|
6 |
requests
|
7 |
sse_starlette
|
8 |
termcolor
|
|
|
9 |
uvicorn
|
10 |
websockets
|
|
|
6 |
requests
|
7 |
sse_starlette
|
8 |
termcolor
|
9 |
+
tiktoken
|
10 |
uvicorn
|
11 |
websockets
|