Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	:recycle: [Refactor] Move STOP_SEQUENCES_MAP and TOKEN_LIMIT_MAP to constants
Browse files- constants/models.py +19 -0
 - networks/message_streamer.py +10 -22
 
    	
        constants/models.py
    CHANGED
    
    | 
         @@ -6,3 +6,22 @@ MODEL_MAP = { 
     | 
|
| 6 | 
         
             
                "gemma-7b": "google/gemma-7b-it",
         
     | 
| 7 | 
         
             
                "default": "mistralai/Mixtral-8x7B-Instruct-v0.1",
         
     | 
| 8 | 
         
             
            }
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 6 | 
         
             
                "gemma-7b": "google/gemma-7b-it",
         
     | 
| 7 | 
         
             
                "default": "mistralai/Mixtral-8x7B-Instruct-v0.1",
         
     | 
| 8 | 
         
             
            }
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            STOP_SEQUENCES_MAP = {
         
     | 
| 12 | 
         
            +
                "mixtral-8x7b": "</s>",
         
     | 
| 13 | 
         
            +
                "nous-mixtral-8x7b": "<|im_end|>",
         
     | 
| 14 | 
         
            +
                "mistral-7b": "</s>",
         
     | 
| 15 | 
         
            +
                "openchat-3.5": "<|end_of_turn|>",
         
     | 
| 16 | 
         
            +
                "gemma-7b": "<eos>",
         
     | 
| 17 | 
         
            +
            }
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            TOKEN_LIMIT_MAP = {
         
     | 
| 20 | 
         
            +
                "mixtral-8x7b": 32768,
         
     | 
| 21 | 
         
            +
                "nous-mixtral-8x7b": 32768,
         
     | 
| 22 | 
         
            +
                "mistral-7b": 32768,
         
     | 
| 23 | 
         
            +
                "openchat-3.5": 8192,
         
     | 
| 24 | 
         
            +
                "gemma-7b": 8192,
         
     | 
| 25 | 
         
            +
            }
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            TOKEN_RESERVED = 20
         
     | 
    	
        networks/message_streamer.py
    CHANGED
    
    | 
         @@ -5,28 +5,18 @@ import requests 
     | 
|
| 5 | 
         
             
            from tiktoken import get_encoding as tiktoken_get_encoding
         
     | 
| 6 | 
         
             
            from transformers import AutoTokenizer
         
     | 
| 7 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 8 | 
         
             
            from messagers.message_outputer import OpenaiStreamOutputer
         
     | 
| 9 | 
         
            -
            from constants.models import MODEL_MAP
         
     | 
| 10 | 
         
             
            from utils.logger import logger
         
     | 
| 11 | 
         
             
            from utils.enver import enver
         
     | 
| 12 | 
         | 
| 13 | 
         | 
| 14 | 
         
             
            class MessageStreamer:
         
     | 
| 15 | 
         
            -
                STOP_SEQUENCES_MAP = {
         
     | 
| 16 | 
         
            -
                    "mixtral-8x7b": "</s>",
         
     | 
| 17 | 
         
            -
                    "mistral-7b": "</s>",
         
     | 
| 18 | 
         
            -
                    "nous-mixtral-8x7b": "<|im_end|>",
         
     | 
| 19 | 
         
            -
                    "openchat-3.5": "<|end_of_turn|>",
         
     | 
| 20 | 
         
            -
                    "gemma-7b": "<eos>",
         
     | 
| 21 | 
         
            -
                }
         
     | 
| 22 | 
         
            -
                TOKEN_LIMIT_MAP = {
         
     | 
| 23 | 
         
            -
                    "mixtral-8x7b": 32768,
         
     | 
| 24 | 
         
            -
                    "mistral-7b": 32768,
         
     | 
| 25 | 
         
            -
                    "nous-mixtral-8x7b": 32768,
         
     | 
| 26 | 
         
            -
                    "openchat-3.5": 8192,
         
     | 
| 27 | 
         
            -
                    "gemma-7b": 8192,
         
     | 
| 28 | 
         
            -
                }
         
     | 
| 29 | 
         
            -
                TOKEN_RESERVED = 20
         
     | 
| 30 | 
         | 
| 31 | 
         
             
                def __init__(self, model: str):
         
     | 
| 32 | 
         
             
                    if model in MODEL_MAP.keys():
         
     | 
| 
         @@ -92,9 +82,7 @@ class MessageStreamer: 
     | 
|
| 92 | 
         
             
                    top_p = min(top_p, 0.99)
         
     | 
| 93 | 
         | 
