Hansimov commited on
Commit
8df3985
1 Parent(s): 0d8e943

:gem: [Feature] Moduralize TokenChecker, and fix gated model repos with alternatives

Browse files
messagers/token_checker.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tclogger import logger
2
+ from transformers import AutoTokenizer
3
+
4
+ from constants.models import MODEL_MAP, TOKEN_LIMIT_MAP, TOKEN_RESERVED
5
+
6
+
7
+ class TokenChecker:
8
+ def __init__(self, input_str: str, model: str):
9
+ self.input_str = input_str
10
+
11
+ if model in MODEL_MAP.keys():
12
+ self.model = model
13
+ else:
14
+ self.model = "mixtral-8x7b"
15
+
16
+ self.model_fullname = MODEL_MAP[self.model]
17
+
18
+ # As some models are gated, we need to fetch tokenizers from alternatives
19
+ GATED_MODEL_MAP = {
20
+ "llama3-70b": "NousResearch/Meta-Llama-3-70B",
21
+ "gemma-7b": "unsloth/gemma-7b",
22
+ "mistral-7b": "dfurman/Mistral-7B-Instruct-v0.2",
23
+ "mixtral-8x7b": "dfurman/Mixtral-8x7B-Instruct-v0.1",
24
+ }
25
+ if self.model in GATED_MODEL_MAP.keys():
26
+ self.tokenizer = AutoTokenizer.from_pretrained(GATED_MODEL_MAP[self.model])
27
+ else:
28
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
29
+
30
+ def count_tokens(self):
31
+ token_count = len(self.tokenizer.encode(self.input_str))
32
+ logger.note(f"Prompt Token Count: {token_count}")
33
+ return token_count
34
+
35
+ def get_token_limit(self):
36
+ return TOKEN_LIMIT_MAP[self.model]
37
+
38
+ def get_token_redundancy(self):
39
+ return int(self.get_token_limit() - TOKEN_RESERVED - self.count_tokens())
40
+
41
+ def check_token_limit(self):
42
+ if self.get_token_redundancy() <= 0:
43
+ raise ValueError(f"Prompt exceeded token limit: {self.get_token_limit()}")
44
+ return True
networks/huggingchat_streamer.py CHANGED
@@ -2,59 +2,15 @@ import copy
2
  import json
3
  import re
4
  import requests
5
- import uuid
6
 
7
- # from curl_cffi import requests
8
  from tclogger import logger
9
- from transformers import AutoTokenizer
10
-
11
- from constants.models import (
12
- MODEL_MAP,
13
- STOP_SEQUENCES_MAP,
14
- TOKEN_LIMIT_MAP,
15
- TOKEN_RESERVED,
16
- )
17
  from constants.envs import PROXIES
18
- from constants.headers import (
19
- REQUESTS_HEADERS,
20
- HUGGINGCHAT_POST_HEADERS,
21
- HUGGINGCHAT_SETTINGS_POST_DATA,
22
- )
23
  from messagers.message_outputer import OpenaiStreamOutputer
24
  from messagers.message_composer import MessageComposer
25
-
26
-
27
- class TokenChecker:
28
- def __init__(self, input_str: str, model: str):
29
- self.input_str = input_str
30
-
31
- if model in MODEL_MAP.keys():
32
- self.model = model
33
- else:
34
- self.model = "mixtral-8x7b"
35
-
36
- self.model_fullname = MODEL_MAP[self.model]
37
-
38
- if self.model == "llama3-70b":
39
- # As original llama3 repo is gated and requires auth,
40
- # I use NousResearch's version as a workaround
41
- self.tokenizer = AutoTokenizer.from_pretrained(
42
- "NousResearch/Meta-Llama-3-70B"
43
- )
44
- else:
45
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
46
-
47
- def count_tokens(self):
48
- token_count = len(self.tokenizer.encode(self.input_str))
49
- logger.note(f"Prompt Token Count: {token_count}")
50
- return token_count
51
-
52
- def check_token_limit(self):
53
- token_limit = TOKEN_LIMIT_MAP[self.model]
54
- token_redundancy = int(token_limit - TOKEN_RESERVED - self.count_tokens())
55
- if token_redundancy <= 0:
56
- raise ValueError(f"Prompt exceeded token limit: {token_limit}")
57
- return True
58
 
59
 
60
  class HuggingchatRequester:
 
2
  import json
3
  import re
