silk-road's picture
Update ChatHaruhi/Qwen118k2GPT.py
fd1d127
import torch
from .BaseLLM import BaseLLM
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
tokenizer_qwen = None
model_qwen = None
def initialize_Qwen2LORA(model):
global model_qwen, tokenizer_qwen
if model_qwen is None:
model_qwen = AutoModelForCausalLM.from_pretrained(
model,
device_map="auto",
trust_remote_code=True
)
model_qwen = model_qwen.eval()
# model_qwen = PeftModel.from_pretrained(
# model_qwen,
# "silk-road/Chat-Haruhi-Fusion_B"
# )
if tokenizer_qwen is None:
tokenizer_qwen = AutoTokenizer.from_pretrained(
model,
# use_fast=True,
trust_remote_code=True
)
return model_qwen, tokenizer_qwen
def Qwen_tokenizer(text):
return len(tokenizer_qwen.encode(text))
class Qwen118k2GPT(BaseLLM):
def __init__(self, model):
super(Qwen118k2GPT, self).__init__()
global model_qwen, tokenizer_qwen
if model == "Qwen/Qwen-1_8B-Chat":
tokenizer_qwen = AutoTokenizer.from_pretrained(
"Qwen/Qwen-1_8B-Chat",
trust_remote_code=True
)
model_qwen = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-1_8B-Chat",
device_map="auto",
trust_remote_code=True
).eval()
self.model = model_qwen
self.tokenizer = tokenizer_qwen
elif "silk-road/" in model :
self.model, self.tokenizer = initialize_Qwen2LORA(model)
else:
raise Exception("Unknown Qwen model")
self.messages = ""
def initialize_message(self):
self.messages = ""
def ai_message(self, payload):
self.messages = "AI: " + self.messages + "\n " + payload
def system_message(self, payload):
self.messages = "SYSTEM PROMPT: " + self.messages + "\n " + payload
def user_message(self, payload):
self.messages = "User: " + self.messages + "\n " + payload
def get_response(self):
with torch.no_grad():
response, history = self.model.chat(self.tokenizer, self.messages, history=[])
# print(response)
return response
def print_prompt(self):
print(type(self.messages))
print(self.messages)