Hansimov commited on
Commit
bf8c5bd
1 Parent(s): 3f608c6

:recycle: [Refactor] Move STOP_SEQUENCES_MAP and TOKEN_LIMIT_MAP to constants

Browse files
Files changed (2) hide show
  1. constants/models.py +19 -0
  2. networks/message_streamer.py +10 -22
constants/models.py CHANGED
@@ -6,3 +6,22 @@ MODEL_MAP = {
6
  "gemma-7b": "google/gemma-7b-it",
7
  "default": "mistralai/Mixtral-8x7B-Instruct-v0.1",
8
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  "gemma-7b": "google/gemma-7b-it",
7
  "default": "mistralai/Mixtral-8x7B-Instruct-v0.1",
8
  }
9
+
10
+
11
+ STOP_SEQUENCES_MAP = {
12
+ "mixtral-8x7b": "</s>",
13
+ "nous-mixtral-8x7b": "<|im_end|>",
14
+ "mistral-7b": "</s>",
15
+ "openchat-3.5": "<|end_of_turn|>",
16
+ "gemma-7b": "<eos>",
17
+ }
18
+
19
+ TOKEN_LIMIT_MAP = {
20
+ "mixtral-8x7b": 32768,
21
+ "nous-mixtral-8x7b": 32768,
22
+ "mistral-7b": 32768,
23
+ "openchat-3.5": 8192,
24
+ "gemma-7b": 8192,
25
+ }
26
+
27
+ TOKEN_RESERVED = 20
networks/message_streamer.py CHANGED
@@ -5,28 +5,18 @@ import requests
5
  from tiktoken import get_encoding as tiktoken_get_encoding
6
  from transformers import AutoTokenizer
7
 
 
 
 
 
 
 
8
  from messagers.message_outputer import OpenaiStreamOutputer
9
- from constants.models import MODEL_MAP
10
  from utils.logger import logger
11
  from utils.enver import enver
12
 
13
 
14
  class MessageStreamer:
15
- STOP_SEQUENCES_MAP = {
16
- "mixtral-8x7b": "</s>",
17
- "mistral-7b": "</s>",
18
- "nous-mixtral-8x7b": "<|im_end|>",
19
- "openchat-3.5": "<|end_of_turn|>",
20
- "gemma-7b": "<eos>",
21
- }
22
- TOKEN_LIMIT_MAP = {
23
- "mixtral-8x7b": 32768,
24
- "mistral-7b": 32768,
25
- "nous-mixtral-8x7b": 32768,
26
- "openchat-3.5": 8192,
27
- "gemma-7b": 8192,
28
- }
29
- TOKEN_RESERVED = 20
30
 
31
  def __init__(self, model: str):
32
  if model in MODEL_MAP.keys():
@@ -92,9 +82,7 @@ class MessageStreamer:
92
  top_p = min(top_p, 0.99)
93
 
94
  token_limit = int(
95
- self.TOKEN_LIMIT_MAP[self.model]
96
- - self.TOKEN_RESERVED
97
- - self.count_tokens(prompt)
98
  )
99
  if token_limit <= 0:
100
  raise ValueError("Prompt exceeded token limit!")
@@ -125,8 +113,8 @@ class MessageStreamer:
125
  "stream": True,
126
  }
127
 
128
- if self.model in self.STOP_SEQUENCES_MAP.keys():
129
- self.stop_sequences = self.STOP_SEQUENCES_MAP[self.model]
130
  # self.request_body["parameters"]["stop_sequences"] = [
131
  # self.STOP_SEQUENCES[self.model]
132
  # ]
@@ -176,7 +164,7 @@ class MessageStreamer:
176
  logger.back(content, end="")
177
  final_content += content
178
 
179
- if self.model in self.STOP_SEQUENCES_MAP.keys():
180
  final_content = final_content.replace(self.stop_sequences, "")
181
 
182
  final_content = final_content.strip()
 
5
  from tiktoken import get_encoding as tiktoken_get_encoding
6
  from transformers import AutoTokenizer
7
 
8
+ from constants.models import (
9
+ MODEL_MAP,
10
+ STOP_SEQUENCES_MAP,
11
+ TOKEN_LIMIT_MAP,
12
+ TOKEN_RESERVED,
13
+ )
14
  from messagers.message_outputer import OpenaiStreamOutputer
 
15
  from utils.logger import logger
16
  from utils.enver import enver
17
 
18
 
19
  class MessageStreamer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def __init__(self, model: str):
22
  if model in MODEL_MAP.keys():
 
82
  top_p = min(top_p, 0.99)
83
 
84
  token_limit = int(
85
+ TOKEN_LIMIT_MAP[self.model] - TOKEN_RESERVED - self.count_tokens(prompt)
 
 
86
  )
87
  if token_limit <= 0:
88
  raise ValueError("Prompt exceeded token limit!")
 
113
  "stream": True,
114
  }
115
 
116
+ if self.model in STOP_SEQUENCES_MAP.keys():
117
+ self.stop_sequences = STOP_SEQUENCES_MAP[self.model]
118
  # self.request_body["parameters"]["stop_sequences"] = [
119
  # self.STOP_SEQUENCES[self.model]
120
  # ]
 
164
  logger.back(content, end="")
165
  final_content += content
166
 
167
+ if self.model in STOP_SEQUENCES_MAP.keys():
168
  final_content = final_content.replace(self.stop_sequences, "")
169
 
170
  final_content = final_content.strip()