silk-road's picture
Update ChatHaruhi/Qwen118k2GPT.py
fd1d127
raw
history blame contribute delete
No virus
2.52 kB
import torch
from .BaseLLM import BaseLLM
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
tokenizer_qwen = None
model_qwen = None
def initialize_Qwen2LORA(model):
global model_qwen, tokenizer_qwen
if model_qwen is None:
model_qwen = AutoModelForCausalLM.from_pretrained(
model,
device_map="auto",
trust_remote_code=True
)
model_qwen = model_qwen.eval()
# model_qwen = PeftModel.from_pretrained(
# model_qwen,
# "silk-road/Chat-Haruhi-Fusion_B"
# )
if tokenizer_qwen is None:
tokenizer_qwen = AutoTokenizer.from_pretrained(
model,
# use_fast=True,
trust_remote_code=True
)
return model_qwen, tokenizer_qwen
def Qwen_tokenizer(text):
return len(tokenizer_qwen.encode(text))
class Qwen118k2GPT(BaseLLM):
def __init__(self, model):
super(Qwen118k2GPT, self).__init__()
global model_qwen, tokenizer_qwen
if model == "Qwen/Qwen-1_8B-Chat":
tokenizer_qwen = AutoTokenizer.from_pretrained(
"Qwen/Qwen-1_8B-Chat",
trust_remote_code=True
)
model_qwen = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-1_8B-Chat",
device_map="auto",
trust_remote_code=True
).eval()
self.model = model_qwen
self.tokenizer = tokenizer_qwen
elif "silk-road/" in model :
self.model, self.tokenizer = initialize_Qwen2LORA(model)
else:
raise Exception("Unknown Qwen model")
self.messages = ""
def initialize_message(self):
self.messages = ""
def ai_message(self, payload):
self.messages = "AI: " + self.messages + "\n " + payload
def system_message(self, payload):
self.messages = "SYSTEM PROMPT: " + self.messages + "\n " + payload
def user_message(self, payload):
self.messages = "User: " + self.messages + "\n " + payload
def get_response(self):
with torch.no_grad():
response, history = self.model.chat(self.tokenizer, self.messages, history=[])
# print(response)
return response
def print_prompt(self):
print(type(self.messages))
print(self.messages)