4
  import requests
 
5
 
 
6
  from tclogger import logger
7
+
8
+ from constants.models import MODEL_MAP
 
 
 
 
 
 
9
  from constants.envs import PROXIES
10
+ from constants.headers import HUGGINGCHAT_POST_HEADERS, HUGGINGCHAT_SETTINGS_POST_DATA
 
 
 
 
11
  from messagers.message_outputer import OpenaiStreamOutputer
12
  from messagers.message_composer import MessageComposer
13
+ from messagers.token_checker import TokenChecker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  class HuggingchatRequester:
networks/huggingface_streamer.py CHANGED
@@ -2,18 +2,11 @@ import json
2
  import re
3
  import requests
4
 
5
-
6
  from tclogger import logger
7
- from transformers import AutoTokenizer
8
-
9
- from constants.models import (
10
- MODEL_MAP,
11
- STOP_SEQUENCES_MAP,
12
- TOKEN_LIMIT_MAP,
13
- TOKEN_RESERVED,
14
- )
15
  from constants.envs import PROXIES
16
  from messagers.message_outputer import OpenaiStreamOutputer
 
17
 
18
 
19
  class HuggingfaceStreamer:
@@ -25,13 +18,6 @@ class HuggingfaceStreamer:
25
  self.model_fullname = MODEL_MAP[self.model]
26
  self.message_outputer = OpenaiStreamOutputer(model=self.model)
27
 
28
- if self.model == "gemma-7b":
29
- # this is not wrong, as repo `google/gemma-7b-it` is gated and must authenticate to access it
30
- # so I use mistral-7b as a fallback
31
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_MAP["mistral-7b"])
32
- else:
33
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
34
-
35
  def parse_line(self, line):
36
  line = line.decode("utf-8")
37
  line = re.sub(r"data:\s*", "", line)
@@ -42,12 +28,6 @@ class HuggingfaceStreamer:
42
  logger.err(data)
43
  return content
44
 
45
- def count_tokens(self, text):
46
- tokens = self.tokenizer.encode(text)
47
- token_count = len(tokens)
48
- logger.note(f"Prompt Token Count: {token_count}")
49
- return token_count
50
-
51
  def chat_response(
52
  self,
53
  prompt: str = None,
@@ -80,16 +60,12 @@ class HuggingfaceStreamer:
80
  top_p = max(top_p, 0.01)
81
  top_p = min(top_p, 0.99)
82
 
83
- token_limit = int(
84
- TOKEN_LIMIT_MAP[self.model] - TOKEN_RESERVED - self.count_tokens(prompt)
85
- )
86
- if token_limit <= 0:
87
- raise ValueError("Prompt exceeded token limit!")
88
 
89
  if max_new_tokens is None or max_new_tokens <= 0:
90
- max_new_tokens = token_limit
91
  else:
92
- max_new_tokens = min(max_new_tokens, token_limit)
93
 
94
  # References:
95
  # huggingface_hub/inference/_client.py:
 
2
  import re
3
  import requests
4
 
 
5
  from tclogger import logger
6
+ from constants.models import MODEL_MAP, STOP_SEQUENCES_MAP
 
 
 
 
 
 
 
7
  from constants.envs import PROXIES
8
  from messagers.message_outputer import OpenaiStreamOutputer
9
+ from messagers.token_checker import TokenChecker
10
 
11
 
12
  class HuggingfaceStreamer:
 
18
  self.model_fullname = MODEL_MAP[self.model]
19
  self.message_outputer = OpenaiStreamOutputer(model=self.model)
20
 
 
 
 
 
 
 
 
21
  def parse_line(self, line):
22
  line = line.decode("utf-8")
23
  line = re.sub(r"data:\s*", "", line)
 
28
  logger.err(data)
29
  return content
30
 
 
 
 
 
 
 
31
  def chat_response(
32
  self,
33
  prompt: str = None,
 
60
  top_p = max(top_p, 0.01)
61
  top_p = min(top_p, 0.99)
62
 
63
+ checker = TokenChecker(input_str=prompt, model=self.model)
 
 
 
 
64
 
65
  if max_new_tokens is None or max_new_tokens <= 0:
66
+ max_new_tokens = checker.get_token_redundancy()
67
  else:
68
+ max_new_tokens = min(max_new_tokens, checker.get_token_redundancy())
69
 
70
  # References:
71
  # huggingface_hub/inference/_client.py: