JohnSmith9982's picture
Upload 98 files
0cc999a
raw history blame
No virus
2.32 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import logging
import colorama
from .base_model import BaseLLMModel
from ..presets import MODEL_METADATA
class Qwen_Client(BaseLLMModel):
def __init__(self, model_name, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_METADATA[model_name]["repo_id"], trust_remote_code=True, resume_download=True)
self.model = AutoModelForCausalLM.from_pretrained(MODEL_METADATA[model_name]["repo_id"], device_map="auto", trust_remote_code=True, resume_download=True).eval()
def generation_config(self):
return GenerationConfig.from_dict({
"chat_format": "chatml",
"do_sample": True,
"eos_token_id": 151643,
"max_length": self.token_upper_limit,
"max_new_tokens": 512,
"max_window_size": 6144,
"pad_token_id": 151643,
"top_k": 0,
"top_p": self.top_p,
"transformers_version": "4.33.2",
"trust_remote_code": True,
"temperature": self.temperature,
})
def _get_glm_style_input(self):
history = [x["content"] for x in self.history]
query = history.pop()
logging.debug(colorama.Fore.YELLOW +
f"{history}" + colorama.Fore.RESET)
assert (
len(history) % 2 == 0
), f"History should be even length. current history is: {history}"
history = [[history[i], history[i + 1]]
for i in range(0, len(history), 2)]
return history, query
def get_answer_at_once(self):
history, query = self._get_glm_style_input()
self.model.generation_config = self.generation_config()
response, history = self.model.chat(self.tokenizer, query, history=history)
return response, len(response)
def get_answer_stream_iter(self):
history, query = self._get_glm_style_input()
self.model.generation_config = self.generation_config()
for response in self.model.chat_stream(
self.tokenizer,
query,
history,
):
yield response