| 94 | 
         
             
                    token_limit = int(
         
     | 
| 95 | 
         
            -
                         
     | 
| 96 | 
         
            -
                        - self.TOKEN_RESERVED
         
     | 
| 97 | 
         
            -
                        - self.count_tokens(prompt)
         
     | 
| 98 | 
         
             
                    )
         
     | 
| 99 | 
         
             
                    if token_limit <= 0:
         
     | 
| 100 | 
         
             
                        raise ValueError("Prompt exceeded token limit!")
         
     | 
| 
         @@ -125,8 +113,8 @@ class MessageStreamer: 
     | 
|
| 125 | 
         
             
                        "stream": True,
         
     | 
| 126 | 
         
             
                    }
         
     | 
| 127 | 
         | 
| 128 | 
         
            -
                    if self.model in  
     | 
| 129 | 
         
            -
                        self.stop_sequences =  
     | 
| 130 | 
         
             
                    #     self.request_body["parameters"]["stop_sequences"] = [
         
     | 
| 131 | 
         
             
                    #         self.STOP_SEQUENCES[self.model]
         
     | 
| 132 | 
         
             
                    #     ]
         
     | 
| 
         @@ -176,7 +164,7 @@ class MessageStreamer: 
     | 
|
| 176 | 
         
             
                            logger.back(content, end="")
         
     | 
| 177 | 
         
             
                            final_content += content
         
     | 
| 178 | 
         | 
| 179 | 
         
            -
                    if self.model in  
     | 
| 180 | 
         
             
                        final_content = final_content.replace(self.stop_sequences, "")
         
     | 
| 181 | 
         | 
| 182 | 
         
             
                    final_content = final_content.strip()
         
     | 
| 
         | 
|
| 5 | 
         
             
            from tiktoken import get_encoding as tiktoken_get_encoding
         
     | 
| 6 | 
         
             
            from transformers import AutoTokenizer
         
     | 
| 7 | 
         | 
| 8 | 
         
            +
            from constants.models import (
         
     | 
| 9 | 
         
            +
                MODEL_MAP,
         
     | 
| 10 | 
         
            +
                STOP_SEQUENCES_MAP,
         
     | 
| 11 | 
         
            +
                TOKEN_LIMIT_MAP,
         
     | 
| 12 | 
         
            +
                TOKEN_RESERVED,
         
     | 
| 13 | 
         
            +
            )
         
     | 
| 14 | 
         
             
            from messagers.message_outputer import OpenaiStreamOutputer
         
     | 
| 
         | 
|
| 15 | 
         
             
            from utils.logger import logger
         
     | 
| 16 | 
         
             
            from utils.enver import enver
         
     | 
| 17 | 
         | 
| 18 | 
         | 
| 19 | 
         
             
            class MessageStreamer:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 20 | 
         | 
| 21 | 
         
             
                def __init__(self, model: str):
         
     | 
| 22 | 
         
             
                    if model in MODEL_MAP.keys():
         
     | 
| 
         | 
|
| 82 | 
         
             
                    top_p = min(top_p, 0.99)
         
     | 
| 83 | 
         | 
| 84 | 
         
             
                    token_limit = int(
         
     | 
| 85 | 
         
            +
                        TOKEN_LIMIT_MAP[self.model] - TOKEN_RESERVED - self.count_tokens(prompt)
         
     | 
| 
         | 
|
| 
         | 
|
| 86 | 
         
             
                    )
         
     | 
| 87 | 
         
             
                    if token_limit <= 0:
         
     | 
| 88 | 
         
             
                        raise ValueError("Prompt exceeded token limit!")
         
     | 
| 
         | 
|
| 113 | 
         
             
                        "stream": True,
         
     | 
| 114 | 
         
             
                    }
         
     | 
| 115 | 
         | 
| 116 | 
         
            +
                    if self.model in STOP_SEQUENCES_MAP.keys():
         
     | 
| 117 | 
         
            +
                        self.stop_sequences = STOP_SEQUENCES_MAP[self.model]
         
     | 
| 118 | 
         
             
                    #     self.request_body["parameters"]["stop_sequences"] = [
         
     | 
| 119 | 
         
             
                    #         self.STOP_SEQUENCES[self.model]
         
     | 
| 120 | 
         
             
                    #     ]
         
     | 
| 
         | 
|
| 164 | 
         
             
                            logger.back(content, end="")
         
     | 
| 165 | 
         
             
                            final_content += content
         
     | 
| 166 | 
         | 
| 167 | 
         
            +
                    if self.model in STOP_SEQUENCES_MAP.keys():
         
     | 
| 168 | 
         
             
                        final_content = final_content.replace(self.stop_sequences, "")
         
     | 
| 169 | 
         | 
| 170 | 
         
             
                    final_content = final_content.strip()
         
     |